Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.formats.protobuf.deserialize;

import org.apache.flink.formats.protobuf.PbCodegenException;
import org.apache.flink.formats.protobuf.util.PbCodegenAppender;

import com.google.protobuf.Descriptors.FieldDescriptor;

/**
* Deserializer that converts a protobuf message to its raw binary bytes. Used when a recursive
* message type is detected and represented as BYTES in the Flink schema, preserving the data for
* optional later unpacking via a UDF.
*/
public class PbCodegenBytesDeserializer implements PbCodegenDeserializer {
private final FieldDescriptor fd;

public PbCodegenBytesDeserializer(FieldDescriptor fd) {
this.fd = fd;
}

@Override
public String codegen(String resultVar, String pbObjectCode, int indent)
throws PbCodegenException {
PbCodegenAppender appender = new PbCodegenAppender(indent);
appender.appendLine(resultVar + " = " + pbObjectCode + ".toByteArray()");
return appender.code();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,25 @@
import org.apache.flink.formats.protobuf.util.PbFormatUtils;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.RowType;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;

/** Codegen factory class which return {@link PbCodegenDeserializer} of different data type. */
public class PbCodegenDeserializeFactory {
public static PbCodegenDeserializer getPbCodegenDes(
Descriptors.FieldDescriptor fd, LogicalType type, PbFormatContext formatContext)
throws PbCodegenException {
// Check for recursive message type represented as raw bytes:
// the proto field is a MESSAGE but the logical type has been resolved to VARBINARY
// by PbToRowTypeUtil's cycle detection.
if (fd.getJavaType() == JavaType.MESSAGE
&& type.getTypeRoot() == LogicalTypeRoot.VARBINARY) {
return new PbCodegenBytesDeserializer(fd);
}
// We do not use FieldDescriptor to check because there's no way to get
// element field descriptor of array type.
if (type instanceof RowType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ private static void validateTypeMatch(FieldDescriptor fd, LogicalType logicalTyp
// simple type
validateSimpleType(fd, logicalType.getTypeRoot());
} else {
// message type
// message type - may be RowType (normal) or VarBinaryType (recursive, as bytes)
if (logicalType.getTypeRoot() == LogicalTypeRoot.VARBINARY) {
// Recursive message type represented as raw bytes - valid mapping
return;
}
if (!(logicalType instanceof RowType)) {
throw new ValidationException(
"Unexpected LogicalType: " + logicalType + ". It should be RowType");
Expand Down Expand Up @@ -131,6 +135,10 @@ private static void validateTypeMatch(FieldDescriptor fd, LogicalType logicalTyp
if (fd.getJavaType() == JavaType.MESSAGE) {
// array message type
LogicalType elementType = arrayType.getElementType();
if (elementType.getTypeRoot() == LogicalTypeRoot.VARBINARY) {
// Recursive message type as raw bytes - valid
return;
}
if (!(elementType instanceof RowType)) {
throw new ValidationException(
"Unexpected logicalType: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,53 @@
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;

import java.util.HashSet;
import java.util.Set;

/** Generate Row type information according to pb descriptors. */
public class PbToRowTypeUtil {
public static RowType generateRowType(Descriptors.Descriptor root) {
return generateRowType(root, false);
}

public static RowType generateRowType(Descriptors.Descriptor root, boolean enumAsInt) {
// Track message types currently being resolved in the ancestor chain to detect
// recursive references (e.g., A -> B -> A). Without this, recursive proto
// definitions cause infinite recursion and StackOverflowError.
Set<String> ancestors = new HashSet<>();
return generateRowTypeInternal(root, enumAsInt, ancestors);
}

/**
* @param ancestors message type full names currently being resolved in the call stack. Used to
* detect cycles: if a field's message type is already in this set, it's a recursive
* reference and gets emitted as BYTES instead of recursing infinitely.
*/
private static RowType generateRowTypeInternal(
Descriptors.Descriptor root, boolean enumAsInt, Set<String> ancestors) {
int size = root.getFields().size();
LogicalType[] types = new LogicalType[size];
String[] rowFieldNames = new String[size];

for (int i = 0; i < size; i++) {
FieldDescriptor field = root.getFields().get(i);
rowFieldNames[i] = field.getName();
types[i] = generateFieldTypeInformation(field, enumAsInt);
// Mark this type as "being resolved" before processing its fields
String fullName = root.getFullName();
ancestors.add(fullName);

try {
for (int i = 0; i < size; i++) {
FieldDescriptor field = root.getFields().get(i);
rowFieldNames[i] = field.getName();
types[i] = generateFieldTypeInformation(field, enumAsInt, ancestors);
}
} finally {
// Unmark when we're done - sibling branches shouldn't see this type as an ancestor
ancestors.remove(fullName);
}
return RowType.of(types, rowFieldNames);
}

private static LogicalType generateFieldTypeInformation(
FieldDescriptor field, boolean enumAsInt) {
FieldDescriptor field, boolean enumAsInt, Set<String> ancestors) {
JavaType fieldType = field.getJavaType();
LogicalType type;
if (fieldType.equals(JavaType.MESSAGE)) {
Expand All @@ -66,16 +92,36 @@ private static LogicalType generateFieldTypeInformation(
generateFieldTypeInformation(
field.getMessageType()
.findFieldByName(PbConstant.PB_MAP_KEY_NAME),
enumAsInt),
enumAsInt,
ancestors),
generateFieldTypeInformation(
field.getMessageType()
.findFieldByName(PbConstant.PB_MAP_VALUE_NAME),
enumAsInt));
enumAsInt,
ancestors));
return mapType;
} else if (field.isRepeated()) {
return new ArrayType(generateRowType(field.getMessageType()));
}

// Cycle detection: if this field's message type is already being resolved
// in the ancestor chain, we have a recursive proto definition
// (e.g., Node -> Child -> Node). Columnar schemas cannot represent
// infinite recursion, so we emit the field as raw BYTES. The protobuf
// binary is preserved and can be deserialized on demand if consumers
// need the recursive data.
String msgFullName = field.getMessageType().getFullName();
if (ancestors.contains(msgFullName)) {
LogicalType bytesType = new VarBinaryType(Integer.MAX_VALUE);
if (field.isRepeated()) {
return new ArrayType(bytesType);
}
return bytesType;
}

if (field.isRepeated()) {
return new ArrayType(
generateRowTypeInternal(field.getMessageType(), enumAsInt, ancestors));
} else {
return generateRowType(field.getMessageType());
return generateRowTypeInternal(field.getMessageType(), enumAsInt, ancestors);
}
} else {
if (fieldType.equals(JavaType.STRING)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.formats.protobuf;

import org.apache.flink.formats.protobuf.deserialize.PbRowDataDeserializationSchema;
import org.apache.flink.formats.protobuf.testproto.RecursiveMessageProto2Test;
import org.apache.flink.formats.protobuf.testproto.RecursiveMessageTest;
import org.apache.flink.formats.protobuf.util.PbFormatUtils;
import org.apache.flink.formats.protobuf.util.PbToRowTypeUtil;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.RowType;

import com.google.protobuf.Descriptors;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

/** Test handling of recursive protobuf message types. */
public class RecursiveMessageProtoToRowTest {

private static final String PROTO3_CLASS =
"org.apache.flink.formats.protobuf.testproto.RecursiveMessageTest";
private static final String PROTO2_CLASS =
"org.apache.flink.formats.protobuf.testproto.RecursiveMessageProto2Test";

/** Deserializes proto bytes through the full Flink codegen pipeline. */
private RowData deserialize(String className, boolean readDefaultValues, byte[] data)
throws Exception {
RowType rowType = PbToRowTypeUtil.generateRowType(PbFormatUtils.getDescriptor(className));
PbFormatConfig config = new PbFormatConfig(className, false, readDefaultValues, "");
PbRowDataDeserializationSchema schema =
new PbRowDataDeserializationSchema(rowType, InternalTypeInfo.of(rowType), config);
schema.open(null);
return schema.deserialize(data);
}

// --- schema generation ---

@Test
public void testCycleDetectionProducesBytes() {
Descriptors.Descriptor descriptor = RecursiveMessageTest.getDescriptor();
RowType rowType = PbToRowTypeUtil.generateRowType(descriptor);

assertEquals(3, rowType.getFieldCount());
assertEquals("id", rowType.getFieldNames().get(0));
assertEquals("name", rowType.getFieldNames().get(1));
assertEquals("parent", rowType.getFieldNames().get(2));
assertEquals(
"Recursive field should be VARBINARY",
LogicalTypeRoot.VARBINARY,
rowType.getTypeAt(2).getTypeRoot());
}

// --- proto3 deserialization ---

@Test
public void testProto3NestedDataPreservedAsBytes() throws Exception {
RecursiveMessageTest grandparent =
RecursiveMessageTest.newBuilder().setId(1).setName("grandparent").build();
RecursiveMessageTest parent =
RecursiveMessageTest.newBuilder()
.setId(2)
.setName("parent")
.setParent(grandparent)
.build();
RecursiveMessageTest message =
RecursiveMessageTest.newBuilder()
.setId(3)
.setName("child")
.setParent(parent)
.build();

RowData row = deserialize(PROTO3_CLASS, false, message.toByteArray());
assertNotNull(row);
assertEquals(3, row.getInt(0));
assertEquals("child", row.getString(1).toString());

// Parse the bytes back - should contain full parent including grandparent
byte[] parentBytes = row.getBinary(2);
RecursiveMessageTest parsedParent = RecursiveMessageTest.parseFrom(parentBytes);
assertEquals(2, parsedParent.getId());
assertEquals("parent", parsedParent.getName());
assertTrue(parsedParent.hasParent());
assertEquals(1, parsedParent.getParent().getId());
assertEquals("grandparent", parsedParent.getParent().getName());
}

@Test
public void testProto3UnsetFieldReadsDefaultBytes() throws Exception {
// In proto3, the recursive field is treated as a primitive type (VARBINARY)
// by the codegen, so it always reads default values regardless of the
// readDefaultValues config. Both true and false produce the same result:
// empty bytes from .toByteArray() on the default message instance.
RecursiveMessageTest message =
RecursiveMessageTest.newBuilder().setId(1).setName("leaf").build();

for (boolean readDefaults : new boolean[] {true, false}) {
RowData row = deserialize(PROTO3_CLASS, readDefaults, message.toByteArray());
assertNotNull(row);
assertEquals(1, row.getInt(0));
assertEquals("leaf", row.getString(1).toString());
byte[] parentBytes = row.getBinary(2);
assertNotNull("readDefaultValues=" + readDefaults, parentBytes);
RecursiveMessageTest parsed = RecursiveMessageTest.parseFrom(parentBytes);
assertEquals(0, parsed.getId());
assertEquals("", parsed.getName());
}
}

// --- proto2 deserialization ---

@Test
public void testProto2SetFieldPreservedAsBytes() throws Exception {
RecursiveMessageProto2Test parent =
RecursiveMessageProto2Test.newBuilder().setId(1).setName("parent").build();
RecursiveMessageProto2Test message =
RecursiveMessageProto2Test.newBuilder()
.setId(2)
.setName("child")
.setParent(parent)
.build();

RowData row = deserialize(PROTO2_CLASS, false, message.toByteArray());
assertNotNull(row);
assertEquals(2, row.getInt(0));
assertEquals("child", row.getString(1).toString());

byte[] parentBytes = row.getBinary(2);
assertNotNull(parentBytes);
RecursiveMessageProto2Test parsed = RecursiveMessageProto2Test.parseFrom(parentBytes);
assertEquals(1, parsed.getId());
assertEquals("parent", parsed.getName());
}

@Test
public void testProto2UnsetFieldIsNull() throws Exception {
// Proto2 has explicit field presence. With readDefaultValues=false,
// hasParent() returns false so the field is null.
RecursiveMessageProto2Test message =
RecursiveMessageProto2Test.newBuilder().setId(1).setName("leaf").build();

RowData row = deserialize(PROTO2_CLASS, false, message.toByteArray());
assertNotNull(row);
assertEquals(1, row.getInt(0));
assertEquals("leaf", row.getString(1).toString());
assertTrue("Proto2 unset recursive field should be null", row.isNullAt(2));
}

@Test
public void testProto2UnsetFieldWithReadDefaultValues() throws Exception {
// With readDefaultValues=true, proto2 returns the default instance as bytes
// instead of null.
RecursiveMessageProto2Test message =
RecursiveMessageProto2Test.newBuilder().setId(1).setName("leaf").build();

RowData row = deserialize(PROTO2_CLASS, true, message.toByteArray());
assertNotNull(row);
assertEquals(1, row.getInt(0));
assertEquals("leaf", row.getString(1).toString());
byte[] parentBytes = row.getBinary(2);
assertNotNull(parentBytes);
RecursiveMessageProto2Test parsed = RecursiveMessageProto2Test.parseFrom(parentBytes);
assertEquals(0, parsed.getId());
assertEquals("", parsed.getName());
}
}
Loading