Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 96 additions & 7 deletions src/mcp/mcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ static const tool_def_t TOOLS[] = {
"\"string\"},\"name_pattern\":{\"type\":\"string\"},\"qn_pattern\":{\"type\":\"string\"},"
"\"file_pattern\":{\"type\":\"string\"},\"relationship\":{\"type\":\"string\"},\"min_degree\":"
"{\"type\":\"integer\"},\"max_degree\":{\"type\":\"integer\"},\"exclude_entry_points\":{"
"\"type\":\"boolean\"},\"include_connected\":{\"type\":\"boolean\"},\"limit\":{\"type\":"
"\"type\":\"boolean\"},\"include_connected\":{\"type\":\"boolean\",\"description\":"
"\"Include connected node names for up to the first 50 results, capped at 10 names per "
"result.\"},\"limit\":{\"type\":"
"\"integer\",\"description\":\"Max results. Default: "
"unlimited\"},\"offset\":{\"type\":\"integer\",\"default\":0}},\"required\":[\"project\"]}"},

Expand Down Expand Up @@ -494,6 +496,44 @@ bool cbm_mcp_get_bool_arg(const char *args_json, const char *key) {
return result;
}

/* Extract a JSON string array. Returns heap-allocated array of heap strings.
* Sets *out_count. Returns NULL only when the key is absent or not an array.
* For an empty array [], returns a non-NULL pointer with *out_count = 0
* so callers can distinguish "not provided" from "explicitly empty". */
static char **cbm_mcp_get_string_array_arg(const char *args_json, const char *key, int *out_count) {
*out_count = 0;
yyjson_doc *doc = yyjson_read(args_json, strlen(args_json), 0);
if (!doc) {
return NULL;
}
yyjson_val *root = yyjson_doc_get_root(doc);
yyjson_val *arr = yyjson_obj_get(root, key);
if (!arr || !yyjson_is_arr(arr)) {
yyjson_doc_free(doc);
return NULL;
}
int count = (int)yyjson_arr_size(arr);
if (count == 0) {
yyjson_doc_free(doc);
return calloc(1, sizeof(char *));
}
char **result = malloc((size_t)count * sizeof(char *));
int n = 0;
size_t idx, max;
yyjson_val *val;
yyjson_arr_foreach(arr, idx, max, val) {
if (yyjson_is_str(val)) {
result[n++] = heap_strdup(yyjson_get_str(val));
}
}
yyjson_doc_free(doc);
if (n == 0) {
return result;
}
*out_count = n;
return result;
}

