Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 130 additions & 113 deletions src/mp/gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,109 @@ static bool BoxedType(const ::capnp::Type& type)
type.isFloat64() || type.isEnum());
}

struct Field
{
::capnp::StructSchema::Field param;
bool param_is_set = false;
::capnp::StructSchema::Field result;
bool result_is_set = false;
int args = 0;
bool retval = false;
bool optional = false;
bool requested = false;
bool skip = false;
kj::StringPtr exception;
};

struct FieldList
{
std::vector<Field> fields;
std::map<kj::StringPtr, int> field_idx; // name -> args index
bool has_result = false;

void addField(const ::capnp::StructSchema::Field& schema_field, bool param, bool result)
{
auto field_name = schema_field.getProto().getName();
auto inserted = field_idx.emplace(field_name, fields.size());
if (inserted.second) {
fields.emplace_back();
}
auto& field = fields[inserted.first->second];
if (param) {
field.param = schema_field;
field.param_is_set = true;
}
if (result) {
field.result = schema_field;
field.result_is_set = true;
}

if (!param && field_name == kj::StringPtr{"result"}) {
field.retval = true;
has_result = true;
}

if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
field.skip = true;
}
GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);

int32_t count = 1;
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
if (schema_field.getType().isStruct()) {
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
COUNT_ANNOTATION_ID, &count);
} else if (schema_field.getType().isInterface()) {
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
COUNT_ANNOTATION_ID, &count);
}
}


if (inserted.second && !field.retval && !field.exception.size()) {
field.args = count;
}
}

void mergeFields()
{
for (auto& field : field_idx) {
auto has_field = field_idx.find("has" + Cap(field.first));
if (has_field != field_idx.end()) {
fields[has_field->second].skip = true;
fields[field.second].optional = true;
}
auto want_field = field_idx.find("want" + Cap(field.first));
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
fields[want_field->second].skip = true;
fields[field.second].requested = true;
}
}
}
};

std::string AccessorType(kj::StringPtr base_name, const Field& field)
{
const auto& f = field.param_is_set ? field.param : field.result;
const auto field_name = f.getProto().getName();
const auto field_type = f.getType();

std::ostringstream out;
out << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", ";
if (!field.param_is_set) {
out << "FIELD_OUT";
} else if (field.result_is_set) {
out << "FIELD_IN | FIELD_OUT";
} else {
out << "FIELD_IN";
}
if (field.optional) out << " | FIELD_OPTIONAL";
if (field.requested) out << " | FIELD_REQUESTED";
if (BoxedType(field_type)) out << " | FIELD_BOXED";
out << ">";
return out.str();
}

