diff --git a/pyrefly/lib/lsp/wasm/completion.rs b/pyrefly/lib/lsp/wasm/completion.rs index fd4ced8a40..a0fd0133f4 100644 --- a/pyrefly/lib/lsp/wasm/completion.rs +++ b/pyrefly/lib/lsp/wasm/completion.rs @@ -48,6 +48,7 @@ use crate::state::lsp::IdentifierContext; use crate::state::lsp::IdentifierWithContext; use crate::state::lsp::ImportFormat; use crate::state::lsp::MIN_CHARACTERS_TYPED_AUTOIMPORT; +use crate::state::lsp::PatternMatchParameterKind; use crate::state::state::Transaction; use crate::types::callable::Param; use crate::types::types::Type; @@ -836,6 +837,52 @@ impl Transaction<'_> { } } + /// Suggest `attr=` completions inside a class pattern like `case Point(x=...)`. + pub(crate) fn add_match_class_keyword_completions( + &self, + handle: &Handle, + covering_nodes: &[AnyNodeRef], + completions: &mut Vec, + ) { + let Some(pattern_class) = covering_nodes.iter().find_map(|node| match node { + AnyNodeRef::PatternMatchClass(pattern_class) => Some(pattern_class), + _ => None, + }) else { + return; + }; + let Some(class_ty) = self.get_type_trace(handle, pattern_class.cls.range()) else { + return; + }; + let Some(items) = self.ad_hoc_solve(handle, "completion_match_class_keywords", |solver| { + let instance_ty = match class_ty { + Type::ClassDef(cls) => solver.instantiate(&cls), + Type::ClassType(cls) => Type::ClassType(cls), + Type::Type(box Type::ClassType(cls)) => Type::ClassType(cls), + _ => return Vec::new(), + }; + solver + .completions(instance_ty, None, true) + .into_iter() + .map(|attr| { + RankedCompletion::new(CompletionItem { + label: format!("{}=", attr.name.as_str()), + detail: attr.ty.map(|ty| ty.to_string()), + kind: Some(CompletionItemKind::VARIABLE), + tags: if attr.is_deprecated { + Some(vec![CompletionItemTag::DEPRECATED]) + } else { + None + }, + ..Default::default() + }) + }) + .collect::>() + }) else { + return; + }; + completions.extend(items); + } + /// Core completion implementation returning items and incomplete flag. pub(crate) fn completion_sorted_opt_with_incomplete( &self, @@ -1046,6 +1093,14 @@ impl Transaction<'_> { if matches!(context, IdentifierContext::MethodDef { .. }) { Self::add_magic_method_completions(&identifier, &mut result); } + if matches!( + context, + IdentifierContext::PatternMatch(PatternMatchParameterKind::KeywordArgName) + ) && let Some(mod_module) = self.get_ast(handle) + { + let nodes = Ast::locate_node(&mod_module, position); + self.add_match_class_keyword_completions(handle, &nodes, &mut result); + } self.add_kwargs_completions(handle, position, &mut result); Self::add_keyword_completions(handle, &mut result); let has_local_completions = self.add_local_variable_completions( @@ -1105,6 +1160,7 @@ impl Transaction<'_> { &mut result, in_string_literal, ); + self.add_match_class_keyword_completions(handle, &nodes, &mut result); let dict_key_claimed = self.add_dict_key_completions( handle, mod_module.as_ref(), diff --git a/pyrefly/lib/test/lsp/completion.rs b/pyrefly/lib/test/lsp/completion.rs index bc364cc816..b638b3ac25 100644 --- a/pyrefly/lib/test/lsp/completion.rs +++ b/pyrefly/lib/test/lsp/completion.rs @@ -1402,6 +1402,37 @@ Completion Results: ); } +#[test] +fn completion_match_class_keyword_dataclass() { + let code = r#" +from dataclasses import dataclass + +@dataclass +class A: + a: int + b: int + +x = object() +match x: + case A( ): ... +# ^ +"#; + let (handles, state) = mk_multi_file_state(&[("main", code)], Require::Exports, false); + let handle = handles.get("main").unwrap(); + let position = extract_cursors_for_test(code)[0]; + let txn = state.transaction(); + let completions = txn.completion(handle, position, ImportFormat::Absolute, true, None); + let completion_labels: Vec<_> = completions.iter().map(|c| c.label.as_str()).collect(); + + for expected in ["a=", "b="] { + assert!( + completion_labels.contains(&expected), + "missing {expected} in completions: {:?}", + completion_labels + ); + } +} + #[test] fn completion_literal_union_alias() { let code = r#" diff --git a/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs b/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs index cf56f84412..62429ffd53 100644 --- a/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs +++ b/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs @@ -48,7 +48,7 @@ fn test_pytorch_error_propagation_latency() { }; // Use all available cores for realistic benchmarking let mut interaction = - LspInteraction::new_with_args(args, NoTelemetry, Some(ThreadCount::AllThreads)); + LspInteraction::new_with_args(args, NoTelemetry, Some(ThreadCount::AllThreads), None); interaction.set_root(pytorch_root.clone()); interaction