/* ══════════════════════════════════════════════════════════════════
* MCP SERVER
* ══════════════════════════════════════════════════════════════════ */
Expand Down Expand Up @@ -940,6 +980,9 @@ static char *handle_search_graph(cbm_mcp_server_t *srv, const char *args) {
char *label = cbm_mcp_get_string_arg(args, "label");
char *name_pattern = cbm_mcp_get_string_arg(args, "name_pattern");
char *file_pattern = cbm_mcp_get_string_arg(args, "file_pattern");
char *relationship = cbm_mcp_get_string_arg(args, "relationship");
bool exclude_entry_points = cbm_mcp_get_bool_arg(args, "exclude_entry_points");
bool include_connected = cbm_mcp_get_bool_arg(args, "include_connected");
int limit = cbm_mcp_get_int_arg(args, "limit", 500000);
int offset = cbm_mcp_get_int_arg(args, "offset", 0);
int min_degree = cbm_mcp_get_int_arg(args, "min_degree", -1);
Expand All @@ -950,6 +993,9 @@ static char *handle_search_graph(cbm_mcp_server_t *srv, const char *args) {
.label = label,
.name_pattern = name_pattern,
.file_pattern = file_pattern,
.relationship = relationship,
.exclude_entry_points = exclude_entry_points,
.include_connected = include_connected,
.limit = limit,
.offset = offset,
.min_degree = min_degree,
Expand Down Expand Up @@ -977,6 +1023,16 @@ static char *handle_search_graph(cbm_mcp_server_t *srv, const char *args) {
sr->node.file_path ? sr->node.file_path : "");
yyjson_mut_obj_add_int(doc, item, "in_degree", sr->in_degree);
yyjson_mut_obj_add_int(doc, item, "out_degree", sr->out_degree);
/* Include connected node names if populated */
if (sr->connected_count > 0 && sr->connected_names) {
yyjson_mut_val *conn = yyjson_mut_arr(doc);
for (int j = 0; j < sr->connected_count; j++) {
if (sr->connected_names[j]) {
yyjson_mut_arr_add_strcpy(doc, conn, sr->connected_names[j]);
}
}
yyjson_mut_obj_add_val(doc, item, "connected_names", conn);
}
yyjson_mut_arr_add_val(results, item);
}
yyjson_mut_obj_add_val(doc, root, "results", results);
Expand All @@ -990,6 +1046,7 @@ static char *handle_search_graph(cbm_mcp_server_t *srv, const char *args) {
free(label);
free(name_pattern);
free(file_pattern);
free(relationship);

char *result = cbm_mcp_text_result(json, false);
free(json);
Expand Down Expand Up @@ -1225,9 +1282,17 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
char *direction = cbm_mcp_get_string_arg(args, "direction");
int depth = cbm_mcp_get_int_arg(args, "depth", 3);

/* Extract edge_types array; fall back to {"CALLS"} if not provided */
int user_edge_type_count = 0;
char **user_edge_types =
cbm_mcp_get_string_array_arg(args, "edge_types", &user_edge_type_count);

if (!func_name) {
free(project);
free(direction);
for (int i = 0; i < user_edge_type_count; i++)
free(user_edge_types[i]);
free(user_edge_types);
return cbm_mcp_text_result("function_name is required", true);
}
if (!store) {
Expand All @@ -1237,6 +1302,9 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
free(func_name);
free(project);
free(direction);
for (int i = 0; i < user_edge_type_count; i++)
free(user_edge_types[i]);
free(user_edge_types);
return _res;
}

Expand All @@ -1245,6 +1313,9 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
free(func_name);
free(project);
free(direction);
for (int i = 0; i < user_edge_type_count; i++)
free(user_edge_types[i]);
free(user_edge_types);
return not_indexed;
}

Expand All @@ -1261,6 +1332,9 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
free(func_name);
free(project);
free(direction);
for (int i = 0; i < user_edge_type_count; i++)
free(user_edge_types[i]);
free(user_edge_types);
cbm_store_free_nodes(nodes, 0);
return cbm_mcp_text_result("{\"error\":\"function not found\"}", true);
}
Expand All @@ -1272,8 +1346,16 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
yyjson_mut_obj_add_str(doc, root, "function", func_name);
yyjson_mut_obj_add_str(doc, root, "direction", direction);

const char *edge_types[] = {"CALLS"};
int edge_type_count = 1;
const char *default_edge_types[] = {"CALLS"};
const char **edge_types;
int edge_type_count;
if (user_edge_types) {
edge_types = (const char **)user_edge_types;
edge_type_count = user_edge_type_count;
} else {
edge_types = default_edge_types;
edge_type_count = 1;
}

