diff --git a/mongo/catalog/catalog.go b/mongo/catalog/catalog.go new file mode 100644 index 00000000..f98aeef6 --- /dev/null +++ b/mongo/catalog/catalog.go @@ -0,0 +1,29 @@ +// Package catalog provides metadata context for MongoDB auto-completion. +package catalog + +import "sort" + +// Catalog stores collection names for a database context. +type Catalog struct { + collections map[string]struct{} +} + +// New creates an empty Catalog. +func New() *Catalog { + return &Catalog{collections: make(map[string]struct{})} +} + +// AddCollection registers a collection name. Duplicates are ignored. +func (c *Catalog) AddCollection(name string) { + c.collections[name] = struct{}{} +} + +// Collections returns all registered collection names in sorted order. +func (c *Catalog) Collections() []string { + result := make([]string, 0, len(c.collections)) + for name := range c.collections { + result = append(result, name) + } + sort.Strings(result) + return result +} diff --git a/mongo/catalog/catalog_test.go b/mongo/catalog/catalog_test.go new file mode 100644 index 00000000..e9f700b7 --- /dev/null +++ b/mongo/catalog/catalog_test.go @@ -0,0 +1,52 @@ +package catalog_test + +import ( + "slices" + "testing" + + "github.com/bytebase/omni/mongo/catalog" +) + +func TestNewCatalogEmpty(t *testing.T) { + cat := catalog.New() + if got := cat.Collections(); len(got) != 0 { + t.Errorf("new catalog Collections() = %v, want empty", got) + } +} + +func TestAddAndListCollections(t *testing.T) { + cat := catalog.New() + cat.AddCollection("users") + cat.AddCollection("orders") + + got := cat.Collections() + want := []string{"orders", "users"} + if !slices.Equal(got, want) { + t.Errorf("Collections() = %v, want %v", got, want) + } +} + +func TestAddCollectionDedup(t *testing.T) { + cat := catalog.New() + cat.AddCollection("users") + cat.AddCollection("users") + cat.AddCollection("users") + + got := cat.Collections() + if len(got) != 1 { + t.Errorf("Collections() = %v, want 1 entry", got) + } +} + +func TestCollectionsSortOrder(t *testing.T) { + cat := catalog.New() + cat.AddCollection("zebra") + cat.AddCollection("alpha") + cat.AddCollection("middle") + + got := cat.Collections() + want := []string{"alpha", "middle", "zebra"} + if !slices.Equal(got, want) { + t.Errorf("Collections() = %v, want %v", got, want) + } +} diff --git a/mongo/completion/candidates.go b/mongo/completion/candidates.go new file mode 100644 index 00000000..a7481844 --- /dev/null +++ b/mongo/completion/candidates.go @@ -0,0 +1,269 @@ +package completion + +import "github.com/bytebase/omni/mongo/catalog" + +// candidatesForContext returns the raw candidate list for a given context, +// optionally enriched by the catalog. +func candidatesForContext(ctx completionContext, cat *catalog.Catalog) []Candidate { + switch ctx { + case contextTopLevel: + return topLevelCandidates() + case contextAfterDbDot: + return afterDbDotCandidates(cat) + case contextAfterCollDot: + return collectionMethodCandidates() + case contextAfterBracket: + return bracketCandidates(cat) + case contextCursorChain: + return cursorMethodCandidates() + case contextShowTarget: + return showTargetCandidates() + case contextAfterRsDot: + return rsMethodCandidates() + case contextAfterShDot: + return shMethodCandidates() + case contextAggStage: + return aggStageCandidates() + case contextQueryOperator: + return queryOperatorCandidates() + case contextInsideArgs: + return insideArgsCandidates() + case contextDocumentKey: + return documentKeyCandidates() + default: + return nil + } +} + +func topLevelCandidates() []Candidate { + keywords := []string{ + "db", "rs", "sh", "sp", "show", + "sleep", "load", "print", "printjson", + "quit", "exit", "help", "it", "cls", "version", + } + candidates := make([]Candidate, 0, len(keywords)+len(bsonHelpers)) + for _, kw := range keywords { + candidates = append(candidates, Candidate{Text: kw, Type: CandidateKeyword}) + } + for _, h := range bsonHelpers { + candidates = append(candidates, Candidate{Text: h, Type: CandidateBSONHelper}) + } + return candidates +} + +func afterDbDotCandidates(cat *catalog.Catalog) []Candidate { + var candidates []Candidate + if cat != nil { + for _, name := range cat.Collections() { + candidates = append(candidates, Candidate{Text: name, Type: CandidateCollection}) + } + } + for _, m := range dbMethods { + candidates = append(candidates, Candidate{Text: m, Type: CandidateDbMethod}) + } + return candidates +} + +func collectionMethodCandidates() []Candidate { + candidates := make([]Candidate, 0, len(collectionMethods)) + for _, m := range collectionMethods { + candidates = append(candidates, Candidate{Text: m, Type: CandidateMethod}) + } + return candidates +} + +func bracketCandidates(cat *catalog.Catalog) []Candidate { + if cat == nil { + return nil + } + var candidates []Candidate + for _, name := range cat.Collections() { + candidates = append(candidates, Candidate{Text: name, Type: CandidateCollection}) + } + return candidates +} + +func cursorMethodCandidates() []Candidate { + candidates := make([]Candidate, 0, len(cursorMethods)) + for _, m := range cursorMethods { + candidates = append(candidates, Candidate{Text: m, Type: CandidateCursorMethod}) + } + return candidates +} + +func showTargetCandidates() []Candidate { + candidates := make([]Candidate, 0, len(showTargets)) + for _, t := range showTargets { + candidates = append(candidates, Candidate{Text: t, Type: CandidateShowTarget}) + } + return candidates +} + +func rsMethodCandidates() []Candidate { + candidates := make([]Candidate, 0, len(rsMethods)) + for _, m := range rsMethods { + candidates = append(candidates, Candidate{Text: m, Type: CandidateRsMethod}) + } + return candidates +} + +func shMethodCandidates() []Candidate { + candidates := make([]Candidate, 0, len(shMethods)) + for _, m := range shMethods { + candidates = append(candidates, Candidate{Text: m, Type: CandidateShMethod}) + } + return candidates +} + +func aggStageCandidates() []Candidate { + candidates := make([]Candidate, 0, len(aggStages)) + for _, s := range aggStages { + candidates = append(candidates, Candidate{Text: s, Type: CandidateAggStage}) + } + return candidates +} + +func queryOperatorCandidates() []Candidate { + candidates := make([]Candidate, 0, len(queryOperators)) + for _, op := range queryOperators { + candidates = append(candidates, Candidate{Text: op, Type: CandidateQueryOperator}) + } + return candidates +} + +func insideArgsCandidates() []Candidate { + literals := []string{"true", "false", "null"} + candidates := make([]Candidate, 0, len(bsonHelpers)+len(literals)) + for _, h := range bsonHelpers { + candidates = append(candidates, Candidate{Text: h, Type: CandidateBSONHelper}) + } + for _, l := range literals { + candidates = append(candidates, Candidate{Text: l, Type: CandidateKeyword}) + } + return candidates +} + +func documentKeyCandidates() []Candidate { + candidates := queryOperatorCandidates() + candidates = append(candidates, insideArgsCandidates()...) + return candidates +} + +// --- Hardcoded candidate lists --- + +var bsonHelpers = []string{ + "ObjectId", "NumberLong", "NumberInt", "NumberDecimal", + "Timestamp", "Date", "ISODate", "UUID", + "MD5", "HexData", "BinData", "Code", + "DBRef", "MinKey", "MaxKey", "RegExp", "Symbol", +} + +var collectionMethods = []string{ + "find", "findOne", "findOneAndDelete", "findOneAndReplace", "findOneAndUpdate", + "insertOne", "insertMany", + "updateOne", "updateMany", + "deleteOne", "deleteMany", + "replaceOne", "bulkWrite", + "aggregate", + "count", "countDocuments", "estimatedDocumentCount", + "distinct", "mapReduce", "watch", + "createIndex", "createIndexes", + "dropIndex", "dropIndexes", "getIndexes", "reIndex", + "drop", "renameCollection", + "stats", "dataSize", "storageSize", "totalSize", "totalIndexSize", + "validate", "explain", + "getShardDistribution", "latencyStats", + "getPlanCache", + "initializeOrderedBulkOp", "initializeUnorderedBulkOp", +} + +var cursorMethods = []string{ + "sort", "limit", "skip", + "toArray", "forEach", "map", + "hasNext", "next", "itcount", "size", + "pretty", "hint", "min", "max", + "readPref", "comment", "batchSize", "close", + "collation", "noCursorTimeout", "allowPartialResults", + "returnKey", "showRecordId", "allowDiskUse", + "maxTimeMS", "readConcern", "writeConcern", + "tailable", "oplogReplay", "projection", +} + +var dbMethods = []string{ + "getName", "getSiblingDB", "getMongo", + "getCollectionNames", "getCollectionInfos", "getCollection", + "createCollection", "createView", + "dropDatabase", + "adminCommand", "runCommand", + "getProfilingStatus", "setProfilingLevel", + "getLogComponents", "setLogLevel", + "fsyncLock", "fsyncUnlock", + "currentOp", "killOp", + "getUser", "getUsers", "createUser", "updateUser", + "dropUser", "dropAllUsers", + "grantRolesToUser", "revokeRolesFromUser", + "getRole", "getRoles", "createRole", "updateRole", + "dropRole", "dropAllRoles", + "grantPrivilegesToRole", "revokePrivilegesFromRole", + "grantRolesToRole", "revokeRolesFromRole", + "serverStatus", "isMaster", "hello", "hostInfo", +} + +var showTargets = []string{ + "dbs", "databases", "collections", "tables", + "profile", "users", "roles", + "log", "logs", "startupWarnings", +} + +var rsMethods = []string{ + "status", "conf", "config", + "initiate", "reconfig", + "add", "addArb", + "stepDown", "freeze", + "slaveOk", "secondaryOk", + "syncFrom", + "printReplicationInfo", "printSecondaryReplicationInfo", +} + +var shMethods = []string{ + "addShard", "addShardTag", "addShardToZone", "addTagRange", + "disableAutoSplit", "enableAutoSplit", + "enableSharding", "disableBalancing", "enableBalancing", + "getBalancerState", "isBalancerRunning", + "moveChunk", + "removeRangeFromZone", "removeShard", "removeShardTag", "removeShardFromZone", + "setBalancerState", "shardCollection", + "splitAt", "splitFind", + "startBalancer", "stopBalancer", + "updateZoneKeyRange", + "status", +} + +var aggStages = []string{ + "$match", "$group", "$project", "$sort", "$limit", "$skip", + "$unwind", "$lookup", "$graphLookup", + "$addFields", "$set", "$unset", + "$out", "$merge", + "$bucket", "$bucketAuto", "$facet", + "$replaceRoot", "$replaceWith", + "$sample", "$count", "$redact", + "$geoNear", "$setWindowFields", "$fill", "$densify", + "$unionWith", + "$collStats", "$indexStats", "$planCacheStats", + "$search", "$searchMeta", "$changeStream", +} + +var queryOperators = []string{ + // Comparison + "$eq", "$ne", "$gt", "$gte", "$lt", "$lte", "$in", "$nin", + // Logical + "$and", "$or", "$not", "$nor", + // Element + "$exists", "$type", + // Evaluation + "$regex", "$expr", "$mod", "$text", "$where", "$jsonSchema", + // Array + "$all", "$elemMatch", "$size", + // Geospatial + "$geoWithin", "$geoIntersects", "$near", "$nearSphere", +} diff --git a/mongo/completion/completion.go b/mongo/completion/completion.go new file mode 100644 index 00000000..81190a9f --- /dev/null +++ b/mongo/completion/completion.go @@ -0,0 +1,103 @@ +// Package completion provides auto-complete for MongoDB shell (mongosh) commands. +package completion + +import ( + "strings" + + "github.com/bytebase/omni/mongo/catalog" + "github.com/bytebase/omni/mongo/parser" +) + +// CandidateType classifies a completion candidate. +type CandidateType int + +const ( + CandidateKeyword CandidateType = iota // top-level keywords (db, rs, sh, show, ...) + CandidateCollection // collection name from catalog + CandidateMethod // collection method (find, insertOne, ...) + CandidateCursorMethod // cursor modifier (sort, limit, ...) + CandidateAggStage // aggregation stage ($match, $group, ...) + CandidateQueryOperator // query operator ($gt, $in, ...) + CandidateBSONHelper // BSON constructor (ObjectId, NumberLong, ...) + CandidateShowTarget // show command target (dbs, collections, ...) + CandidateDbMethod // database method (getName, runCommand, ...) + CandidateRsMethod // replica set method (status, conf, ...) + CandidateShMethod // sharding method (addShard, status, ...) +) + +// Candidate is a single completion suggestion. +type Candidate struct { + Text string // the completion text + Type CandidateType // what kind of object this is + Definition string // optional definition/signature + Comment string // optional doc comment +} + +// Complete returns completion candidates for the given mongosh input at the cursor offset. +// cat may be nil if no catalog context is available. +func Complete(input string, cursorOffset int, cat *catalog.Catalog) []Candidate { + if cursorOffset > len(input) { + cursorOffset = len(input) + } + + prefix := extractPrefix(input, cursorOffset) + tokens := tokenize(input, cursorOffset-len(prefix)) + ctx := detectContext(tokens) + candidates := candidatesForContext(ctx, cat) + + return filterByPrefix(candidates, prefix) +} + +// tokenize lexes input up to the given byte offset and returns all tokens. +func tokenize(input string, limit int) []parser.Token { + if limit > len(input) { + limit = len(input) + } + if limit < 0 { + limit = 0 + } + lex := parser.NewLexer(input[:limit]) + var tokens []parser.Token + for { + tok := lex.NextToken() + if tok.Type == parser.TokEOF { + break + } + tokens = append(tokens, tok) + } + return tokens +} + +// extractPrefix returns the partial token the user is typing at cursorOffset. +// Includes $ as a valid prefix character (for $match, $gt, etc.). +func extractPrefix(input string, cursorOffset int) string { + if cursorOffset > len(input) { + cursorOffset = len(input) + } + i := cursorOffset + for i > 0 { + c := input[i-1] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_' || c == '$' { + i-- + } else { + break + } + } + return input[i:cursorOffset] +} + +// filterByPrefix filters candidates whose Text starts with prefix. +// Matching is case-sensitive (mongosh is case-sensitive). +func filterByPrefix(candidates []Candidate, prefix string) []Candidate { + if prefix == "" { + return candidates + } + var result []Candidate + for _, c := range candidates { + if strings.HasPrefix(c.Text, prefix) { + result = append(result, c) + } + } + return result +} diff --git a/mongo/completion/completion_test.go b/mongo/completion/completion_test.go new file mode 100644 index 00000000..293c12fc --- /dev/null +++ b/mongo/completion/completion_test.go @@ -0,0 +1,581 @@ +package completion + +import ( + "testing" + + "github.com/bytebase/omni/mongo/catalog" +) + +// detectContextFromInput is a test helper that simulates what Complete() does: +// extract prefix, tokenize, then detect context. +func detectContextFromInput(input string) completionContext { + prefix := extractPrefix(input, len(input)) + tokens := tokenize(input, len(input)-len(prefix)) + return detectContext(tokens) +} + +func TestDetectContextTopLevel(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"", "empty input"}, + {" ", "whitespace only"}, + {"db.users.find();\n", "after semicolon and newline"}, + {"db.users.find();", "after semicolon"}, + {"var x = 1;", "after variable assignment"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextTopLevel { + t.Errorf("detectContextFromInput(%q) = %d, want contextTopLevel (%d)", tt.input, got, contextTopLevel) + } + }) + } +} + +func TestDetectContextAfterDbDot(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.", "db dot"}, + {"db.u", "db dot with prefix u"}, + {"db.get", "db dot with prefix get"}, + {"db.getC", "db dot with prefix getC"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextAfterDbDot { + t.Errorf("detectContextFromInput(%q) = %d, want contextAfterDbDot (%d)", tt.input, got, contextAfterDbDot) + } + }) + } +} + +func TestDetectContextAfterCollDot(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.users.", "db.coll."}, + {"db.users.f", "db.coll.prefix"}, + {`db["users"].`, "db bracket coll dot"}, + {`db["users"].f`, "db bracket coll dot prefix"}, + {`db.getCollection("users").`, "db.getCollection dot"}, + {`db.getCollection("users").f`, "db.getCollection dot prefix"}, + {"db.myCollection.", "db.myCollection dot"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextAfterCollDot { + t.Errorf("detectContextFromInput(%q) = %d, want contextAfterCollDot (%d)", tt.input, got, contextAfterCollDot) + } + }) + } +} + +func TestDetectContextAfterBracket(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db[", "db open bracket"}, + {`db["`, "db bracket with opening quote"}, + {`db["us`, "db bracket with partial collection name"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextAfterBracket { + t.Errorf("detectContextFromInput(%q) = %d, want contextAfterBracket (%d)", tt.input, got, contextAfterBracket) + } + }) + } +} + +func TestDetectContextCursorChain(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.users.find().", "after method call dot"}, + {"db.users.find().s", "after method call dot with prefix"}, + {"db.users.find({}).sort({a:1}).", "chained cursor methods"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextCursorChain { + t.Errorf("detectContextFromInput(%q) = %d, want contextCursorChain (%d)", tt.input, got, contextCursorChain) + } + }) + } +} + +func TestDetectContextShowTarget(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"show ", "show space"}, + {"show d", "show with prefix d"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextShowTarget { + t.Errorf("detectContextFromInput(%q) = %d, want contextShowTarget (%d)", tt.input, got, contextShowTarget) + } + }) + } +} + +func TestDetectContextAfterRsDot(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"rs.", "rs dot"}, + {"rs.s", "rs dot with prefix s"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextAfterRsDot { + t.Errorf("detectContextFromInput(%q) = %d, want contextAfterRsDot (%d)", tt.input, got, contextAfterRsDot) + } + }) + } +} + +func TestDetectContextAfterShDot(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"sh.", "sh dot"}, + {"sh.a", "sh dot with prefix a"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextAfterShDot { + t.Errorf("detectContextFromInput(%q) = %d, want contextAfterShDot (%d)", tt.input, got, contextAfterShDot) + } + }) + } +} + +func TestDetectContextAggStage(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.users.aggregate([{$", "agg stage with dollar prefix"}, + {"db.users.aggregate([{$m", "agg stage with dollar m prefix"}, + {"db.users.aggregate([{$match: {}}, {$", "agg stage after comma"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextAggStage { + t.Errorf("detectContextFromInput(%q) = %d, want contextAggStage (%d)", tt.input, got, contextAggStage) + } + }) + } +} + +func TestDetectContextQueryOperator(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.users.find({age: {$", "query operator with dollar prefix"}, + {"db.users.find({age: {$g", "query operator with dollar g prefix"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextQueryOperator { + t.Errorf("detectContextFromInput(%q) = %d, want contextQueryOperator (%d)", tt.input, got, contextQueryOperator) + } + }) + } +} + +func TestDetectContextInsideArgs(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.users.find(", "find open paren"}, + {"db.users.insertOne(", "insertOne open paren"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextInsideArgs { + t.Errorf("detectContextFromInput(%q) = %d, want contextInsideArgs (%d)", tt.input, got, contextInsideArgs) + } + }) + } +} + +func TestDetectContextDocumentKey(t *testing.T) { + tests := []struct { + input string + name string + }{ + {"db.users.find({", "open brace after paren"}, + {"db.users.insertOne({name: 1, ", "after comma inside brace"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectContextFromInput(tt.input) + if got != contextDocumentKey { + t.Errorf("detectContextFromInput(%q) = %d, want contextDocumentKey (%d)", tt.input, got, contextDocumentKey) + } + }) + } +} + +// --- Test helpers --- + +func newTestCatalog(names ...string) *catalog.Catalog { + cat := catalog.New() + for _, name := range names { + cat.AddCollection(name) + } + return cat +} + +func candidateTexts(candidates []Candidate) []string { + texts := make([]string, len(candidates)) + for i, c := range candidates { + texts[i] = c.Text + } + return texts +} + +func hasCandidate(candidates []Candidate, text string) bool { + for _, c := range candidates { + if c.Text == text { + return true + } + } + return false +} + +func hasCandidateWithType(candidates []Candidate, text string, typ CandidateType) bool { + for _, c := range candidates { + if c.Text == text && c.Type == typ { + return true + } + } + return false +} + +// --- extractPrefix tests --- + +func TestExtractPrefix(t *testing.T) { + tests := []struct { + name string + input string + offset int + want string + }{ + {"db dot", "db.", 3, ""}, + {"db dot partial", "db.us", 5, "us"}, + {"cursor chain prefix", "db.users.find().s", 17, "s"}, + {"dollar prefix", "{age: {$g", 9, "$g"}, + {"empty input", "", 0, ""}, + {"show prefix", "show d", 6, "d"}, + {"full word", "db", 2, "db"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractPrefix(tt.input, tt.offset) + if got != tt.want { + t.Errorf("extractPrefix(%q, %d) = %q, want %q", tt.input, tt.offset, got, tt.want) + } + }) + } +} + +// --- filterByPrefix tests --- + +func TestFilterByPrefix(t *testing.T) { + candidates := []Candidate{ + {Text: "find", Type: CandidateMethod}, + {Text: "findOne", Type: CandidateMethod}, + {Text: "aggregate", Type: CandidateMethod}, + {Text: "$gt", Type: CandidateQueryOperator}, + {Text: "$gte", Type: CandidateQueryOperator}, + } + + t.Run("empty prefix returns all", func(t *testing.T) { + got := filterByPrefix(candidates, "") + if len(got) != len(candidates) { + t.Errorf("filterByPrefix with empty prefix returned %d candidates, want %d", len(got), len(candidates)) + } + }) + + t.Run("case sensitive", func(t *testing.T) { + got := filterByPrefix(candidates, "F") + if len(got) != 0 { + t.Errorf("filterByPrefix with prefix 'F' returned %d candidates, want 0 (case-sensitive)", len(got)) + } + }) + + t.Run("dollar prefix", func(t *testing.T) { + got := filterByPrefix(candidates, "$g") + if len(got) != 2 { + t.Errorf("filterByPrefix with prefix '$g' returned %d candidates, want 2", len(got)) + } + for _, c := range got { + if c.Text != "$gt" && c.Text != "$gte" { + t.Errorf("unexpected candidate %q for prefix '$g'", c.Text) + } + } + }) + + t.Run("prefix f", func(t *testing.T) { + got := filterByPrefix(candidates, "f") + if len(got) != 2 { + t.Errorf("filterByPrefix with prefix 'f' returned %d candidates, want 2", len(got)) + } + }) +} + +// --- Complete end-to-end tests --- + +func TestCompleteAfterDbDot(t *testing.T) { + cat := newTestCatalog("users", "orders") + results := Complete("db.", 3, cat) + + // Should include collection names. + if !hasCandidateWithType(results, "users", CandidateCollection) { + t.Error("expected collection 'users' in results") + } + if !hasCandidateWithType(results, "orders", CandidateCollection) { + t.Error("expected collection 'orders' in results") + } + // Should include db methods. + if !hasCandidateWithType(results, "getName", CandidateDbMethod) { + t.Error("expected db method 'getName' in results") + } + if !hasCandidateWithType(results, "runCommand", CandidateDbMethod) { + t.Error("expected db method 'runCommand' in results") + } +} + +func TestCompleteCollectionMethodPrefix(t *testing.T) { + results := Complete("db.users.f", 10, nil) + + expected := []string{"find", "findOne", "findOneAndDelete", "findOneAndReplace", "findOneAndUpdate"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateMethod) { + t.Errorf("expected method %q in results", e) + } + } +} + +func TestCompleteCursorChainPrefix(t *testing.T) { + results := Complete("db.users.find().s", 17, nil) + + expected := []string{"sort", "skip", "size", "showRecordId"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateCursorMethod) { + t.Errorf("expected cursor method %q in results", e) + } + } +} + +func TestCompleteAggStage(t *testing.T) { + results := Complete("db.users.aggregate([{$m", 23, nil) + + expected := []string{"$match", "$merge"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateAggStage) { + t.Errorf("expected agg stage %q in results", e) + } + } +} + +func TestCompleteQueryOperator(t *testing.T) { + results := Complete("db.users.find({age: {$g", 23, nil) + + expected := []string{"$gt", "$gte", "$geoWithin", "$geoIntersects"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateQueryOperator) { + t.Errorf("expected query operator %q in results", e) + } + } +} + +func TestCompleteBracketWithCatalog(t *testing.T) { + cat := newTestCatalog("system.profile", "users", "system.views") + results := Complete(`db["sys`, 7, cat) + + if !hasCandidateWithType(results, "system.profile", CandidateCollection) { + t.Error("expected 'system.profile' in results") + } + if !hasCandidateWithType(results, "system.views", CandidateCollection) { + t.Error("expected 'system.views' in results") + } + if hasCandidate(results, "users") { + t.Error("should NOT include 'users' for prefix 'sys'") + } +} + +func TestCompleteShowTarget(t *testing.T) { + results := Complete("show d", 6, nil) + + expected := []string{"dbs", "databases"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateShowTarget) { + t.Errorf("expected show target %q in results", e) + } + } +} + +func TestCompleteRsMethods(t *testing.T) { + results := Complete("rs.", 3, nil) + + expected := []string{"status", "conf", "initiate"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateRsMethod) { + t.Errorf("expected rs method %q in results", e) + } + } +} + +func TestCompleteShMethods(t *testing.T) { + results := Complete("sh.", 3, nil) + + expected := []string{"addShard", "enableSharding", "status"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateShMethod) { + t.Errorf("expected sh method %q in results", e) + } + } +} + +func TestCompleteTopLevelEmpty(t *testing.T) { + results := Complete("", 0, nil) + + expected := []string{"db", "rs", "sh", "show"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateKeyword) { + t.Errorf("expected keyword %q in results", e) + } + } +} + +func TestCompleteGetCollection(t *testing.T) { + results := Complete(`db.getCollection("users").f`, 27, nil) + + expected := []string{"find", "findOne"} + for _, e := range expected { + if !hasCandidateWithType(results, e, CandidateMethod) { + t.Errorf("expected method %q in results", e) + } + } +} + +// --- Edge case tests --- + +func TestCompleteNilCatalog(t *testing.T) { + results := Complete("db.", 3, nil) + + // Should still have db methods. + if !hasCandidateWithType(results, "getName", CandidateDbMethod) { + t.Error("expected db method 'getName' with nil catalog") + } + // Should NOT have any collection candidates. + for _, c := range results { + if c.Type == CandidateCollection { + t.Errorf("unexpected collection candidate %q with nil catalog", c.Text) + } + } +} + +func TestCompleteCursorOvershoot(t *testing.T) { + // Should not panic when cursor offset exceeds input length. + results := Complete("db.", 100, nil) + if !hasCandidateWithType(results, "getName", CandidateDbMethod) { + t.Error("expected db method 'getName' with overshooting cursor") + } +} + +func TestCompleteBracketWithCatalogQuote(t *testing.T) { + cat := newTestCatalog("users", "orders") + results := Complete(`db["`, 4, cat) + + if !hasCandidateWithType(results, "users", CandidateCollection) { + t.Error("expected 'users' in bracket completion") + } + if !hasCandidateWithType(results, "orders", CandidateCollection) { + t.Error("expected 'orders' in bracket completion") + } +} + +// --- Additional edge case tests --- + +func TestCompleteMultiStatement(t *testing.T) { + // Completion in a second statement after semicolon. + input := "db.users.find(); db." + results := Complete(input, len(input), nil) + if !hasCandidateWithType(results, "getName", CandidateDbMethod) { + t.Error("expected db methods after semicolon in multi-statement input") + } +} + +func TestCompleteDeepNestedQueryOperator(t *testing.T) { + // Deeply nested query operator context. + input := `db.users.find({$and: [{age: {$` + results := Complete(input, len(input), nil) + if !hasCandidateWithType(results, "$gt", CandidateQueryOperator) { + t.Errorf("expected query operators in deeply nested context, got %v", candidateTexts(results)) + } +} + +func TestCompleteDocumentKeyAfterColon(t *testing.T) { + // After colon inside a document (typing a value). + input := "db.users.find({name: " + results := Complete(input, len(input), nil) + // Should offer BSON helpers and literals for value position. + if !hasCandidate(results, "true") { + t.Error("expected 'true' in document key context after colon") + } +} + +// --- Negative tests --- + +func TestCompleteNegativeCases(t *testing.T) { + t.Run("collection method prefix should not include unrelated", func(t *testing.T) { + results := Complete("db.users.f", 10, nil) + + if hasCandidate(results, "aggregate") { + t.Error("should NOT include 'aggregate' for prefix 'f'") + } + if hasCandidate(results, "sort") { + t.Error("should NOT include cursor method 'sort' in collection methods") + } + }) + + t.Run("case sensitivity", func(t *testing.T) { + results := Complete("db.users.F", 10, nil) + + if hasCandidate(results, "find") { + t.Error("should NOT match 'find' for prefix 'F' (case-sensitive)") + } + if hasCandidate(results, "findOne") { + t.Error("should NOT match 'findOne' for prefix 'F' (case-sensitive)") + } + }) +} diff --git a/mongo/completion/context.go b/mongo/completion/context.go new file mode 100644 index 00000000..a9a60f19 --- /dev/null +++ b/mongo/completion/context.go @@ -0,0 +1,253 @@ +package completion + +import "github.com/bytebase/omni/mongo/parser" + +// completionContext identifies the kind of completion expected. +type completionContext int + +const ( + contextTopLevel completionContext = iota // start of input or after semicolon + contextAfterDbDot // db.| + contextAfterCollDot // db.users.| + contextAfterBracket // db[| + contextInsideArgs // db.users.find(| + contextDocumentKey // {| or {age: 1, | + contextQueryOperator // {age: {$| + contextAggStage // [{$| + contextCursorChain // db.users.find().| + contextShowTarget // show | + contextAfterRsDot // rs.| + contextAfterShDot // sh.| +) + +// detectContext analyzes the token sequence to determine the completion context. +func detectContext(tokens []parser.Token) completionContext { + n := len(tokens) + if n == 0 { + return contextTopLevel + } + + last := tokens[n-1] + + // Ends with semicolon → top level. + if last.Str == ";" { + return contextTopLevel + } + + // Ends with "." → classify dot context. + if last.Str == "." { + return classifyDotContext(tokens[:n-1]) + } + + // Ends with "show" keyword → show target. + if last.Str == "show" { + return contextShowTarget + } + + // Ends with "[" → check if preceded by "db". + if last.Str == "[" { + if n >= 2 && tokens[n-2].Str == "db" { + return contextAfterBracket + } + return contextTopLevel + } + + // Ends with a string token after "db[" → bracket access with partial/full collection name. + // Handles: db["us (unterminated string) and db["users" (complete string). + if last.Type == parser.TokString { + if n >= 3 && tokens[n-2].Str == "[" && tokens[n-3].Str == "db" { + return contextAfterBracket + } + } + + // Ends with "(" → inside args. + if last.Str == "(" { + return contextInsideArgs + } + + // Ends with "{" → classify brace context. + if last.Str == "{" { + return classifyBraceContext(tokens[:n-1]) + } + + // Ends with "," or ":" → check if inside unclosed brace. + if last.Str == "," || last.Str == ":" { + if insideBrace(tokens[:n-1]) { + return contextDocumentKey + } + return contextTopLevel + } + + // Ends with ")" followed by nothing — this shouldn't produce completions + // in most cases, but we handle it as top level. + if last.Str == ")" { + return contextTopLevel + } + + return contextTopLevel +} + +// classifyDotContext determines the context when the last token is ".". +// tokens is the slice WITHOUT the trailing ".". +func classifyDotContext(tokens []parser.Token) completionContext { + n := len(tokens) + if n == 0 { + return contextTopLevel + } + + last := tokens[n-1] + + // "db." → afterDbDot + if last.Str == "db" && n == 1 { + return contextAfterDbDot + } + if last.Str == "db" { + // Check that what's before "db" isn't a dot (which would mean db is a property). + prev := tokens[n-2] + if prev.Str != "." { + return contextAfterDbDot + } + } + + // "rs." → afterRsDot + if last.Str == "rs" { + if n == 1 || tokens[n-2].Str != "." { + return contextAfterRsDot + } + } + + // "sh." → afterShDot + if last.Str == "sh" { + if n == 1 || tokens[n-2].Str != "." { + return contextAfterShDot + } + } + + // Ends with ")" → could be cursor chain or getCollection(...). + if last.Str == ")" { + // Try to find the matching "(" and check if it's db.getCollection("x"). + openIdx := findMatchingOpen(tokens, n-1, "(", ")") + if openIdx >= 0 { + // Check for db.getCollection("x"). + if openIdx >= 2 && + tokens[openIdx-1].Str == "getCollection" && + tokens[openIdx-2].Str == "." { + // Check if preceded by "db". + if openIdx >= 3 && tokens[openIdx-3].Str == "db" { + return contextAfterCollDot + } + } + } + // General case: method(). + return contextCursorChain + } + + // Ends with "]" → check for db["coll"]. + if last.Str == "]" { + openIdx := findMatchingOpen(tokens, n-1, "[", "]") + if openIdx >= 1 && tokens[openIdx-1].Str == "db" { + return contextAfterCollDot + } + return contextCursorChain + } + + // Ends with a word → check if it's a collection name after "db.". + if last.IsWord() { + // db . . → afterCollDot + if n >= 3 && tokens[n-2].Str == "." && tokens[n-3].Str == "db" { + // Make sure "db" isn't itself after a dot. + if n == 3 || tokens[n-4].Str != "." { + return contextAfterCollDot + } + } + } + + // Default: treat unrecognized patterns ending in "." as cursor chain. + // In mongosh, arbitrary expressions are commonly chained with dots + // (e.g., result.length, cursor.next()), so cursor methods are the + // most useful default when we can't determine a more specific context. + return contextCursorChain +} + +// classifyBraceContext determines the context when the last token is "{". +// tokens is the slice WITHOUT the trailing "{". +func classifyBraceContext(tokens []parser.Token) completionContext { + n := len(tokens) + if n == 0 { + return contextDocumentKey + } + + last := tokens[n-1] + + // Preceded by ":" → query operator context (nested document for operator). + if last.Str == ":" { + return contextQueryOperator + } + + // Preceded by "[" → agg stage context (pipeline array). + if last.Str == "[" { + return contextAggStage + } + + // Preceded by "," → need to check if we're inside an array (agg pipeline). + if last.Str == "," { + if insideArray(tokens[:n-1]) { + return contextAggStage + } + return contextDocumentKey + } + + return contextDocumentKey +} + +// findMatchingOpen walks backward from pos to find the matching open delimiter. +// pos should point to the close delimiter. Returns the index of the matching open, +// or -1 if not found. +func findMatchingOpen(tokens []parser.Token, pos int, open, close string) int { + depth := 0 + for i := pos; i >= 0; i-- { + if tokens[i].Str == close { + depth++ + } else if tokens[i].Str == open { + depth-- + if depth == 0 { + return i + } + } + } + return -1 +} + +// insideBrace checks if there is an unclosed "{" in tokens. +func insideBrace(tokens []parser.Token) bool { + depth := 0 + for i := len(tokens) - 1; i >= 0; i-- { + switch tokens[i].Str { + case "}": + depth++ + case "{": + if depth == 0 { + return true + } + depth-- + } + } + return false +} + +// insideArray checks if there is an unclosed "[" in tokens. +func insideArray(tokens []parser.Token) bool { + depth := 0 + for i := len(tokens) - 1; i >= 0; i-- { + switch tokens[i].Str { + case "]": + depth++ + case "[": + if depth == 0 { + return true + } + depth-- + } + } + return false +} diff --git a/mongo/parser/token.go b/mongo/parser/token.go index 66c65ca3..4361ad42 100644 --- a/mongo/parser/token.go +++ b/mongo/parser/token.go @@ -349,6 +349,18 @@ const ( kwNaN // NaN ) +// Exported token type constants for use by external packages (e.g., completion). +const ( + TokEOF = tokEOF + TokString = tokString + TokIdent = tokIdent +) + +// IsWord returns true if the token is an identifier or keyword (a "word" token). +func (t Token) IsWord() bool { + return t.Type == tokIdent || t.Type >= 700 +} + // keywords maps keyword strings to their token types. // mongosh keywords are case-sensitive. var keywords = map[string]int{