From 37b8d88a31581233b4faca0a5f59e66d9a31ace6 Mon Sep 17 00:00:00 2001 From: SuperAuguste <19855629+SuperAuguste@users.noreply.github.com> Date: Tue, 27 May 2025 13:43:15 -0400 Subject: [PATCH] implement workspace symbols Co-Authored-By: Techatrix --- src/DocumentStore.zig | 151 ++++- src/Server.zig | 21 +- src/TrigramStore.zig | 694 +++++++++++++++++++++++ src/Uri.zig | 5 + src/features/document_symbol.zig | 2 +- src/features/workspace_symbols.zig | 93 +++ src/zls.zig | 1 + tests/context.zig | 27 +- tests/lsp_features/workspace_symbols.zig | 101 ++++ tests/tests.zig | 1 + 10 files changed, 1090 insertions(+), 6 deletions(-) create mode 100644 src/TrigramStore.zig create mode 100644 src/features/workspace_symbols.zig create mode 100644 tests/lsp_features/workspace_symbols.zig diff --git a/src/DocumentStore.zig b/src/DocumentStore.zig index e243bf52b6..127246a292 100644 --- a/src/DocumentStore.zig +++ b/src/DocumentStore.zig @@ -14,6 +14,7 @@ const tracy = @import("tracy"); const translate_c = @import("translate_c.zig"); const DocumentScope = @import("DocumentScope.zig"); const DiagnosticsCollection = @import("DiagnosticsCollection.zig"); +const TrigramStore = @import("TrigramStore.zig"); const DocumentStore = @This(); @@ -171,6 +172,7 @@ pub const Handle = struct { /// or has been closed with `textDocument/didClose`. lsp_synced: bool, document_scope: Lazy(DocumentScope, DocumentStoreContext) = .unset, + trigram_store: Lazy(TrigramStore, TrigramStoreContext) = .unset, /// private field impl: struct { @@ -454,6 +456,8 @@ pub const Handle = struct { old_handle.document_scope = handle.document_scope; handle.document_scope = .unset; + old_handle.trigram_store = handle.trigram_store; + handle.trigram_store = .unset; } fn parseTree(allocator: std.mem.Allocator, new_text: [:0]const u8, mode: Ast.Mode) error{OutOfMemory}!Ast { @@ -485,7 +489,7 @@ pub const Handle = struct { const tracy_zone = tracy.trace(@src()); defer tracy_zone.end(); - const parsed_uri = std.Uri.parse(uri.raw) catch unreachable; // The Uri is guranteed to be valid + const parsed_uri = uri.toStdUri(); const node_tags = tree.nodes.items(.tag); for (node_tags, 0..) |tag, i| { @@ -564,6 +568,7 @@ pub const Handle = struct { self.tree.deinit(allocator); } self.document_scope.deinit(allocator); + self.trigram_store.deinit(allocator); for (self.file_imports) |uri| uri.deinit(allocator); allocator.free(self.file_imports); @@ -620,6 +625,17 @@ pub const Handle = struct { return &lazy.value.?; } + pub fn getOrNull(lazy: *LazyResource, handle: *Handle) ?*const T { + const tracy_zone = tracy.traceNamed(@src(), "Lazy(" ++ @typeName(T) ++ ").getOrNull"); + defer tracy_zone.end(); + + const store = handle.impl.store; + const io = store.io; + handle.impl.lock.lockUncancelable(io); + defer handle.impl.lock.unlock(io); + return if (lazy.value) |*value| value else null; + } + pub fn getCached(lazy: *LazyResource) *const T { return &lazy.value.?; } @@ -642,6 +658,15 @@ pub const Handle = struct { document_scope.deinit(allocator); } }; + + const TrigramStoreContext = struct { + fn create(handle: *Handle, allocator: std.mem.Allocator) error{OutOfMemory}!TrigramStore { + return try .init(allocator, &handle.tree); + } + fn deinit(trigram_store: *TrigramStore, allocator: std.mem.Allocator) void { + trigram_store.deinit(allocator); + } + }; }; pub const HandleIterator = struct { @@ -944,6 +969,128 @@ pub fn invalidateBuildFile(self: *DocumentStore, build_file_uri: Uri) void { self.wait_group.async(self.io, invalidateBuildFileWorker, .{ self, build_file }); } +const LoadDirectoryError = error{UnsupportedScheme} || std.mem.Allocator.Error || std.Io.Dir.OpenError; + +pub fn loadDirectoryRecursive(store: *DocumentStore, directory_uri: Uri) LoadDirectoryError!usize { + const tracy_zone = tracy.trace(@src()); + defer tracy_zone.end(); + + const workspace_path = try directory_uri.toFsPath(store.allocator); + defer store.allocator.free(workspace_path); + + var workspace_dir = try std.Io.Dir.cwd().openDir(store.io, workspace_path, .{ .iterate = true }); + defer workspace_dir.close(store.io); + + var walker = try workspace_dir.walk(store.allocator); + defer walker.deinit(); + + const getOrLoadHandleVoid = struct { + fn getOrLoadHandleVoid( + s: *DocumentStore, + uri: Uri, + did_out_of_memory: *std.atomic.Value(bool), + ) std.Io.Cancelable!void { + _ = s.getOrLoadHandle(uri) catch |err| switch (err) { + error.Canceled => return error.Canceled, + error.OutOfMemory => did_out_of_memory.store(true, .release), + }; + uri.deinit(s.allocator); + } + }.getOrLoadHandleVoid; + + var group: std.Io.Group = .init; + var did_out_of_memory: std.atomic.Value(bool) = .init(false); + + var file_count: usize = 0; + while (try walker.next(store.io)) |entry| { + if (entry.kind == .directory) { + // Keep in sync with `loadTrigramStores` + if (std.mem.startsWith(u8, entry.basename, ".") or + std.mem.eql(u8, entry.basename, "zig-cache") or + std.mem.eql(u8, entry.basename, "zig-pkg")) + { + walker.leave(store.io); + } + continue; + } + if (!std.mem.eql(u8, std.fs.path.extension(entry.basename), ".zig")) continue; + + file_count += 1; + + const path = try std.fs.path.join(store.allocator, &.{ workspace_path, entry.path }); + defer store.allocator.free(path); + + const uri: Uri = try .fromPath(store.allocator, path); + errdefer comptime unreachable; + + group.async(store.io, getOrLoadHandleVoid, .{ store, uri, &did_out_of_memory }); + } + try group.await(store.io); + + if (did_out_of_memory.load(.acquire)) return error.OutOfMemory; + + return file_count; +} + +pub fn loadTrigramStores( + store: *DocumentStore, + filter_uris: []const std.Uri, +) error{ OutOfMemory, Canceled }![]*DocumentStore.Handle { + const tracy_zone = tracy.trace(@src()); + defer tracy_zone.end(); + + var handles: std.ArrayList(*DocumentStore.Handle) = .empty; + errdefer handles.deinit(store.allocator); + + var it: HandleIterator = .{ .store = store }; + while (it.next()) |handle| { + const uri = handle.uri.toStdUri(); + + var component_it = std.fs.path.componentIterator(uri.path.percent_encoded); + const skip = while (component_it.next()) |component| { + // Keep in sync with `loadDirectoryRecursive` + if (std.mem.startsWith(u8, component.name, ".")) break true; + if (std.mem.eql(u8, component.name, "zig-cache")) break true; + if (std.mem.eql(u8, component.name, "zig-pkg")) break true; + } else false; + if (skip) continue; + + for (filter_uris) |filter_uri| { + if (!std.ascii.eqlIgnoreCase(uri.scheme, filter_uri.scheme)) continue; + if (std.mem.startsWith(u8, uri.path.percent_encoded, filter_uri.path.percent_encoded)) break; + } else continue; + try handles.append(store.allocator, handle); + } + + const loadTrigramStore = struct { + fn loadTrigramStore( + handle: *DocumentStore.Handle, + did_out_of_memory: *std.atomic.Value(bool), + ) void { + _ = handle.trigram_store.get(handle) catch |err| switch (err) { + error.OutOfMemory => { + did_out_of_memory.store(true, .release); + return; + }, + }; + } + }.loadTrigramStore; + + var group: std.Io.Group = .init; + var did_out_of_memory: std.atomic.Value(bool) = .init(false); + + for (handles.items) |handle| { + const has_trigram_store = handle.trigram_store.getOrNull(handle) != null; + if (has_trigram_store) continue; + group.async(store.io, loadTrigramStore, .{ handle, &did_out_of_memory }); + } + try group.await(store.io); + + if (did_out_of_memory.load(.acquire)) return error.OutOfMemory; + + return try handles.toOwnedSlice(store.allocator); +} + const progress_token = "buildProgressToken"; fn sendMessageToClient( @@ -1840,7 +1987,7 @@ pub fn uriFromImportStr( defer tracy_zone.end(); if (std.mem.endsWith(u8, import_str, ".zig") or std.mem.endsWith(u8, import_str, ".zon")) { - const parsed_uri = std.Uri.parse(handle.uri.raw) catch unreachable; // The Uri is guranteed to be valid + const parsed_uri = handle.uri.toStdUri(); return .{ .one = try Uri.resolveImport(allocator, handle.uri, parsed_uri, import_str) }; } diff --git a/src/Server.zig b/src/Server.zig index d79921d6e2..e72f01993e 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -560,7 +560,7 @@ fn initializeHandler(server: *Server, arena: std.mem.Allocator, request: types.I .documentRangeFormattingProvider = .{ .bool = false }, .foldingRangeProvider = .{ .bool = true }, .selectionRangeProvider = .{ .bool = true }, - .workspaceSymbolProvider = .{ .bool = false }, + .workspaceSymbolProvider = .{ .bool = true }, .workspace = .{ .workspaceFolders = .{ .supported = true, @@ -857,7 +857,6 @@ const Workspace = struct { fn addWorkspace(server: *Server, uri: Uri) error{ Canceled, OutOfMemory }!void { try server.workspaces.ensureUnusedCapacity(server.allocator, 1); server.workspaces.appendAssumeCapacity(try Workspace.init(server, uri)); - log.info("added Workspace Folder: {s}", .{uri.raw}); if (BuildOnSaveSupport.isSupportedComptime() and // Don't initialize build on save until initialization finished. @@ -870,6 +869,17 @@ fn addWorkspace(server: *Server, uri: Uri) error{ Canceled, OutOfMemory }!void { .restart = false, }); } + + const file_count = server.document_store.loadDirectoryRecursive(uri) catch |err| switch (err) { + error.Canceled, error.OutOfMemory => |e| return e, + error.UnsupportedScheme => return, // https://github.com/microsoft/language-server-protocol/issues/1264 + else => { + log.err("failed to load files in workspace '{s}': {}", .{ uri.raw, err }); + return; + }, + }; + + log.info("added Workspace Folder: {s} ({d} files)", .{ uri.raw, file_count }); } fn removeWorkspace(server: *Server, uri: Uri) void { @@ -1569,6 +1579,10 @@ fn selectionRangeHandler(server: *Server, arena: std.mem.Allocator, request: typ return try selection_range.generateSelectionRanges(arena, handle, request.positions, server.offset_encoding); } +fn workspaceSymbolHandler(server: *Server, arena: std.mem.Allocator, request: types.workspace.Symbol.Params) Error!?types.workspace.Symbol.Result { + return try @import("features/workspace_symbols.zig").handler(server, arena, request); +} + const HandledRequestParams = union(enum) { initialize: types.InitializeParams, shutdown, @@ -1592,6 +1606,7 @@ const HandledRequestParams = union(enum) { @"textDocument/codeAction": types.CodeAction.Params, @"textDocument/foldingRange": types.FoldingRange.Params, @"textDocument/selectionRange": types.SelectionRange.Params, + @"workspace/symbol": types.workspace.Symbol.Params, other: lsp.MethodWithParams, }; @@ -1636,6 +1651,7 @@ fn isBlockingMessage(msg: Message) bool { .@"textDocument/codeAction", .@"textDocument/foldingRange", .@"textDocument/selectionRange", + .@"workspace/symbol", => return false, .other => return false, }, @@ -1805,6 +1821,7 @@ pub fn sendRequestSync(server: *Server, arena: std.mem.Allocator, comptime metho .@"textDocument/codeAction" => try server.codeActionHandler(arena, params), .@"textDocument/foldingRange" => try server.foldingRangeHandler(arena, params), .@"textDocument/selectionRange" => try server.selectionRangeHandler(arena, params), + .@"workspace/symbol" => try server.workspaceSymbolHandler(arena, params), .other => return null, }; } diff --git a/src/TrigramStore.zig b/src/TrigramStore.zig new file mode 100644 index 0000000000..03f0c2c01f --- /dev/null +++ b/src/TrigramStore.zig @@ -0,0 +1,694 @@ +//! A per-file trigram store for workspace symbols. + +const std = @import("std"); +const ast = @import("ast.zig"); +const Ast = std.zig.Ast; +const assert = std.debug.assert; +const offsets = @import("offsets.zig"); + +pub const TrigramStore = @This(); + +pub const Trigram = [3]u8; + +pub const Declaration = struct { + pub const Index = enum(u32) { _ }; + + pub const Kind = enum { + variable, + constant, + field, + function, + test_function, + }; + + /// Either `.identifier` or `.string_literal`. + name: Ast.TokenIndex, + kind: Kind, +}; + +filter_buckets: ?[]CuckooFilter.Bucket, +trigram_to_declarations: std.AutoArrayHashMapUnmanaged(Trigram, std.ArrayList(Declaration.Index)), +declarations: std.MultiArrayList(Declaration), + +pub fn init( + allocator: std.mem.Allocator, + tree: *const Ast, +) error{OutOfMemory}!TrigramStore { + var store: TrigramStore = .{ + .filter_buckets = null, + .trigram_to_declarations = .empty, + .declarations = .empty, + }; + errdefer store.deinit(allocator); + + var walker: ast.Walker = try .init(allocator, tree, .root); + defer walker.deinit(allocator); + + var in_function_stack: std.ArrayList(bool) = try .initCapacity(allocator, 16); + defer in_function_stack.deinit(allocator); + + while (try walker.next(allocator, tree)) |entry| { + switch (entry) { + .open => |node| switch (tree.nodeTag(node)) { + .fn_decl => try in_function_stack.append(allocator, true), + .fn_proto, + .fn_proto_multi, + .fn_proto_one, + .fn_proto_simple, + => { + const fn_token = tree.nodeMainToken(node); + if (tree.tokenTag(fn_token + 1) != .identifier) continue; + + try store.appendDeclaration( + allocator, + tree, + fn_token + 1, + .function, + ); + }, + .test_decl => { + try in_function_stack.append(allocator, true); + const test_name_token = tree.nodeData(node).opt_token_and_node[0].unwrap() orelse continue; + + try store.appendDeclaration( + allocator, + tree, + test_name_token, + .test_function, + ); + }, + .container_decl, + .container_decl_trailing, + .container_decl_arg, + .container_decl_arg_trailing, + .container_decl_two, + .container_decl_two_trailing, + .tagged_union, + .tagged_union_trailing, + .tagged_union_enum_tag, + .tagged_union_enum_tag_trailing, + .tagged_union_two, + .tagged_union_two_trailing, + => try in_function_stack.append(allocator, false), + + .global_var_decl, + .local_var_decl, + .simple_var_decl, + .aligned_var_decl, + => { + const in_function = in_function_stack.getLastOrNull() orelse false; + if (in_function) continue; + + const main_token = tree.nodeMainToken(node); + + const kind: Declaration.Kind = switch (tree.tokenTag(main_token)) { + .keyword_var => .variable, + .keyword_const => .constant, + else => unreachable, + }; + + if (isVarDeclAlias(tree, node)) continue; + + try store.appendDeclaration( + allocator, + tree, + main_token + 1, + kind, + ); + }, + .container_field_init, + .container_field_align, + .container_field, + => { + const name_token = tree.nodeMainToken(node); + if (tree.tokenTag(name_token) != .identifier) continue; + + try store.appendDeclaration( + allocator, + tree, + name_token, + .field, + ); + }, + else => {}, + }, + .close => |node| switch (tree.nodeTag(node)) { + .fn_decl, .test_decl => assert(in_function_stack.pop().?), + .container_decl, + .container_decl_trailing, + .container_decl_arg, + .container_decl_arg_trailing, + .container_decl_two, + .container_decl_two_trailing, + .tagged_union, + .tagged_union_trailing, + .tagged_union_enum_tag, + .tagged_union_enum_tag_trailing, + .tagged_union_two, + .tagged_union_two_trailing, + => assert(!in_function_stack.pop().?), + else => {}, + }, + } + } + + const lists = store.trigram_to_declarations.values(); + var index: usize = 0; + while (index < lists.len) { + if (lists[index].items.len == 0) { + lists[index].deinit(allocator); + store.trigram_to_declarations.swapRemoveAt(index); + } else { + index += 1; + } + } + + const trigrams = store.trigram_to_declarations.keys(); + + if (trigrams.len > 0) { + var prng = std.Random.DefaultPrng.init(0); + + const filter_capacity = CuckooFilter.capacityForCount(trigrams.len) catch unreachable; + const buckets = try allocator.alloc(CuckooFilter.Bucket, filter_capacity); + errdefer comptime unreachable; + + const filter: CuckooFilter = .{ .buckets = buckets }; + filter.reset(); + + for (trigrams) |trigram| { + filter.append(prng.random(), trigram) catch |err| switch (err) { + error.EvictionFailed => { + // This should generally be quite rare. + allocator.free(buckets); + break; + }, + }; + } else { + store.filter_buckets = buckets; + } + } + + return store; +} + +pub fn deinit(store: *TrigramStore, allocator: std.mem.Allocator) void { + if (store.filter_buckets) |buckets| allocator.free(buckets); + for (store.trigram_to_declarations.values()) |*list| { + list.deinit(allocator); + } + store.trigram_to_declarations.deinit(allocator); + store.declarations.deinit(allocator); + store.* = undefined; +} + +/// Asserts `query.len >= 1`. Asserts declaration_buffer.items.len == 0. +pub fn declarationsForQuery( + store: *const TrigramStore, + allocator: std.mem.Allocator, + query: []const u8, + declaration_buffer: *std.ArrayList(Declaration.Index), +) error{OutOfMemory}!void { + assert(query.len >= 1); + assert(declaration_buffer.items.len == 0); + + if (store.filter_buckets) |buckets| { + const filter: CuckooFilter = .{ .buckets = buckets }; + var ti: TrigramIterator = .init(query); + while (ti.next()) |trigram| { + if (!filter.contains(trigram)) { + return; + } + } + } + + var ti: TrigramIterator = .init(query); + + const first = (store.trigram_to_declarations.get(ti.next() orelse return) orelse return).items; + + try declaration_buffer.resize(allocator, first.len * 2); + + var len = first.len; + @memcpy(declaration_buffer.items[0..len], first); + + while (ti.next()) |trigram| { + const old_len = len; + len = mergeIntersection( + (store.trigram_to_declarations.get(trigram) orelse { + declaration_buffer.clearRetainingCapacity(); + return; + }).items, + declaration_buffer.items[0..len], + declaration_buffer.items[len..], + ); + @memcpy(declaration_buffer.items[0..len], declaration_buffer.items[old_len..][0..len]); + declaration_buffer.shrinkRetainingCapacity(len * 2); + } + + declaration_buffer.shrinkRetainingCapacity(declaration_buffer.items.len / 2); +} + +fn appendDeclaration( + store: *TrigramStore, + allocator: std.mem.Allocator, + tree: *const Ast, + name_token: Ast.TokenIndex, + kind: Declaration.Kind, +) error{OutOfMemory}!void { + const raw_name = tree.tokenSlice(name_token); + + const strategy: enum { raw, smart }, const name = switch (tree.tokenTag(name_token)) { + .string_literal => .{ .raw, raw_name[1 .. raw_name.len - 1] }, + .identifier => if (std.mem.startsWith(u8, raw_name, "@")) + .{ .raw, raw_name[2 .. raw_name.len - 1] } + else + .{ .smart, raw_name }, + else => unreachable, + }; + + switch (strategy) { + .raw => { + if (name.len < 3) return; + for (0..name.len - 2) |index| { + var trigram = name[index..][0..3].*; + for (&trigram) |*char| char.* = std.ascii.toLower(char.*); + try store.appendOneTrigram(allocator, trigram); + } + }, + .smart => { + var it: TrigramIterator = .init(name); + while (it.next()) |trigram| { + try store.appendOneTrigram(allocator, trigram); + } + }, + } + + try store.declarations.append(allocator, .{ + .name = name_token, + .kind = kind, + }); +} + +fn appendOneTrigram( + store: *TrigramStore, + allocator: std.mem.Allocator, + trigram: Trigram, +) error{OutOfMemory}!void { + const declaration_index: Declaration.Index = @enumFromInt(store.declarations.len); + + const gop = try store.trigram_to_declarations.getOrPutValue(allocator, trigram, .empty); + + if (gop.value_ptr.getLastOrNull() != declaration_index) { + try gop.value_ptr.append(allocator, declaration_index); + } +} + +/// Check if the init expression is a sequence of field accesses +/// where the last field name matches the var decl name: +/// +/// ```zig +/// const Foo = a.Foo; // true +/// const Bar = a.b.Bar; // true +/// const Baz = a.Bar; // false +/// const Biz = 5; // false +/// ``` +fn isVarDeclAlias(tree: *const Ast, var_decl: Ast.Node.Index) bool { + const main_token = tree.nodeMainToken(var_decl); + + if (tree.tokenTag(main_token) != .keyword_const) return false; + const init_node = tree.fullVarDecl(var_decl).?.ast.init_node.unwrap() orelse return false; + + if (tree.nodeTag(init_node) != .field_access) return false; + + const lhs_node, const field_name_token = tree.nodeData(init_node).node_and_token; + const alias_name = offsets.identifierTokenToNameSlice(tree, main_token + 1); + const target_name = offsets.identifierTokenToNameSlice(tree, field_name_token); + if (!std.mem.eql(u8, alias_name, target_name)) return false; + + var current_node = lhs_node; + while (true) { + switch (tree.nodeTag(current_node)) { + .identifier => return true, + .field_access => current_node = tree.nodeData(current_node).node_and_token[0], + else => return false, + } + } +} + +/// Splits a symbol into trigrams with the following rules: +/// - ignore `_` symbol characters +/// - convert symbol characters to lowercase +/// - append `\x00` (null bytes) to the symbol if symbol length is not divisible by the trigram length +const TrigramIterator = struct { + symbol: []const u8, + index: usize, + + trigram_buffer: Trigram, + trigram_buffer_index: u2, + + pub fn init(symbol: []const u8) TrigramIterator { + assert(symbol.len != 0); + return .{ + .symbol = symbol, + .index = 0, + .trigram_buffer = @splat(0), + .trigram_buffer_index = 0, + }; + } + + pub fn next(ti: *TrigramIterator) ?Trigram { + while (ti.index < ti.symbol.len) { + defer ti.index += 1; + const c = std.ascii.toLower(ti.symbol[ti.index]); + if (c == '_') continue; + + if (ti.trigram_buffer_index < 3) { + ti.trigram_buffer[ti.trigram_buffer_index] = c; + ti.trigram_buffer_index += 1; + continue; + } + + defer { + @memmove(ti.trigram_buffer[0..2], ti.trigram_buffer[1..3]); + ti.trigram_buffer[2] = c; + } + return ti.trigram_buffer; + } else if (ti.trigram_buffer_index > 0) { + ti.trigram_buffer_index = 0; + return ti.trigram_buffer; + } else { + return null; + } + } +}; + +test TrigramIterator { + try testTrigramIterator("a", &.{"a\x00\x00".*}); + try testTrigramIterator("ab", &.{"ab\x00".*}); + try testTrigramIterator("abc", &.{"abc".*}); + + try testTrigramIterator("hello", &.{ "hel".*, "ell".*, "llo".* }); + try testTrigramIterator("HELLO", &.{ "hel".*, "ell".*, "llo".* }); + try testTrigramIterator("HellO", &.{ "hel".*, "ell".*, "llo".* }); + + try testTrigramIterator("a_", &.{"a\x00\x00".*}); + try testTrigramIterator("ab_", &.{"ab\x00".*}); + try testTrigramIterator("abc_", &.{"abc".*}); + + try testTrigramIterator("_a", &.{"a\x00\x00".*}); + try testTrigramIterator("_a_", &.{"a\x00\x00".*}); + try testTrigramIterator("_a__", &.{"a\x00\x00".*}); + + try testTrigramIterator("_", &.{}); + try testTrigramIterator("__", &.{}); + try testTrigramIterator("___", &.{}); + + try testTrigramIterator("He_ll_O", &.{ "hel".*, "ell".*, "llo".* }); + try testTrigramIterator("He__ll___O", &.{ "hel".*, "ell".*, "llo".* }); + try testTrigramIterator("__He__ll__O_", &.{ "hel".*, "ell".*, "llo".* }); + + try testTrigramIterator("HellO__World___HelloWorld", &.{ + "hel".*, "ell".*, "llo".*, + "low".*, "owo".*, "wor".*, + "orl".*, "rld".*, "ldh".*, + "dhe".*, "hel".*, "ell".*, + "llo".*, "low".*, "owo".*, + "wor".*, "orl".*, "rld".*, + }); +} + +fn testTrigramIterator( + input: []const u8, + expected: []const Trigram, +) !void { + const allocator = std.testing.allocator; + + var actual_buffer: std.ArrayList(Trigram) = .empty; + defer actual_buffer.deinit(allocator); + + var it: TrigramIterator = .init(input); + while (it.next()) |trigram| { + try actual_buffer.append(allocator, trigram); + } + + try @import("testing.zig").expectEqual(expected, actual_buffer.items); +} + +/// Asserts `@min(a.len, b.len) <= out.len`. +fn mergeIntersection( + a: []const Declaration.Index, + b: []const Declaration.Index, + out: []Declaration.Index, +) u32 { + assert(@min(a.len, b.len) <= out.len); + + var out_idx: u32 = 0; + + var a_idx: u32 = 0; + var b_idx: u32 = 0; + + while (a_idx < a.len and b_idx < b.len) { + const a_val = a[a_idx]; + const b_val = b[b_idx]; + + if (a_val == b_val) { + out[out_idx] = a_val; + out_idx += 1; + a_idx += 1; + b_idx += 1; + } else if (@intFromEnum(a_val) < @intFromEnum(b_val)) { + a_idx += 1; + } else { + b_idx += 1; + } + } + + return out_idx; +} + +const CuckooFilter = struct { + buckets: []Bucket, + + pub const Fingerprint = enum(u8) { + none = std.math.maxInt(u8), + _, + + const precomputed_odd_hashes = blk: { + var table: [255]u32 = undefined; + + for (&table, 0..) |*h, index| { + h.* = @truncate(std.hash.Murmur2_64.hash(&.{index}) | 1); + } + + break :blk table; + }; + + pub fn oddHash(fingerprint: Fingerprint) u32 { + assert(fingerprint != .none); + return precomputed_odd_hashes[@intFromEnum(fingerprint)]; + } + }; + + pub const Bucket = [4]Fingerprint; + pub const BucketIndex = enum(u32) { + _, + + pub fn alternate(index: BucketIndex, fingerprint: Fingerprint, len: u32) BucketIndex { + assert(@intFromEnum(index) < len); + assert(fingerprint != .none); + + const signed_index: i64 = @intFromEnum(index); + const odd_hash: i64 = fingerprint.oddHash(); + + const unbounded = switch (parity(signed_index)) { + .even => signed_index + odd_hash, + .odd => signed_index - odd_hash, + }; + const bounded: u32 = @intCast(@mod(unbounded, len)); + + assert(parity(signed_index) != parity(bounded)); + + return @enumFromInt(bounded); + } + }; + + pub const Triplet = struct { + fingerprint: Fingerprint, + index_1: BucketIndex, + index_2: BucketIndex, + + pub fn initFromTrigram(trigram: Trigram, len: u32) Triplet { + const split: packed struct { + fingerprint: Fingerprint, + padding: u24, + index_1: u32, + } = @bitCast(std.hash.Murmur2_64.hash(&trigram)); + + const index_1: BucketIndex = @enumFromInt(split.index_1 % len); + + const fingerprint: Fingerprint = if (split.fingerprint == .none) + @enumFromInt(1) + else + split.fingerprint; + + const triplet: Triplet = .{ + .fingerprint = fingerprint, + .index_1 = index_1, + .index_2 = index_1.alternate(fingerprint, len), + }; + assert(triplet.index_2.alternate(fingerprint, len) == index_1); + + return triplet; + } + }; + + pub fn init(buckets: []Bucket) CuckooFilter { + assert(parity(buckets.len) == .even); + return .{ .buckets = buckets }; + } + + pub fn reset(filter: CuckooFilter) void { + @memset(filter.buckets, @splat(.none)); + } + + pub fn capacityForCount(count: usize) error{Overflow}!usize { + const overallocated_count = std.math.divCeil(usize, try std.math.mul(usize, count, 105), 100) catch |err| switch (err) { + error.DivisionByZero => unreachable, + else => |e| return e, + }; + return overallocated_count + (overallocated_count & 1); + } + + pub fn append(filter: CuckooFilter, random: std.Random, trigram: Trigram) error{EvictionFailed}!void { + const triplet: Triplet = .initFromTrigram(trigram, @intCast(filter.buckets.len)); + + if (filter.appendToBucket(triplet.index_1, triplet.fingerprint) or + filter.appendToBucket(triplet.index_2, triplet.fingerprint)) + { + return; + } + + var fingerprint = triplet.fingerprint; + var index = if (random.boolean()) triplet.index_1 else triplet.index_2; + for (0..500) |_| { + fingerprint = filter.swapFromBucket(random, index, fingerprint); + index = index.alternate(fingerprint, @intCast(filter.buckets.len)); + + if (filter.appendToBucket(index, fingerprint)) { + return; + } + } + + return error.EvictionFailed; + } + + fn bucketAt(filter: CuckooFilter, index: BucketIndex) *Bucket { + return &filter.buckets[@intFromEnum(index)]; + } + + fn appendToBucket(filter: CuckooFilter, index: BucketIndex, fingerprint: Fingerprint) bool { + assert(fingerprint != .none); + + const bucket = filter.bucketAt(index); + for (bucket) |*slot| { + if (slot.* == .none) { + slot.* = fingerprint; + return true; + } + } + + return false; + } + + fn swapFromBucket( + filter: CuckooFilter, + random: std.Random, + index: BucketIndex, + fingerprint: Fingerprint, + ) Fingerprint { + assert(fingerprint != .none); + + comptime assert(@typeInfo(Bucket).array.len == 4); + const target = &filter.bucketAt(index)[random.int(u2)]; + + const old_fingerprint = target.*; + assert(old_fingerprint != .none); + + target.* = fingerprint; + + return old_fingerprint; + } + + pub fn contains(filter: CuckooFilter, trigram: Trigram) bool { + const triplet: Triplet = .initFromTrigram(trigram, @intCast(filter.buckets.len)); + + return filter.containsInBucket(triplet.index_1, triplet.fingerprint) or + filter.containsInBucket(triplet.index_2, triplet.fingerprint); + } + + fn containsInBucket(filter: CuckooFilter, index: BucketIndex, fingerprint: Fingerprint) bool { + assert(fingerprint != .none); + + const bucket = filter.bucketAt(index); + for (bucket) |*slot| { + if (slot.* == fingerprint) { + return true; + } + } + + return false; + } + + fn parity(integer: anytype) enum(u1) { even, odd } { + return @enumFromInt(integer & 1); + } +}; + +test CuckooFilter { + const allocator = std.testing.allocator; + + const element_count = 499; + const filter_size = comptime CuckooFilter.capacityForCount(element_count) catch unreachable; + + var entries: std.AutoArrayHashMapUnmanaged(Trigram, void) = .empty; + defer entries.deinit(allocator); + try entries.ensureTotalCapacity(allocator, element_count); + + var buckets: [filter_size]CuckooFilter.Bucket = undefined; + var filter: CuckooFilter = .init(&buckets); + var filter_prng: std.Random.DefaultPrng = .init(42); + + for (0..2_500) |gen_prng_seed| { + entries.clearRetainingCapacity(); + filter.reset(); + + var gen_prng: std.Random.DefaultPrng = .init(gen_prng_seed); + for (0..element_count) |_| { + const trigram: Trigram = @bitCast(gen_prng.random().int(u24)); + entries.putAssumeCapacity(trigram, {}); + try filter.append(filter_prng.random(), trigram); + } + + // No false negatives + for (entries.keys()) |trigram| { + try std.testing.expect(filter.contains(trigram)); + } + + // Reasonable false positive rate + const fpr_count = 2_500; + var false_positives: usize = 0; + var negative_prng: std.Random.DefaultPrng = .init(~gen_prng_seed); + for (0..fpr_count) |_| { + var trigram: Trigram = @bitCast(negative_prng.random().int(u24)); + while (entries.contains(trigram)) { + trigram = @bitCast(negative_prng.random().int(u24)); + } + + false_positives += @intFromBool(filter.contains(trigram)); + } + + const fpr = @as(f32, @floatFromInt(false_positives)) / fpr_count; + + errdefer std.log.err("fpr: {d}%", .{fpr * 100}); + try std.testing.expect(fpr < 0.035); + } +} diff --git a/src/Uri.zig b/src/Uri.zig index 26f6fa9943..c4d3d154e8 100644 --- a/src/Uri.zig +++ b/src/Uri.zig @@ -157,6 +157,11 @@ pub fn eql(a: Uri, b: Uri) bool { return std.mem.eql(u8, a.raw, b.raw); } +pub fn toStdUri(uri: Uri) std.Uri { + // The Uri is guranteed to be valid + return std.Uri.parse(uri.raw) catch unreachable; +} + pub const format = @compileError("Cannot format @import(\"Uri.zig\") directly!. Access the underlying raw string field instead."); pub const jsonStringify = @compileError("Cannot stringify @import(\"Uri.zig\") directly!. Access the underlying raw string field instead."); diff --git a/src/features/document_symbol.zig b/src/features/document_symbol.zig index 02e1dd7ff5..4cb7d7ff4f 100644 --- a/src/features/document_symbol.zig +++ b/src/features/document_symbol.zig @@ -18,7 +18,7 @@ const Symbol = struct { children: std.ArrayList(Symbol), }; -fn tokenNameMaybeQuotes(tree: *const Ast, token: Ast.TokenIndex) []const u8 { +pub fn tokenNameMaybeQuotes(tree: *const Ast, token: Ast.TokenIndex) []const u8 { const token_slice = tree.tokenSlice(token); switch (tree.tokenTag(token)) { .identifier => return token_slice, diff --git a/src/features/workspace_symbols.zig b/src/features/workspace_symbols.zig new file mode 100644 index 0000000000..c3bed53b6a --- /dev/null +++ b/src/features/workspace_symbols.zig @@ -0,0 +1,93 @@ +//! Implementation of [`workspace/symbol`](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#workspace_symbol) + +const std = @import("std"); + +const lsp = @import("lsp"); +const types = lsp.types; + +const DocumentStore = @import("../DocumentStore.zig"); +const offsets = @import("../offsets.zig"); +const Server = @import("../Server.zig"); +const TrigramStore = @import("../TrigramStore.zig"); +const Uri = @import("../Uri.zig"); + +pub fn handler(server: *Server, arena: std.mem.Allocator, request: types.workspace.Symbol.Params) error{ OutOfMemory, Canceled }!?types.workspace.Symbol.Result { + if (request.query.len == 0) return null; + + var workspace_uris: std.ArrayList(std.Uri) = try .initCapacity(arena, server.workspaces.items.len); + defer workspace_uris.deinit(arena); + + for (server.workspaces.items) |workspace| { + workspace_uris.appendAssumeCapacity(workspace.uri.toStdUri()); + } + + const handles = try server.document_store.loadTrigramStores(workspace_uris.items); + defer server.document_store.allocator.free(handles); + + var symbols: std.ArrayList(types.workspace.Symbol) = .empty; + var declaration_buffer: std.ArrayList(TrigramStore.Declaration.Index) = .empty; + + for (handles) |handle| { + const trigram_store = handle.trigram_store.getCached(); + + declaration_buffer.clearRetainingCapacity(); + try trigram_store.declarationsForQuery(arena, request.query, &declaration_buffer); + + const SortContext = struct { + names: []const std.zig.Ast.TokenIndex, + fn lessThan(ctx: @This(), lhs: TrigramStore.Declaration.Index, rhs: TrigramStore.Declaration.Index) bool { + return ctx.names[@intFromEnum(lhs)] < ctx.names[@intFromEnum(rhs)]; + } + }; + + std.mem.sortUnstable( + TrigramStore.Declaration.Index, + declaration_buffer.items, + SortContext{ .names = trigram_store.declarations.items(.name) }, + SortContext.lessThan, + ); + + const slice = trigram_store.declarations.slice(); + const names = slice.items(.name); + const kinds = slice.items(.kind); + + var last_index: usize = 0; + var last_position: offsets.Position = .{ .line = 0, .character = 0 }; + + try symbols.ensureUnusedCapacity(arena, declaration_buffer.items.len); + for (declaration_buffer.items) |declaration| { + const name_token = names[@intFromEnum(declaration)]; + const kind = kinds[@intFromEnum(declaration)]; + + const loc = offsets.tokenToLoc(&handle.tree, name_token); + const name = @import("document_symbol.zig").tokenNameMaybeQuotes(&handle.tree, name_token); + + const start_position = offsets.advancePosition(handle.tree.source, last_position, last_index, loc.start, server.offset_encoding); + const end_position = offsets.advancePosition(handle.tree.source, start_position, loc.start, loc.end, server.offset_encoding); + last_index = loc.end; + last_position = end_position; + + symbols.appendAssumeCapacity(.{ + .name = name, + .kind = switch (kind) { + .variable => .Variable, + .constant => .Constant, + .field => .Field, + .function => .Function, + .test_function => .Method, // there is no SymbolKind that represents a tests, + }, + .location = .{ + .location = .{ + .uri = handle.uri.raw, + .range = .{ + .start = start_position, + .end = end_position, + }, + }, + }, + }); + } + } + + return .{ .workspace_symbols = symbols.items }; +} diff --git a/src/zls.zig b/src/zls.zig index 864bcb37fa..7f244c28d6 100644 --- a/src/zls.zig +++ b/src/zls.zig @@ -18,6 +18,7 @@ pub const Server = @import("Server.zig"); pub const snippets = @import("snippets.zig"); pub const testing = @import("testing.zig"); pub const translate_c = @import("translate_c.zig"); +pub const TrigramStore = @import("TrigramStore.zig"); pub const Uri = @import("Uri.zig"); pub const code_actions = @import("features/code_actions.zig"); diff --git a/tests/context.zig b/tests/context.zig index 55c493460d..07586ed83c 100644 --- a/tests/context.zig +++ b/tests/context.zig @@ -84,9 +84,13 @@ pub const Context = struct { pub fn addDocument(self: *Context, options: struct { source: []const u8, mode: std.zig.Ast.Mode = .zig, + base_directory: []const u8 = "/", }) !zls.Uri { + std.debug.assert(std.mem.startsWith(u8, options.base_directory, "/")); + std.debug.assert(std.mem.endsWith(u8, options.base_directory, "/")); + const arena = self.arena.allocator(); - const path = try std.fmt.allocPrint(arena, "untitled:///Untitled-{d}.{t}", .{ self.file_id, options.mode }); + const path = try std.fmt.allocPrint(arena, "untitled://{s}Untitled-{d}.{t}", .{ options.base_directory, self.file_id, options.mode }); const uri: zls.Uri = try .parse(arena, path); const params: types.TextDocument.DidOpenParams = .{ @@ -103,4 +107,25 @@ pub const Context = struct { self.file_id += 1; return uri; } + + pub fn addWorkspace(self: *Context, name: []const u8, base_directory: []const u8) !void { + std.debug.assert(std.mem.startsWith(u8, base_directory, "/")); + std.debug.assert(std.mem.endsWith(u8, base_directory, "/")); + + try self.server.sendNotificationSync( + self.arena.allocator(), + "workspace/didChangeWorkspaceFolders", + .{ + .event = .{ + .added = &.{ + .{ + .uri = try std.fmt.allocPrint(self.arena.allocator(), "untitled:{s}", .{base_directory}), + .name = name, + }, + }, + .removed = &.{}, + }, + }, + ); + } }; diff --git a/tests/lsp_features/workspace_symbols.zig b/tests/lsp_features/workspace_symbols.zig new file mode 100644 index 0000000000..a3ceeb1cb0 --- /dev/null +++ b/tests/lsp_features/workspace_symbols.zig @@ -0,0 +1,101 @@ +const std = @import("std"); +const zls = @import("zls"); + +const Context = @import("../context.zig").Context; + +const types = zls.lsp.types; + +const allocator: std.mem.Allocator = std.testing.allocator; + +test "workspace symbols" { + var ctx: Context = try .init(); + defer ctx.deinit(); + + try ctx.addWorkspace("Animal Shelter", "/animal_shelter/"); + + _ = try ctx.addDocument(.{ .source = + \\const SalamanderCrab = struct { + \\ fn salamander_crab() void {} + \\}; + , .base_directory = "/animal_shelter/" }); + + _ = try ctx.addDocument(.{ .source = + \\const Dog = struct { + \\ const sheltie: Dog = .{}; + \\ var @"Mr Crabs" = @compileError("hold up"); + \\}; + \\test "walk the dog" { + \\ const dog: Dog = .sheltie; + \\ _ = dog; // nah + \\} + , .base_directory = "/animal_shelter/" }); + + _ = try ctx.addDocument(.{ .source = + \\const Lion = struct { + \\ extern fn evolveToMonke() void; + \\ fn roar() void { + \\ var lion = "cool!"; + \\ const Lion2 = struct { + \\ const lion_for_real = 0; + \\ }; + \\ } + \\}; + , .base_directory = "/animal_shelter/" }); + + _ = try ctx.addDocument(.{ .source = + \\const PotatoDoctor = struct {}; + , .base_directory = "/farm/" }); + + try testDocumentSymbol(&ctx, "Sal", + \\Constant SalamanderCrab + \\Function salamander_crab + ); + try testDocumentSymbol(&ctx, "_cr___a_b_", + \\Constant SalamanderCrab + \\Function salamander_crab + \\Variable @"Mr Crabs" + ); + try testDocumentSymbol(&ctx, "dog", + \\Constant Dog + \\Method walk the dog + ); + try testDocumentSymbol(&ctx, "potato_d", ""); + // Becomes S\x00\x00 which matches nothing + try testDocumentSymbol(&ctx, "S", ""); + try testDocumentSymbol(&ctx, "lion", + \\Constant Lion + \\Constant lion_for_real + ); + try testDocumentSymbol(&ctx, "monke", + \\Function evolveToMonke + ); +} + +fn testDocumentSymbol(ctx: *Context, query: []const u8, expected: []const u8) !void { + const response = try ctx.server.sendRequestSync( + ctx.arena.allocator(), + "workspace/symbol", + .{ .query = query }, + ) orelse { + std.debug.print("Server returned `null` as the result\n", .{}); + return error.InvalidResponse; + }; + + var actual: std.ArrayList(u8) = .empty; + defer actual.deinit(allocator); + + for (response.workspace_symbols) |workspace_symbol| { + std.debug.assert(workspace_symbol.tags == null); // unsupported for now + std.debug.assert(workspace_symbol.containerName == null); // unsupported for now + try actual.print(allocator, "{t} {s}\n", .{ + workspace_symbol.kind, + workspace_symbol.name, + }); + } + + if (actual.items.len != 0) { + _ = actual.pop(); // Final \n + } + + try zls.testing.expectEqualStrings(expected, actual.items); +} diff --git a/tests/tests.zig b/tests/tests.zig index ffe7f37d1c..76307fd058 100644 --- a/tests/tests.zig +++ b/tests/tests.zig @@ -22,6 +22,7 @@ comptime { _ = @import("lsp_features/selection_range.zig"); _ = @import("lsp_features/semantic_tokens.zig"); _ = @import("lsp_features/signature_help.zig"); + _ = @import("lsp_features/workspace_symbols.zig"); // Language features _ = @import("language_features/cimport.zig");