// src_file is path to .capnp file to generate stub code from.
//
// src_prefix can be used to generate outputs in a different directory than the
Expand Down Expand Up @@ -332,6 +435,13 @@ static void Generate(kj::StringPtr src_prefix,

if (node.getProto().isStruct()) {
const auto& struc = node.asStruct();

FieldList fields;
for (const auto schema_field : struc.getFields()) {
fields.addField(schema_field, true, true);
}
fields.mergeFields();

std::ostringstream generic_name;
generic_name << node_name;
dec << "template<";
Expand All @@ -352,22 +462,18 @@ static void Generate(kj::StringPtr src_prefix,
dec << "struct ProxyStruct<" << message_namespace << "::" << generic_name.str() << ">\n";
dec << "{\n";
dec << " using Struct = " << message_namespace << "::" << generic_name.str() << ";\n";
for (const auto field : struc.getFields()) {
auto field_name = field.getProto().getName();
for (const auto& field : fields.fields) {
auto field_name = field.param.getProto().getName();
add_accessor(field_name);
dec << " using " << Cap(field_name) << "Accessor = Accessor<" << base_name
<< "_fields::" << Cap(field_name) << ", FIELD_IN | FIELD_OUT";
if (BoxedType(field.getType())) dec << " | FIELD_BOXED";
dec << ">;\n";
dec << " using " << Cap(field_name) << "Accessor = "
<< AccessorType(base_name, field) << ";\n";
}
dec << " using Accessors = std::tuple<";
size_t i = 0;
for (const auto field : struc.getFields()) {
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
continue;
}
for (const auto& field : fields.fields) {
if (field.skip) continue;
if (i) dec << ", ";
dec << Cap(field.getProto().getName()) << "Accessor";
dec << Cap(field.param.getProto().getName()) << "Accessor";
++i;
}
dec << ">;\n";
Expand All @@ -381,13 +487,11 @@ static void Generate(kj::StringPtr src_prefix,
inl << "public:\n";
inl << " using Struct = " << message_namespace << "::" << node_name << ";\n";
size_t i = 0;
for (const auto field : struc.getFields()) {
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
continue;
}
auto field_name = field.getProto().getName();
for (const auto& field : fields.fields) {
if (field.skip) continue;
auto field_name = field.param.getProto().getName();
auto member_name = field_name;
GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name);
GetAnnotationText(field.param.getProto(), NAME_ANNOTATION_ID, &member_name);
inl << " static decltype(auto) get(std::integral_constant<size_t, " << i << ">) { return "
<< "&" << proxied_class_type << "::" << member_name << "; }\n";
++i;
Expand Down Expand Up @@ -430,85 +534,14 @@ static void Generate(kj::StringPtr src_prefix,
const bool is_construct = method_name == kj::StringPtr{"construct"};
const bool is_destroy = method_name == kj::StringPtr{"destroy"};

struct Field
{
::capnp::StructSchema::Field param;
bool param_is_set = false;
::capnp::StructSchema::Field result;
bool result_is_set = false;
int args = 0;
bool retval = false;
bool optional = false;
bool requested = false;
bool skip = false;
kj::StringPtr exception;
};

std::vector<Field> fields;
std::map<kj::StringPtr, int> field_idx; // name -> args index
bool has_result = false;

auto add_field = [&](const ::capnp::StructSchema::Field& schema_field, bool param) {
if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
return;
}

auto field_name = schema_field.getProto().getName();
auto inserted = field_idx.emplace(field_name, fields.size());
if (inserted.second) {
fields.emplace_back();
}
auto& field = fields[inserted.first->second];
if (param) {
field.param = schema_field;
field.param_is_set = true;
} else {
field.result = schema_field;
field.result_is_set = true;
}

if (!param && field_name == kj::StringPtr{"result"}) {
field.retval = true;
has_result = true;
}

GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);

int32_t count = 1;
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
if (schema_field.getType().isStruct()) {
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
COUNT_ANNOTATION_ID, &count);
} else if (schema_field.getType().isInterface()) {
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
COUNT_ANNOTATION_ID, &count);
}
}


if (inserted.second && !field.retval && !field.exception.size()) {
field.args = count;
}
};

FieldList fields;
for (const auto schema_field : method.getParamType().getFields()) {
add_field(schema_field, true);
fields.addField(schema_field, true, false);
}
for (const auto schema_field : method.getResultType().getFields()) {
add_field(schema_field, false);
}
for (auto& field : field_idx) {
auto has_field = field_idx.find("has" + Cap(field.first));
if (has_field != field_idx.end()) {
fields[has_field->second].skip = true;
fields[field.second].optional = true;
}
auto want_field = field_idx.find("want" + Cap(field.first));
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
fields[want_field->second].skip = true;
fields[field.second].requested = true;
}
fields.addField(schema_field, false, true);
}
fields.mergeFields();

