diff --git a/meta-oe/recipes-devtools/protobuf/protobuf/0001-Add-recursion-check-when-parsing-unknown-fields-in-J.patch b/meta-oe/recipes-devtools/protobuf/protobuf/0001-Add-recursion-check-when-parsing-unknown-fields-in-J.patch new file mode 100644 index 0000000000..d5d85aac03 --- /dev/null +++ b/meta-oe/recipes-devtools/protobuf/protobuf/0001-Add-recursion-check-when-parsing-unknown-fields-in-J.patch @@ -0,0 +1,795 @@ +From 23cd727717604ab5a51b88e0eff51b364e6fa008 Mon Sep 17 00:00:00 2001 +From: Protobuf Team Bot +Date: Tue, 17 Sep 2024 12:03:36 -0700 +Subject: [PATCH] Add recursion check when parsing unknown fields in Java. + +PiperOrigin-RevId: 675657198 + +CVE: CVE-2024-7254 + +Upstream-Status: Backport [ac9fb5b4c71b0dd80985b27684e265d1f03abf46] + +The original patch is adjusted to fit for the current version. + +Signed-off-by: Chen Qi +--- + .../com/google/protobuf/ArrayDecoders.java | 29 +++ + .../com/google/protobuf/CodedInputStream.java | 72 +++++- + .../com/google/protobuf/MessageSchema.java | 9 +- + .../com/google/protobuf/MessageSetSchema.java | 2 +- + .../google/protobuf/UnknownFieldSchema.java | 28 ++- + .../google/protobuf/CodedInputStreamTest.java | 159 ++++++++++++ + .../java/com/google/protobuf/LiteTest.java | 232 ++++++++++++++++++ + 7 files changed, 515 insertions(+), 16 deletions(-) + +diff --git a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java +index 39b79278c..431739cd5 100644 +--- a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java ++++ b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java +@@ -44,6 +44,11 @@ import java.io.IOException; + * crossing protobuf public API boundaries. + */ + final class ArrayDecoders { ++ static final int DEFAULT_RECURSION_LIMIT = 100; ++ ++ @SuppressWarnings("NonFinalStaticField") ++ private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT; ++ + /** + * A helper used to return multiple values in a Java function. Java doesn't natively support + * returning multiple values in a function. Creating a new Object to hold the return values will +@@ -58,6 +63,7 @@ final class ArrayDecoders { + public long long1; + public Object object1; + public final ExtensionRegistryLite extensionRegistry; ++ public int recursionDepth; + + Registers() { + this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry(); +@@ -265,7 +271,10 @@ final class ArrayDecoders { + if (length < 0 || length > limit - position) { + throw InvalidProtocolBufferException.truncatedMessage(); + } ++ registers.recursionDepth++; ++ checkRecursionLimit(registers.recursionDepth); + schema.mergeFrom(msg, data, position, position + length, registers); ++ registers.recursionDepth--; + registers.object1 = msg; + return position + length; + } +@@ -284,8 +293,11 @@ final class ArrayDecoders { + // and it can't be used in group fields). + final MessageSchema messageSchema = (MessageSchema) schema; + // It's OK to directly use parseProto2Message since proto3 doesn't have group. ++ registers.recursionDepth++; ++ checkRecursionLimit(registers.recursionDepth); + final int endPosition = + messageSchema.parseProto2Message(msg, data, position, limit, endGroup, registers); ++ registers.recursionDepth--; + registers.object1 = msg; + return endPosition; + } +@@ -1045,6 +1057,8 @@ final class ArrayDecoders { + final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance(); + final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP; + int lastTag = 0; ++ registers.recursionDepth++; ++ checkRecursionLimit(registers.recursionDepth); + while (position < limit) { + position = decodeVarint32(data, position, registers); + lastTag = registers.int1; +@@ -1053,6 +1067,7 @@ final class ArrayDecoders { + } + position = decodeUnknownField(lastTag, data, position, limit, child, registers); + } ++ registers.recursionDepth--; + if (position > limit || lastTag != endGroup) { + throw InvalidProtocolBufferException.parseFailure(); + } +@@ -1099,4 +1114,18 @@ final class ArrayDecoders { + throw InvalidProtocolBufferException.invalidTag(); + } + } ++ ++ /** ++ * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if ++ * the depth of the message exceeds this limit. ++ */ ++ public static void setRecursionLimit(int limit) { ++ recursionLimit = limit; ++ } ++ ++ private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException { ++ if (depth >= recursionLimit) { ++ throw InvalidProtocolBufferException.recursionLimitExceeded(); ++ } ++ } + } +diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +index 6e9c0f620..5c194cff6 100644 +--- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java ++++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +@@ -729,7 +729,14 @@ public abstract class CodedInputStream { + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -739,7 +746,14 @@ public abstract class CodedInputStream { + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag, output)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag, output); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -1444,7 +1458,14 @@ public abstract class CodedInputStream { + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -1454,7 +1475,14 @@ public abstract class CodedInputStream { + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag, output)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag, output); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -2212,7 +2240,14 @@ public abstract class CodedInputStream { + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -2222,7 +2257,14 @@ public abstract class CodedInputStream { + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag, output)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag, output); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -3333,7 +3375,14 @@ public abstract class CodedInputStream { + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +@@ -3343,7 +3392,14 @@ public abstract class CodedInputStream { + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); +- if (tag == 0 || !skipField(tag, output)) { ++ if (tag == 0) { ++ return; ++ } ++ checkRecursionLimit(); ++ ++recursionDepth; ++ boolean fieldSkipped = skipField(tag, output); ++ --recursionDepth; ++ if (!fieldSkipped) { + return; + } + } +diff --git a/java/core/src/main/java/com/google/protobuf/MessageSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSchema.java +index 76220cc47..d85899e14 100644 +--- a/java/core/src/main/java/com/google/protobuf/MessageSchema.java ++++ b/java/core/src/main/java/com/google/protobuf/MessageSchema.java +@@ -3946,7 +3946,8 @@ final class MessageSchema implements Schema { + unknownFields = unknownFieldSchema.getBuilderFromMessage(message); + } + // Unknown field. +- if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { ++ if (unknownFieldSchema.mergeOneFieldFrom( ++ unknownFields, reader, /* currentDepth= */ 0)) { + continue; + } + } +@@ -4321,7 +4322,8 @@ final class MessageSchema implements Schema { + if (unknownFields == null) { + unknownFields = unknownFieldSchema.getBuilderFromMessage(message); + } +- if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { ++ if (!unknownFieldSchema.mergeOneFieldFrom( ++ unknownFields, reader, /* currentDepth= */ 0)) { + return; + } + break; +@@ -4337,7 +4339,8 @@ final class MessageSchema implements Schema { + if (unknownFields == null) { + unknownFields = unknownFieldSchema.getBuilderFromMessage(message); + } +- if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { ++ if (!unknownFieldSchema.mergeOneFieldFrom( ++ unknownFields, reader, /* currentDepth= */ 0)) { + return; + } + } +diff --git a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java +index eae93b912..705d74608 100644 +--- a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java ++++ b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java +@@ -301,7 +301,7 @@ final class MessageSetSchema implements Schema { + reader, extension, extensionRegistry, extensions); + return true; + } else { +- return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader); ++ return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0); + } + } else { + return reader.skipField(); +diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java +index e736d5ce9..c68e3107c 100644 +--- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java ++++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java +@@ -35,6 +35,11 @@ import java.io.IOException; + @ExperimentalApi + abstract class UnknownFieldSchema { + ++ static final int DEFAULT_RECURSION_LIMIT = 100; ++ ++ @SuppressWarnings("NonFinalStaticField") ++ private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT; ++ + /** Whether unknown fields should be dropped. */ + abstract boolean shouldDiscardUnknownFields(Reader reader); + +@@ -78,7 +83,8 @@ abstract class UnknownFieldSchema { + abstract void makeImmutable(Object message); + + /** Merges one field into the unknown fields. */ +- final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException { ++ final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth) ++ throws IOException { + int tag = reader.getTag(); + int fieldNumber = WireFormat.getTagFieldNumber(tag); + switch (WireFormat.getTagWireType(tag)) { +@@ -97,7 +103,12 @@ abstract class UnknownFieldSchema { + case WireFormat.WIRETYPE_START_GROUP: + final B subFields = newBuilder(); + int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP); +- mergeFrom(subFields, reader); ++ currentDepth++; ++ if (currentDepth >= recursionLimit) { ++ throw InvalidProtocolBufferException.recursionLimitExceeded(); ++ } ++ mergeFrom(subFields, reader, currentDepth); ++ currentDepth--; + if (endGroupTag != reader.getTag()) { + throw InvalidProtocolBufferException.invalidEndTag(); + } +@@ -110,10 +121,11 @@ abstract class UnknownFieldSchema { + } + } + +- final void mergeFrom(B unknownFields, Reader reader) throws IOException { ++ final void mergeFrom(B unknownFields, Reader reader, int currentDepth) ++ throws IOException { + while (true) { + if (reader.getFieldNumber() == Reader.READ_DONE +- || !mergeOneFieldFrom(unknownFields, reader)) { ++ || !mergeOneFieldFrom(unknownFields, reader, currentDepth)) { + break; + } + } +@@ -130,4 +142,12 @@ abstract class UnknownFieldSchema { + abstract int getSerializedSizeAsMessageSet(T message); + + abstract int getSerializedSize(T unknowns); ++ ++ /** ++ * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if ++ * the depth of the message exceeds this limit. ++ */ ++ public void setRecursionLimit(int limit) { ++ recursionLimit = limit; ++ } + } +diff --git a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +index f5bb31fdf..78efd4c31 100644 +--- a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java ++++ b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +@@ -33,6 +33,10 @@ package com.google.protobuf; + import static com.google.common.truth.Truth.assertThat; + import static com.google.common.truth.Truth.assertWithMessage; + import static org.junit.Assert.assertArrayEquals; ++import static org.junit.Assert.assertThrows; ++ ++import com.google.common.primitives.Bytes; ++import map_test.MapTestProto.MapContainer; + import protobuf_unittest.UnittestProto.BoolMessage; + import protobuf_unittest.UnittestProto.Int32Message; + import protobuf_unittest.UnittestProto.Int64Message; +@@ -57,6 +61,13 @@ public class CodedInputStreamTest { + + private static final int DEFAULT_BLOCK_SIZE = 4096; + ++ private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP); ++ ++ private static final byte[] NESTING_SGROUP = generateSGroupTags(); ++ ++ private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField(); ++ ++ + private enum InputType { + ARRAY { + @Override +@@ -139,6 +150,17 @@ public class CodedInputStreamTest { + return bytes; + } + ++ private static byte[] generateSGroupTags() { ++ byte[] bytes = new byte[100000]; ++ Arrays.fill(bytes, (byte) GROUP_TAP); ++ return bytes; ++ } ++ ++ private static byte[] generateSGroupTagsForMapField() { ++ byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12}; ++ return Bytes.concat(initialBytes, NESTING_SGROUP); ++ } ++ + /** + * An InputStream which limits the number of bytes it reads at a time. We use this to make sure + * that CodedInputStream doesn't screw up when reading in small blocks. +@@ -684,6 +706,143 @@ public class CodedInputStreamTest { + } + } + ++ @Test ++ public void testMaliciousRecursion_unknownFields() throws Exception { ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> TestRecursiveMessage.parseFrom(NESTING_SGROUP)); ++ ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousRecursion_skippingUnknownField() throws Exception { ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> ++ DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser()) ++ .parseFrom(NESTING_SGROUP)); ++ ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception { ++ Throwable parseFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> ++ MapContainer.parseFrom( ++ new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES))); ++ Throwable mergeFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> ++ MapContainer.newBuilder() ++ .mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES))); ++ ++ assertThat(parseFromThrown) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ assertThat(mergeFromThrown) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception { ++ ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP); ++ CodedInputStream input = CodedInputStream.newInstance(inputSteam); ++ CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); ++ ++ Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); ++ Throwable thrown2 = ++ assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); ++ ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ assertThat(thrown2) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception { ++ Throwable parseFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES)); ++ Throwable mergeFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES)); ++ ++ assertThat(parseFromThrown) ++ .hasMessageThat() ++ .contains("the input ended unexpectedly in the middle of a field"); ++ assertThat(mergeFromThrown) ++ .hasMessageThat() ++ .contains("the input ended unexpectedly in the middle of a field"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception { ++ CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP); ++ CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); ++ ++ Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); ++ Throwable thrown2 = ++ assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); ++ ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ assertThat(thrown2) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception { ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES))); ++ ++ assertThat(thrown) ++ .hasMessageThat() ++ .contains("the input ended unexpectedly in the middle of a field"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception { ++ CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP); ++ CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); ++ ++ Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); ++ Throwable thrown2 = ++ assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); ++ ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ assertThat(thrown2) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception { ++ CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP); ++ CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); ++ ++ Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); ++ Throwable thrown2 = ++ assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); ++ ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ assertThat(thrown2) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ + private void checkSizeLimitExceeded(InvalidProtocolBufferException e) { + assertThat(e) + .hasMessageThat() +diff --git a/java/lite/src/test/java/com/google/protobuf/LiteTest.java b/java/lite/src/test/java/com/google/protobuf/LiteTest.java +index 6535a1bf6..850228467 100644 +--- a/java/lite/src/test/java/com/google/protobuf/LiteTest.java ++++ b/java/lite/src/test/java/com/google/protobuf/LiteTest.java +@@ -2453,6 +2453,211 @@ public class LiteTest { + } + } + ++ @Test ++ public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception { ++ int numThreads = 200; ++ ArrayList threads = new ArrayList<>(); ++ ++ ByteString byteString = generateNestingGroups(99); ++ AtomicBoolean thrown = new AtomicBoolean(false); ++ ++ for (int i = 0; i < numThreads; i++) { ++ Thread thread = ++ new Thread( ++ () -> { ++ try { ++ TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString); ++ } catch (IOException e) { ++ if (e.getMessage().contains("Protocol message had too many levels of nesting")) { ++ thrown.set(true); ++ } ++ } ++ }); ++ thread.start(); ++ threads.add(thread); ++ } ++ ++ for (Thread thread : threads) { ++ thread.join(); ++ } ++ ++ assertThat(thrown.get()).isFalse(); ++ } ++ ++ @Test ++ public void testParseFromInputStream_nestingUnknownGroups() throws IOException { ++ ByteString byteString = generateNestingGroups(99); ++ ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); ++ assertThat(thrown) ++ .hasMessageThat() ++ .doesNotContain("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException { ++ ByteString byteString = generateNestingGroups(100); ++ ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testParseFromInputStream_setRecursionLimit_exception() throws IOException { ++ ByteString byteString = generateNestingGroups(199); ++ UnknownFieldSchema schema = SchemaUtil.unknownFieldSetLiteSchema(); ++ schema.setRecursionLimit(200); ++ ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); ++ assertThat(thrown) ++ .hasMessageThat() ++ .doesNotContain("Protocol message had too many levels of nesting"); ++ schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT); ++ } ++ ++ @Test ++ public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception { ++ int numThreads = 200; ++ ArrayList threads = new ArrayList<>(); ++ ++ ByteString byteString = generateNestingGroups(99); ++ AtomicBoolean thrown = new AtomicBoolean(false); ++ ++ for (int i = 0; i < numThreads; i++) { ++ Thread thread = ++ new Thread( ++ () -> { ++ try { ++ // Should pass in byte[] instead of ByteString to go into ArrayDecoders. ++ TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray()); ++ } catch (InvalidProtocolBufferException e) { ++ if (e.getMessage().contains("Protocol message had too many levels of nesting")) { ++ thrown.set(true); ++ } ++ } ++ }); ++ thread.start(); ++ threads.add(thread); ++ } ++ ++ for (Thread thread : threads) { ++ thread.join(); ++ } ++ ++ assertThat(thrown.get()).isFalse(); ++ } ++ ++ @Test ++ public void testParseFromBytes_nestingUnknownGroups() throws IOException { ++ ByteString byteString = generateNestingGroups(99); ++ ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); ++ assertThat(thrown) ++ .hasMessageThat() ++ .doesNotContain("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException { ++ ByteString byteString = generateNestingGroups(100); ++ ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testParseFromBytes_setRecursionLimit_exception() throws IOException { ++ ByteString byteString = generateNestingGroups(199); ++ ArrayDecoders.setRecursionLimit(200); ++ ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); ++ assertThat(thrown) ++ .hasMessageThat() ++ .doesNotContain("Protocol message had too many levels of nesting"); ++ ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT); ++ } ++ ++ @Test ++ public void testParseFromBytes_recursiveMessages() throws Exception { ++ byte[] data99 = makeRecursiveMessage(99).toByteArray(); ++ byte[] data100 = makeRecursiveMessage(100).toByteArray(); ++ ++ RecursiveMessage unused = RecursiveMessage.parseFrom(data99); ++ Throwable thrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100)); ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testParseFromBytes_recursiveKnownGroups() throws Exception { ++ byte[] data99 = makeRecursiveGroup(99).toByteArray(); ++ byte[] data100 = makeRecursiveGroup(100).toByteArray(); ++ ++ RecursiveGroup unused = RecursiveGroup.parseFrom(data99); ++ Throwable thrown = ++ assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100)); ++ assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ @SuppressWarnings("ProtoParseFromByteString") ++ public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception { ++ ByteString byteString = generateNestingGroups(102); ++ ++ Throwable parseFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.parseFrom(byteString.toByteArray())); ++ Throwable mergeFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray())); ++ ++ assertThat(parseFromThrown) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ assertThat(mergeFromThrown) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ ++ @Test ++ public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception { ++ byte[] bytes = generateNestingGroups(101).toByteArray(); ++ ++ Throwable parseFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.parseFrom(new ByteArrayInputStream(bytes))); ++ Throwable mergeFromThrown = ++ assertThrows( ++ InvalidProtocolBufferException.class, ++ () -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes))); ++ ++ assertThat(parseFromThrown) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ assertThat(mergeFromThrown) ++ .hasMessageThat() ++ .contains("Protocol message had too many levels of nesting"); ++ } ++ + @Test + public void testParseFromByteBuffer_extensions() throws Exception { + TestAllExtensionsLite message = +@@ -2809,4 +3014,31 @@ public class LiteTest { + } + return false; + } ++ ++ private static ByteString generateNestingGroups(int num) throws IOException { ++ int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP); ++ ByteString.Output byteStringOutput = ByteString.newOutput(); ++ CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput); ++ for (int i = 0; i < num; i++) { ++ codedOutput.writeInt32NoTag(groupTap); ++ } ++ codedOutput.flush(); ++ return byteStringOutput.toByteString(); ++ } ++ ++ private static RecursiveMessage makeRecursiveMessage(int num) { ++ if (num == 0) { ++ return RecursiveMessage.getDefaultInstance(); ++ } else { ++ return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build(); ++ } ++ } ++ ++ private static RecursiveGroup makeRecursiveGroup(int num) { ++ if (num == 0) { ++ return RecursiveGroup.getDefaultInstance(); ++ } else { ++ return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build(); ++ } ++ } + } +-- +2.25.1 + diff --git a/meta-oe/recipes-devtools/protobuf/protobuf_3.19.6.bb b/meta-oe/recipes-devtools/protobuf/protobuf_3.19.6.bb index 8e50054718..2b19d07ff8 100644 --- a/meta-oe/recipes-devtools/protobuf/protobuf_3.19.6.bb +++ b/meta-oe/recipes-devtools/protobuf/protobuf_3.19.6.bb @@ -19,6 +19,7 @@ SRC_URI = "git://github.com/protocolbuffers/protobuf.git;branch=3.19.x;protocol= file://0001-examples-Makefile-respect-CXX-LDFLAGS-variables-fix-.patch \ file://0001-Fix-linking-error-with-ld-gold.patch \ file://0001-Lower-init-prio-for-extension-attributes.patch \ + file://0001-Add-recursion-check-when-parsing-unknown-fields-in-J.patch \ " SRC_URI:append:mips:toolchain-clang = " file://0001-Fix-build-on-mips-clang.patch " SRC_URI:append:mipsel:toolchain-clang = " file://0001-Fix-build-on-mips-clang.patch "