/* Run BFS for each requested direction.
* IMPORTANT: yyjson_mut_obj_add_str borrows pointers — we must keep
Expand All @@ -1287,8 +1369,10 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
cbm_traverse_result_t tr_in = {0};

if (do_outbound) {
cbm_store_bfs(store, nodes[0].id, "outbound", edge_types, edge_type_count, depth, 100,
&tr_out);
if (edge_type_count > 0) {
cbm_store_bfs(store, nodes[0].id, "outbound", edge_types, edge_type_count, depth, 100,
&tr_out);
}

yyjson_mut_val *callees = yyjson_mut_arr(doc);
for (int i = 0; i < tr_out.visited_count; i++) {
Expand All @@ -1305,8 +1389,10 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
}

if (do_inbound) {
cbm_store_bfs(store, nodes[0].id, "inbound", edge_types, edge_type_count, depth, 100,
&tr_in);
if (edge_type_count > 0) {
cbm_store_bfs(store, nodes[0].id, "inbound", edge_types, edge_type_count, depth, 100,
&tr_in);
}

yyjson_mut_val *callers = yyjson_mut_arr(doc);
for (int i = 0; i < tr_in.visited_count; i++) {
Expand Down Expand Up @@ -1338,6 +1424,9 @@ static char *handle_trace_call_path(cbm_mcp_server_t *srv, const char *args) {
free(func_name);
free(project);
free(direction);
for (int i = 0; i < user_edge_type_count; i++)
free(user_edge_types[i]);
free(user_edge_types);

char *result = cbm_mcp_text_result(json, false);
free(json);
Expand Down
105 changes: 97 additions & 8 deletions src/store/store.c
Original file line number Diff line number Diff line change
Expand Up @@ -1955,12 +1955,33 @@ int cbm_store_search(cbm_store_t *s, const cbm_search_params_t *params, cbm_sear
char count_sql[4096];
int bind_idx = 0;

/* We build a query that selects nodes with optional degree subqueries */
const char *select_cols =
"SELECT n.id, n.project, n.label, n.name, n.qualified_name, "
"n.file_path, n.start_line, n.end_line, n.properties, "
"(SELECT COUNT(*) FROM edges e WHERE e.target_id = n.id AND e.type = 'CALLS') AS in_deg, "
"(SELECT COUNT(*) FROM edges e WHERE e.source_id = n.id AND e.type = 'CALLS') AS out_deg ";
/* When a relationship filter is active, compute degree using that edge type
* so that min_degree/max_degree/exclude_entry_points filter consistently.
* Edge type is validated to contain only safe chars (A-Z, a-z, 0-9, _)
* to prevent SQL injection since it's inlined in the SELECT clause. */
const char *deg_edge_type = "CALLS";
if (params->relationship) {
bool safe = true;
for (const char *p = params->relationship; *p; p++) {
if (!((*p >= 'A' && *p <= 'Z') || (*p >= 'a' && *p <= 'z') ||
(*p >= '0' && *p <= '9') || *p == '_')) {
safe = false;
break;
}
}
if (safe && params->relationship[0] != '\0' && strlen(params->relationship) <= 64) {
deg_edge_type = params->relationship;
}
}
char select_cols[640];
snprintf(select_cols, sizeof(select_cols),
"SELECT n.id, n.project, n.label, n.name, n.qualified_name, "
"n.file_path, n.start_line, n.end_line, n.properties, "
"(SELECT COUNT(*) FROM edges e WHERE e.target_id = n.id AND e.type = '%s'"
" AND e.source_id != e.target_id) AS in_deg, "
"(SELECT COUNT(*) FROM edges e WHERE e.source_id = n.id AND e.type = '%s'"
" AND e.source_id != e.target_id) AS out_deg ",
deg_edge_type, deg_edge_type);

/* Start building WHERE */
char where[2048] = "";
Expand Down Expand Up @@ -2019,6 +2040,17 @@ int cbm_store_search(cbm_store_t *s, const cbm_search_params_t *params, cbm_sear
BIND_TEXT(like_pattern);
}

/* Relationship filter: only include nodes that have at least one edge of this type */
if (params->relationship) {
char rel_clause[256];
snprintf(rel_clause, sizeof(rel_clause),
"EXISTS (SELECT 1 FROM edges e2 WHERE e2.type = ?%d"
" AND (e2.source_id = n.id OR e2.target_id = n.id))",
bind_idx + 1);
ADD_WHERE(rel_clause);
BIND_TEXT(params->relationship);
}