if (!is_construct && !is_destroy && (&method_interface == &interface)) {
methods << "template<>\n";
Expand All @@ -524,25 +557,11 @@ static void Generate(kj::StringPtr src_prefix,
std::ostringstream server_invoke_start;
std::ostringstream server_invoke_end;
int argc = 0;
for (const auto& field : fields) {
for (const auto& field : fields.fields) {
if (field.skip) continue;

const auto& f = field.param_is_set ? field.param : field.result;
auto field_name = f.getProto().getName();
auto field_type = f.getType();

std::ostringstream field_flags;
if (!field.param_is_set) {
field_flags << "FIELD_OUT";
} else if (field.result_is_set) {
field_flags << "FIELD_IN | FIELD_OUT";
} else {
field_flags << "FIELD_IN";
}
if (field.optional) field_flags << " | FIELD_OPTIONAL";
if (field.requested) field_flags << " | FIELD_REQUESTED";
if (BoxedType(field_type)) field_flags << " | FIELD_BOXED";

add_accessor(field_name);

std::ostringstream fwd_args;
Expand All @@ -569,8 +588,7 @@ static void Generate(kj::StringPtr src_prefix,
client_invoke << "MakeClientParam<";
}

client_invoke << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
<< field_flags.str() << ">>(";
client_invoke << AccessorType(base_name, field) << ">(";

if (field.retval) {
client_invoke << field_name;
Expand All @@ -586,8 +604,7 @@ static void Generate(kj::StringPtr src_prefix,
} else {
server_invoke_start << "MakeServerField<" << field.args;
}
server_invoke_start << ", Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
<< field_flags.str() << ">>(";
server_invoke_start << ", " << AccessorType(base_name, field) << ">(";
server_invoke_end << ")";
}

Expand All @@ -603,12 +620,12 @@ static void Generate(kj::StringPtr src_prefix,
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal
<< "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name
<< "(" << super_str << client_args.str() << ") {\n";
if (has_result) {
if (fields.has_result) {
def_client << " typename M" << method_ordinal << "::Result result;\n";
}
def_client << " clientInvoke(" << self_str << ", &" << message_namespace << "::" << node_name
<< "::Client::" << method_name << "Request" << client_invoke.str() << ");\n";
if (has_result) def_client << " return result;\n";
if (fields.has_result) def_client << " return result;\n";
def_client << "}\n";

server << " kj::Promise<void> " << method_name << "(" << Cap(method_name)
Expand Down
1 change: 1 addition & 0 deletions test/mp/test/foo-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <mp/type-map.h>
#include <mp/type-message.h>
#include <mp/type-number.h>
#include <mp/type-optional.h>
#include <mp/type-pointer.h>
#include <mp/type-set.h>
#include <mp/type-string.h>
Expand Down
2 changes: 2 additions & 0 deletions test/mp/test/foo.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct FooStruct $Proxy.wrap("mp::test::FooStruct") {
name @0 :Text;
setint @1 :List(Int32);
vbool @2 :List(Bool);
optionalInt @3 :Int32 $Proxy.name("optional_int");
hasOptionalInt @4 :Bool;
}

struct FooCustom $Proxy.wrap("mp::test::FooCustom") {
Expand Down
2 changes: 2 additions & 0 deletions test/mp/test/foo.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <set>
#include <vector>
Expand All @@ -21,6 +22,7 @@ struct FooStruct
std::string name;
std::set<int> setint;
std::vector<bool> vbool;
std::optional<int> optional_int;
};

enum class FooEnum : uint8_t { ONE = 1, TWO = 2, };
Expand Down
7 changes: 7 additions & 0 deletions test/mp/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ KJ_TEST("Call FooInterface methods")
in.vbool.push_back(false);
in.vbool.push_back(true);
in.vbool.push_back(false);
in.optional_int = 3;
FooStruct out = foo->pass(in);
KJ_EXPECT(in.name == out.name);
KJ_EXPECT(in.setint.size() == out.setint.size());
Expand All @@ -150,6 +151,12 @@ KJ_TEST("Call FooInterface methods")
for (size_t i = 0; i < in.vbool.size(); ++i) {
KJ_EXPECT(in.vbool[i] == out.vbool[i]);
}
KJ_EXPECT(in.optional_int == out.optional_int);

// Additional checks for std::optional member
KJ_EXPECT(foo->pass(in).optional_int == 3);
in.optional_int.reset();
KJ_EXPECT(!foo->pass(in).optional_int);

FooStruct err;
try {
Expand Down
Loading