|
| 1 | +//! Function matching algorithms |
| 2 | +
|
| 3 | +use smart_diff_parser::{Function, MatchResult}; |
| 4 | +use serde::{Deserialize, Serialize}; |
| 5 | +use std::collections::HashMap; |
| 6 | + |
| 7 | +/// Function matcher that finds optimal mappings between function sets |
| 8 | +pub struct FunctionMatcher { |
| 9 | + threshold: f64, |
| 10 | +} |
| 11 | + |
| 12 | +/// Similarity score between two functions |
| 13 | +#[derive(Debug, Clone, Serialize, Deserialize)] |
| 14 | +pub struct SimilarityScore { |
| 15 | + pub signature_similarity: f64, |
| 16 | + pub body_similarity: f64, |
| 17 | + pub context_similarity: f64, |
| 18 | + pub overall_similarity: f64, |
| 19 | +} |
| 20 | + |
| 21 | +impl FunctionMatcher { |
| 22 | + pub fn new(threshold: f64) -> Self { |
| 23 | + Self { threshold } |
| 24 | + } |
| 25 | + |
| 26 | + /// Match functions between two sets using Hungarian algorithm |
| 27 | + pub fn match_functions(&self, source_functions: &[Function], target_functions: &[Function]) -> MatchResult { |
| 28 | + let mut result = MatchResult::new(); |
| 29 | + |
| 30 | + if source_functions.is_empty() && target_functions.is_empty() { |
| 31 | + result.similarity = 1.0; |
| 32 | + return result; |
| 33 | + } |
| 34 | + |
| 35 | + // Calculate similarity matrix |
| 36 | + let similarity_matrix = self.calculate_similarity_matrix(source_functions, target_functions); |
| 37 | + |
| 38 | + // Apply Hungarian algorithm for optimal matching |
| 39 | + let matches = self.hungarian_matching(&similarity_matrix); |
| 40 | + |
| 41 | + // Process matches and create result |
| 42 | + self.process_matches(source_functions, target_functions, &matches, &mut result); |
| 43 | + |
| 44 | + result.calculate_similarity(); |
| 45 | + result |
| 46 | + } |
| 47 | + |
| 48 | + fn calculate_similarity_matrix(&self, source: &[Function], target: &[Function]) -> Vec<Vec<f64>> { |
| 49 | + let mut matrix = Vec::new(); |
| 50 | + |
| 51 | + for source_func in source { |
| 52 | + let mut row = Vec::new(); |
| 53 | + for target_func in target { |
| 54 | + let similarity = self.calculate_function_similarity(source_func, target_func); |
| 55 | + row.push(similarity.overall_similarity); |
| 56 | + } |
| 57 | + matrix.push(row); |
| 58 | + } |
| 59 | + |
| 60 | + matrix |
| 61 | + } |
| 62 | + |
| 63 | + /// Calculate similarity between two functions |
| 64 | + pub fn calculate_function_similarity(&self, func1: &Function, func2: &Function) -> SimilarityScore { |
| 65 | + // Signature similarity (40% weight) |
| 66 | + let signature_similarity = func1.signature.similarity(&func2.signature); |
| 67 | + |
| 68 | + // Body similarity using AST structure (40% weight) |
| 69 | + let body_similarity = self.calculate_ast_similarity(&func1.body, &func2.body); |
| 70 | + |
| 71 | + // Context similarity (20% weight) - based on surrounding functions, calls, etc. |
| 72 | + let context_similarity = self.calculate_context_similarity(func1, func2); |
| 73 | + |
| 74 | + // Weighted overall similarity |
| 75 | + let overall_similarity = |
| 76 | + signature_similarity * 0.4 + |
| 77 | + body_similarity * 0.4 + |
| 78 | + context_similarity * 0.2; |
| 79 | + |
| 80 | + SimilarityScore { |
| 81 | + signature_similarity, |
| 82 | + body_similarity, |
| 83 | + context_similarity, |
| 84 | + overall_similarity, |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + fn calculate_ast_similarity(&self, ast1: &smart_diff_parser::ASTNode, ast2: &smart_diff_parser::ASTNode) -> f64 { |
| 89 | + // Simple structural similarity based on node types and tree structure |
| 90 | + if ast1.node_type != ast2.node_type { |
| 91 | + return 0.0; |
| 92 | + } |
| 93 | + |
| 94 | + if ast1.children.is_empty() && ast2.children.is_empty() { |
| 95 | + return 1.0; |
| 96 | + } |
| 97 | + |
| 98 | + if ast1.children.len() != ast2.children.len() { |
| 99 | + return 0.5; // Partial similarity for different child counts |
| 100 | + } |
| 101 | + |
| 102 | + let mut total_similarity = 0.0; |
| 103 | + for (child1, child2) in ast1.children.iter().zip(ast2.children.iter()) { |
| 104 | + total_similarity += self.calculate_ast_similarity(child1, child2); |
| 105 | + } |
| 106 | + |
| 107 | + total_similarity / ast1.children.len() as f64 |
| 108 | + } |
| 109 | + |
| 110 | + fn calculate_context_similarity(&self, func1: &Function, func2: &Function) -> f64 { |
| 111 | + // Compare function calls, dependencies, etc. |
| 112 | + let calls1 = func1.extract_function_calls(); |
| 113 | + let calls2 = func2.extract_function_calls(); |
| 114 | + |
| 115 | + if calls1.is_empty() && calls2.is_empty() { |
| 116 | + return 1.0; |
| 117 | + } |
| 118 | + |
| 119 | + let common_calls = calls1.iter() |
| 120 | + .filter(|call| calls2.contains(call)) |
| 121 | + .count(); |
| 122 | + |
| 123 | + let total_calls = calls1.len().max(calls2.len()); |
| 124 | + if total_calls > 0 { |
| 125 | + common_calls as f64 / total_calls as f64 |
| 126 | + } else { |
| 127 | + 1.0 |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + fn hungarian_matching(&self, similarity_matrix: &[Vec<f64>]) -> Vec<(usize, usize)> { |
| 132 | + // Placeholder implementation - in reality would use Hungarian algorithm |
| 133 | + let mut matches = Vec::new(); |
| 134 | + |
| 135 | + for (i, row) in similarity_matrix.iter().enumerate() { |
| 136 | + if let Some((j, &similarity)) = row.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) { |
| 137 | + if similarity >= self.threshold { |
| 138 | + matches.push((i, j)); |
| 139 | + } |
| 140 | + } |
| 141 | + } |
| 142 | + |
| 143 | + matches |
| 144 | + } |
| 145 | + |
| 146 | + fn process_matches(&self, source: &[Function], target: &[Function], |
| 147 | + matches: &[(usize, usize)], result: &mut MatchResult) { |
| 148 | + let mut matched_source = std::collections::HashSet::new(); |
| 149 | + let mut matched_target = std::collections::HashSet::new(); |
| 150 | + |
| 151 | + for &(source_idx, target_idx) in matches { |
| 152 | + let source_func = &source[source_idx]; |
| 153 | + let target_func = &target[target_idx]; |
| 154 | + |
| 155 | + result.mapping.insert(source_func.hash.clone(), target_func.hash.clone()); |
| 156 | + matched_source.insert(source_idx); |
| 157 | + matched_target.insert(target_idx); |
| 158 | + |
| 159 | + // Create change record if functions are different |
| 160 | + let similarity = self.calculate_function_similarity(source_func, target_func); |
| 161 | + if similarity.overall_similarity < 1.0 { |
| 162 | + let change = smart_diff_parser::Change::new( |
| 163 | + smart_diff_parser::ChangeType::Modify, |
| 164 | + format!("Function '{}' modified (similarity: {:.2})", |
| 165 | + source_func.signature.name, similarity.overall_similarity) |
| 166 | + ).with_confidence(similarity.overall_similarity); |
| 167 | + |
| 168 | + result.changes.push(change); |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + // Record unmatched functions |
| 173 | + for (i, func) in source.iter().enumerate() { |
| 174 | + if !matched_source.contains(&i) { |
| 175 | + result.unmatched_source.push(func.hash.clone()); |
| 176 | + |
| 177 | + let change = smart_diff_parser::Change::new( |
| 178 | + smart_diff_parser::ChangeType::Delete, |
| 179 | + format!("Function '{}' deleted", func.signature.name) |
| 180 | + ); |
| 181 | + result.changes.push(change); |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + for (i, func) in target.iter().enumerate() { |
| 186 | + if !matched_target.contains(&i) { |
| 187 | + result.unmatched_target.push(func.hash.clone()); |
| 188 | + |
| 189 | + let change = smart_diff_parser::Change::new( |
| 190 | + smart_diff_parser::ChangeType::Add, |
| 191 | + format!("Function '{}' added", func.signature.name) |
| 192 | + ); |
| 193 | + result.changes.push(change); |
| 194 | + } |
| 195 | + } |
| 196 | + } |
| 197 | +} |
0 commit comments