/* Exclude labels: use parameterized placeholders to prevent SQL injection */
if (params->exclude_labels) {
char excl_clause[512] = "n.label NOT IN (";
Expand Down Expand Up @@ -2051,7 +2083,8 @@ int cbm_store_search(cbm_store_t *s, const cbm_search_params_t *params, cbm_sear
/* Degree filters: -1 = no filter, 0+ = active filter.
* Wraps in subquery to filter on computed degree columns. */
// NOLINTNEXTLINE(readability-implicit-bool-conversion)
bool has_degree_filter = (params->min_degree >= 0 || params->max_degree >= 0);
bool has_degree_filter =
(params->min_degree >= 0 || params->max_degree >= 0 || params->exclude_entry_points);
if (has_degree_filter) {
char inner_sql[4096];
snprintf(inner_sql, sizeof(inner_sql), "%s", sql);
Expand All @@ -2063,9 +2096,17 @@ int cbm_store_search(cbm_store_t *s, const cbm_search_params_t *params, cbm_sear
} else if (params->min_degree >= 0) {
snprintf(sql, sizeof(sql), "SELECT * FROM (%s) WHERE (in_deg + out_deg) >= %d",
inner_sql, params->min_degree);
} else {
} else if (params->max_degree >= 0) {
snprintf(sql, sizeof(sql), "SELECT * FROM (%s) WHERE (in_deg + out_deg) <= %d",
inner_sql, params->max_degree);
} else {
/* Only exclude_entry_points is active */
snprintf(sql, sizeof(sql), "SELECT * FROM (%s) WHERE 1=1", inner_sql);
}

/* Exclude entry points: functions that call things but aren't called themselves */
if (params->exclude_entry_points) {
strncat(sql, " AND NOT (in_deg = 0 AND out_deg > 0)", sizeof(sql) - strlen(sql) - 1);
}
}

Expand Down Expand Up @@ -2129,6 +2170,54 @@ int cbm_store_search(cbm_store_t *s, const cbm_search_params_t *params, cbm_sear
sqlite3_finalize(main_stmt);
free(like_pattern);

/* Populate connected_names if requested.
* Capped to first 50 results to prevent N+1 blowup on large result sets. */
if (params->include_connected && n > 0) {
int conn_limit = n < 50 ? n : 50;
const char *edge_type = params->relationship ? params->relationship : "CALLS";

char conn_sql[512];
snprintf(conn_sql, sizeof(conn_sql),
"SELECT DISTINCT name FROM ("
" SELECT tn.name FROM edges e JOIN nodes tn ON tn.id = e.target_id"
" WHERE e.source_id = ?1 AND e.type = ?2 AND tn.id != ?1"
" UNION"
" SELECT sn.name FROM edges e JOIN nodes sn ON sn.id = e.source_id"
" WHERE e.target_id = ?1 AND e.type = ?2 AND sn.id != ?1"
") ORDER BY name LIMIT 10");

sqlite3_stmt *conn_stmt = NULL;
int conn_rc = sqlite3_prepare_v2(s->db, conn_sql, -1, &conn_stmt, NULL);
if (conn_rc == SQLITE_OK) {
for (int i = 0; i < conn_limit; i++) {
sqlite3_reset(conn_stmt);
sqlite3_bind_int64(conn_stmt, 1, results[i].node.id);
bind_text(conn_stmt, 2, edge_type);
int ccap = 4;
int cn = 0;
const char **names = NULL;
while (sqlite3_step(conn_stmt) == SQLITE_ROW) {
if (!names) {
names = malloc((size_t)ccap * sizeof(const char *));
}
if (cn >= ccap) {
ccap *= 2;
names = safe_realloc(names, (size_t)ccap * sizeof(const char *));
}
const char *cname = (const char *)sqlite3_column_text(conn_stmt, 0);
names[cn++] = cname ? heap_strdup(cname) : heap_strdup("");
}
if (cn > 0) {
results[i].connected_names = names;
results[i].connected_count = cn;
} else {
free(names);
}
}
sqlite3_finalize(conn_stmt);
}
}

out->results = results;
out->count = n;
return CBM_STORE_OK;
Expand Down
Loading