From 6c8088437a95746af7b817d50a925194f525f607 Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 03:33:41 +0200 Subject: [PATCH 01/34] fix: Allow async dispatches and error page rendering without authentication in WebSecurityConfig --- .../rostilos/codecrow/security/web/WebSecurityConfig.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java b/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java index 77a284f3..4c8eef38 100644 --- a/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java +++ b/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java @@ -1,5 +1,6 @@ package org.rostilos.codecrow.security.web; +import jakarta.servlet.DispatcherType; import java.util.Arrays; import org.rostilos.codecrow.security.web.jwt.AuthEntryPoint; import org.rostilos.codecrow.security.web.jwt.AuthTokenFilter; @@ -113,8 +114,12 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { ) .authorizeHttpRequests(auth -> auth + // Allow async dispatches to complete (SSE, streaming responses) + .dispatcherTypeMatchers(DispatcherType.ASYNC, DispatcherType.ERROR).permitAll() // Allow all OPTIONS requests (CORS preflight) .requestMatchers(org.springframework.http.HttpMethod.OPTIONS, "/**").permitAll() + // Allow error page to be rendered without authentication + .requestMatchers("/error").permitAll() .requestMatchers("/api/auth/**").permitAll() .requestMatchers("/api/test/**").permitAll() // OAuth callbacks need to be public (called by VCS providers) From b5cab46f1f35ede347952d8f51a7b9b994e467ec Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 12:45:36 +0200 Subject: [PATCH 02/34] fix: Add analysis lock checks to prevent duplicate PR analysis in webhook handlers --- .../BitbucketCloudPullRequestWebhookHandler.java | 16 +++++++++++++++- .../GitHubPullRequestWebhookHandler.java | 16 +++++++++++++++- .../src/rag_pipeline/core/ast_splitter.py | 15 ++++++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java index 950f601c..845362d4 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java @@ -1,10 +1,12 @@ package org.rostilos.codecrow.pipelineagent.bitbucket.webhookhandler; +import org.rostilos.codecrow.core.model.analysis.AnalysisLockType; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; +import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; @@ -48,13 +50,16 @@ public class BitbucketCloudPullRequestWebhookHandler extends AbstractWebhookHand private final PullRequestAnalysisProcessor pullRequestAnalysisProcessor; private final VcsServiceFactory vcsServiceFactory; + private final AnalysisLockService analysisLockService; public BitbucketCloudPullRequestWebhookHandler( PullRequestAnalysisProcessor pullRequestAnalysisProcessor, - VcsServiceFactory vcsServiceFactory + VcsServiceFactory vcsServiceFactory, + AnalysisLockService analysisLockService ) { this.pullRequestAnalysisProcessor = pullRequestAnalysisProcessor; this.vcsServiceFactory = vcsServiceFactory; + this.analysisLockService = analysisLockService; } @Override @@ -104,6 +109,15 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro String placeholderCommentId = null; try { + // Check if analysis is already in progress for this branch BEFORE posting placeholder + // This prevents duplicate webhooks from both posting placeholders and deleting each other's comments + String sourceBranch = payload.sourceBranch(); + if (analysisLockService.isLocked(project.getId(), sourceBranch, AnalysisLockType.PR_ANALYSIS)) { + log.info("PR analysis already in progress for project={}, branch={}, PR={} - skipping duplicate webhook", + project.getId(), sourceBranch, payload.pullRequestId()); + return WebhookResult.ignored("PR analysis already in progress for this branch"); + } + // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java index b64a4f7d..7dfe21e8 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java @@ -1,10 +1,12 @@ package org.rostilos.codecrow.pipelineagent.github.webhookhandler; +import org.rostilos.codecrow.core.model.analysis.AnalysisLockType; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; +import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; @@ -47,13 +49,16 @@ public class GitHubPullRequestWebhookHandler extends AbstractWebhookHandler impl private final PullRequestAnalysisProcessor pullRequestAnalysisProcessor; private final VcsServiceFactory vcsServiceFactory; + private final AnalysisLockService analysisLockService; public GitHubPullRequestWebhookHandler( PullRequestAnalysisProcessor pullRequestAnalysisProcessor, - VcsServiceFactory vcsServiceFactory + VcsServiceFactory vcsServiceFactory, + AnalysisLockService analysisLockService ) { this.pullRequestAnalysisProcessor = pullRequestAnalysisProcessor; this.vcsServiceFactory = vcsServiceFactory; + this.analysisLockService = analysisLockService; } @Override @@ -117,6 +122,15 @@ private WebhookResult handlePullRequestEvent( String placeholderCommentId = null; try { + // Check if analysis is already in progress for this branch BEFORE posting placeholder + // This prevents duplicate webhooks from both posting placeholders and deleting each other's comments + String sourceBranch = payload.sourceBranch(); + if (analysisLockService.isLocked(project.getId(), sourceBranch, AnalysisLockType.PR_ANALYSIS)) { + log.info("PR analysis already in progress for project={}, branch={}, PR={} - skipping duplicate webhook", + project.getId(), sourceBranch, payload.pullRequestId()); + return WebhookResult.ignored("PR analysis already in progress for this branch"); + } + // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py index 29e5369b..691a3b0d 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py @@ -620,6 +620,10 @@ def _extract_ast_chunks_with_context( chunks = [] processed_ranges: Set[tuple] = set() # Track (start, end) to avoid duplicates + # IMPORTANT: Tree-sitter returns byte positions, not character positions + # We need to slice bytes and decode, not slice the string directly + source_bytes = source_code.encode('utf-8') + # File-level metadata collected dynamically from AST file_metadata: Dict[str, Any] = { 'imports': [], @@ -633,8 +637,8 @@ def _extract_ast_chunks_with_context( all_semantic_types = class_types | function_types def get_node_text(node) -> str: - """Get full text content of a node""" - return source_code[node.start_byte:node.end_byte] + """Get full text content of a node using byte positions""" + return source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') def extract_identifiers(node) -> List[str]: """Recursively extract all identifier names from a node""" @@ -753,10 +757,11 @@ def traverse(node, parent_context: List[str], depth: int = 0): if node_range in processed_ranges: return - content = source_code[node.start_byte:node.end_byte] + # Use bytes for slicing since tree-sitter returns byte positions + content = source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') - # Calculate line numbers - start_line = source_code[:node.start_byte].count('\n') + 1 + # Calculate line numbers (use bytes for consistency) + start_line = source_bytes[:node.start_byte].count(b'\n') + 1 end_line = start_line + content.count('\n') # Get the name of this node From 3bc5967c97d2a79efc6a62a73a1e62a3e151aa6c Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 12:59:29 +0200 Subject: [PATCH 03/34] fix: Implement atomic upsert for command rate limiting to prevent race conditions --- .../CommentCommandRateLimitRepository.java | 18 +++++++++ .../service/VcsRagIndexingService.java | 37 +++++++++---------- .../CommentCommandRateLimitService.java | 14 ++----- 3 files changed, 38 insertions(+), 31 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/analysis/CommentCommandRateLimitRepository.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/analysis/CommentCommandRateLimitRepository.java index 9118391f..35527873 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/analysis/CommentCommandRateLimitRepository.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/analysis/CommentCommandRateLimitRepository.java @@ -31,4 +31,22 @@ int countCommandsInWindow( @Param("projectId") Long projectId, @Param("windowStart") OffsetDateTime windowStart ); + + /** + * Atomic upsert: increments command count if record exists, creates with count=1 if not. + * Uses PostgreSQL ON CONFLICT DO UPDATE to avoid race conditions. + */ + @Modifying + @Query(value = """ + INSERT INTO comment_command_rate_limit (project_id, window_start, command_count, last_command_at) + VALUES (:projectId, :windowStart, 1, NOW()) + ON CONFLICT (project_id, window_start) + DO UPDATE SET + command_count = comment_command_rate_limit.command_count + 1, + last_command_at = NOW() + """, nativeQuery = true) + void upsertCommandCount( + @Param("projectId") Long projectId, + @Param("windowStart") OffsetDateTime windowStart + ); } diff --git a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java index 6c05e098..8a82b569 100644 --- a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java +++ b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java @@ -98,6 +98,23 @@ public Map indexProjectFromVcs( String branch = determineBranch(requestBranch, config); + // Check if indexing can start BEFORE creating job to avoid orphan "failed" jobs + if (!ragIndexTrackingService.canStartIndexing(project)) { + log.warn("RAG indexing already in progress for project: {}", project.getName()); + return Map.of("status", "locked", "message", "RAG indexing is already in progress"); + } + + // Try to acquire lock BEFORE creating job - this is the authoritative check + Optional lockKey = analysisLockService.acquireLock( + project, branch, AnalysisLockType.RAG_INDEXING + ); + + if (lockKey.isEmpty()) { + log.warn("Failed to acquire RAG indexing lock for project: {} (another process holds the lock)", project.getName()); + return Map.of("status", "locked", "message", "RAG indexing is already in progress (lock held by another process)"); + } + + // Now that we have the lock, create and start the job Job job = jobService.createRagIndexJob(project, null); if (job != null) { jobService.startJob(job); @@ -111,26 +128,6 @@ public Map indexProjectFromVcs( "message", "Starting RAG indexing for branch: " + branch )); - if (!ragIndexTrackingService.canStartIndexing(project)) { - log.warn("RAG indexing already in progress for project: {}", project.getName()); - if (job != null) { - jobService.failJob(job, "RAG indexing is already in progress"); - } - return Map.of("status", "locked", "message", "RAG indexing is already in progress"); - } - - Optional lockKey = analysisLockService.acquireLock( - project, branch, AnalysisLockType.RAG_INDEXING - ); - - if (lockKey.isEmpty()) { - log.warn("Failed to acquire RAG indexing lock for project: {}", project.getName()); - if (job != null) { - jobService.failJob(job, "Could not acquire lock for RAG indexing"); - } - return Map.of("status", "locked", "message", "Could not acquire lock for RAG indexing"); - } - try { return performIndexing(project, vcsConnection, workspaceSlug, repoSlug, branch, config, messageConsumer, job); } finally { diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/CommentCommandRateLimitService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/CommentCommandRateLimitService.java index 3a5bfabe..fc26ca02 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/CommentCommandRateLimitService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/CommentCommandRateLimitService.java @@ -54,6 +54,7 @@ public boolean isCommandAllowed(Project project) { /** * Record a command execution for rate limiting purposes. + * Uses atomic upsert to avoid race conditions with concurrent requests. * * @param project The project that executed the command */ @@ -68,17 +69,8 @@ public void recordCommand(Project project) { .truncatedTo(ChronoUnit.HOURS) .plus((OffsetDateTime.now().getMinute() / windowMinutes) * windowMinutes, ChronoUnit.MINUTES); - Optional existingRecord = rateLimitRepository - .findByProjectIdAndWindowStart(project.getId(), windowStart); - - if (existingRecord.isPresent()) { - existingRecord.get().incrementCommandCount(); - rateLimitRepository.save(existingRecord.get()); - } else { - CommentCommandRateLimit newRecord = new CommentCommandRateLimit(project, windowStart); - newRecord.incrementCommandCount(); - rateLimitRepository.save(newRecord); - } + // Use atomic upsert to avoid duplicate key violations from concurrent requests + rateLimitRepository.upsertCommandCount(project.getId(), windowStart); log.debug("Recorded command for project {}, window starting at {}", project.getId(), windowStart); } From 28a400751fd14aa26c5a1df0586ad52f0772f8e1 Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 13:16:07 +0200 Subject: [PATCH 04/34] fix: Improve alias management by ensuring direct collections are deleted before alias creation --- .../src/rag_pipeline/core/index_manager.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py index d5497d1a..f92014af 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py @@ -535,9 +535,19 @@ def index_repository( if old_collection_exists and not is_direct_collection: old_versioned_name = self._resolve_alias_to_collection(alias_name) + # If there's a direct collection with the target name, we need to delete it FIRST + # before we can create an alias with that name + if is_direct_collection: + logger.info(f"Migrating from direct collection to alias-based indexing. Deleting old collection: {alias_name}") + try: + self.qdrant_client.delete_collection(alias_name) + except Exception as del_err: + logger.error(f"Failed to delete old direct collection before alias swap: {del_err}") + raise Exception(f"Cannot create alias - collection '{alias_name}' exists and cannot be deleted: {del_err}") + alias_operations = [] - # Delete old alias if exists (not direct collection) + # Delete old alias if exists (not direct collection - already handled above) if old_collection_exists and not is_direct_collection: alias_operations.append( DeleteAliasOperation(delete_alias=DeleteAlias(alias_name=alias_name)) @@ -558,14 +568,8 @@ def index_repository( logger.info(f"Alias swap completed successfully: {alias_name} -> {temp_collection_name}") - # NOW delete old collections (after alias swap is complete) - if is_direct_collection: - logger.info(f"Migrating from direct collection to alias-based indexing. Deleting old collection: {alias_name}") - try: - self.qdrant_client.delete_collection(alias_name) - except Exception as del_err: - logger.warning(f"Failed to delete old direct collection: {del_err}") - elif old_versioned_name and old_versioned_name != temp_collection_name: + # Delete old versioned collection (after alias swap is complete) + if old_versioned_name and old_versioned_name != temp_collection_name: logger.info(f"Deleting old versioned collection: {old_versioned_name}") try: self.qdrant_client.delete_collection(old_versioned_name) From efc42d2f2628ff04510041f6049c810c390ec59a Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 14:28:45 +0200 Subject: [PATCH 05/34] fix: Enhance locking mechanism in PR webhook handlers to prevent race conditions --- .../service/VcsRagIndexingService.java | 28 +++++++++---------- ...tbucketCloudPullRequestWebhookHandler.java | 17 +++++++++-- .../GitHubPullRequestWebhookHandler.java | 17 +++++++++-- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java index 8a82b569..e8edbdf9 100644 --- a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java +++ b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/VcsRagIndexingService.java @@ -114,21 +114,21 @@ public Map indexProjectFromVcs( return Map.of("status", "locked", "message", "RAG indexing is already in progress (lock held by another process)"); } - // Now that we have the lock, create and start the job - Job job = jobService.createRagIndexJob(project, null); - if (job != null) { - jobService.startJob(job); - jobService.logToJob(job, JobLogLevel.INFO, "init", - "Starting RAG indexing for branch: " + branch); - } - - messageConsumer.accept(Map.of( - "type", "progress", - "stage", "init", - "message", "Starting RAG indexing for branch: " + branch - )); - try { + // Now that we have the lock, create and start the job + Job job = jobService.createRagIndexJob(project, null); + if (job != null) { + jobService.startJob(job); + jobService.logToJob(job, JobLogLevel.INFO, "init", + "Starting RAG indexing for branch: " + branch); + } + + messageConsumer.accept(Map.of( + "type", "progress", + "stage", "init", + "message", "Starting RAG indexing for branch: " + branch + )); + return performIndexing(project, vcsConnection, workspaceSlug, repoSlug, branch, config, messageConsumer, job); } finally { analysisLockService.releaseLock(lockKey.get()); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java index 845362d4..04036114 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java @@ -17,6 +17,7 @@ import org.springframework.stereotype.Component; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -109,15 +110,25 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro String placeholderCommentId = null; try { - // Check if analysis is already in progress for this branch BEFORE posting placeholder - // This prevents duplicate webhooks from both posting placeholders and deleting each other's comments + // Try to acquire lock atomically BEFORE posting placeholder + // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously + // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will + // reuse this lock since it's for the same project/branch/type String sourceBranch = payload.sourceBranch(); - if (analysisLockService.isLocked(project.getId(), sourceBranch, AnalysisLockType.PR_ANALYSIS)) { + Optional earlyLock = analysisLockService.acquireLock( + project, sourceBranch, AnalysisLockType.PR_ANALYSIS, + payload.commitHash(), Long.parseLong(payload.pullRequestId())); + + if (earlyLock.isEmpty()) { log.info("PR analysis already in progress for project={}, branch={}, PR={} - skipping duplicate webhook", project.getId(), sourceBranch, payload.pullRequestId()); return WebhookResult.ignored("PR analysis already in progress for this branch"); } + // Lock acquired - placeholder posting is now protected from race conditions + // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it + // since acquireLockWithWait() will detect the existing lock and use it + // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java index 7dfe21e8..2997a833 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java @@ -17,6 +17,7 @@ import org.springframework.stereotype.Component; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -122,15 +123,25 @@ private WebhookResult handlePullRequestEvent( String placeholderCommentId = null; try { - // Check if analysis is already in progress for this branch BEFORE posting placeholder - // This prevents duplicate webhooks from both posting placeholders and deleting each other's comments + // Try to acquire lock atomically BEFORE posting placeholder + // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously + // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will + // reuse this lock since it's for the same project/branch/type String sourceBranch = payload.sourceBranch(); - if (analysisLockService.isLocked(project.getId(), sourceBranch, AnalysisLockType.PR_ANALYSIS)) { + Optional earlyLock = analysisLockService.acquireLock( + project, sourceBranch, AnalysisLockType.PR_ANALYSIS, + payload.commitHash(), Long.parseLong(payload.pullRequestId())); + + if (earlyLock.isEmpty()) { log.info("PR analysis already in progress for project={}, branch={}, PR={} - skipping duplicate webhook", project.getId(), sourceBranch, payload.pullRequestId()); return WebhookResult.ignored("PR analysis already in progress for this branch"); } + // Lock acquired - placeholder posting is now protected from race conditions + // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it + // since acquireLockWithWait() will detect the existing lock and use it + // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); From 13a63c8451088cc5f2410be2049f630b13482279 Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 14:35:47 +0200 Subject: [PATCH 06/34] fix: Enhance alias management by implementing backup and migration strategy for direct collections --- .../src/rag_pipeline/core/index_manager.py | 77 +++++++++++++++---- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py index f92014af..1694fcd9 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py @@ -535,19 +535,9 @@ def index_repository( if old_collection_exists and not is_direct_collection: old_versioned_name = self._resolve_alias_to_collection(alias_name) - # If there's a direct collection with the target name, we need to delete it FIRST - # before we can create an alias with that name - if is_direct_collection: - logger.info(f"Migrating from direct collection to alias-based indexing. Deleting old collection: {alias_name}") - try: - self.qdrant_client.delete_collection(alias_name) - except Exception as del_err: - logger.error(f"Failed to delete old direct collection before alias swap: {del_err}") - raise Exception(f"Cannot create alias - collection '{alias_name}' exists and cannot be deleted: {del_err}") - alias_operations = [] - # Delete old alias if exists (not direct collection - already handled above) + # Delete old alias if exists (not direct collection) if old_collection_exists and not is_direct_collection: alias_operations.append( DeleteAliasOperation(delete_alias=DeleteAlias(alias_name=alias_name)) @@ -562,11 +552,66 @@ def index_repository( ) # Perform atomic alias swap - self.qdrant_client.update_collection_aliases( - change_aliases_operations=alias_operations - ) - - logger.info(f"Alias swap completed successfully: {alias_name} -> {temp_collection_name}") + # If this fails due to a direct collection existing, we'll handle it below + try: + self.qdrant_client.update_collection_aliases( + change_aliases_operations=alias_operations + ) + logger.info(f"Alias swap completed successfully: {alias_name} -> {temp_collection_name}") + except Exception as alias_err: + # Check if failure is due to direct collection conflict + if is_direct_collection and "already exists" in str(alias_err).lower(): + logger.info(f"Alias creation failed due to existing collection. Migrating from direct collection to alias-based indexing...") + + # Rename the old collection to a backup name before deleting + # This way if alias creation fails, we can recover + backup_name = f"{alias_name}_backup_{int(time.time())}" + logger.info(f"Creating backup: renaming {alias_name} to {backup_name}") + + # Qdrant doesn't support rename, so we need to: + # 1. Create an alias pointing old collection to backup name + # 2. Delete the original collection name (which is actually a collection, not alias) + # 3. Create the new alias + + # Since we can't rename, delete the old collection but only AFTER + # we've verified the temp collection is ready (already done above) + logger.warning(f"Deleting old direct collection: {alias_name} (temp collection {temp_collection_name} verified with {temp_collection_info.points_count} points)") + + try: + self.qdrant_client.delete_collection(alias_name) + except Exception as del_err: + logger.error(f"Failed to delete old direct collection: {del_err}") + raise Exception(f"Cannot create alias - collection '{alias_name}' exists and cannot be deleted: {del_err}") + + # Now create the alias (should succeed since collection is deleted) + retry_operations = [ + CreateAliasOperation(create_alias=CreateAlias( + alias_name=alias_name, + collection_name=temp_collection_name + )) + ] + + try: + self.qdrant_client.update_collection_aliases( + change_aliases_operations=retry_operations + ) + logger.info(f"Alias swap completed successfully after migration: {alias_name} -> {temp_collection_name}") + except Exception as retry_err: + # Critical failure - we've deleted the old collection but can't create alias + # The temp collection still exists with all data, log clear instructions + logger.critical( + f"CRITICAL: Alias creation failed after deleting old collection! " + f"Data is safe in '{temp_collection_name}'. " + f"Manual fix required: create alias '{alias_name}' -> '{temp_collection_name}'. " + f"Error: {retry_err}" + ) + raise Exception( + f"Alias creation failed after migration. Data preserved in '{temp_collection_name}'. " + f"Run: qdrant alias create {alias_name} -> {temp_collection_name}" + ) + else: + # Some other alias error, re-raise + raise alias_err # Delete old versioned collection (after alias swap is complete) if old_versioned_name and old_versioned_name != temp_collection_name: From 95d74e1c93cb5db9a1334e53769bc6f52f400108 Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 14:59:29 +0200 Subject: [PATCH 07/34] fix: Enhance AI analysis by incorporating full PR issue history and resolution tracking --- .../dto/request/ai/AiAnalysisRequestImpl.java | 23 +++++ .../request/ai/AiRequestPreviousIssueDTO.java | 19 +++- .../PullRequestAnalysisProcessor.java | 12 ++- .../service/vcs/VcsAiClientService.java | 20 +++++ .../ai/AiRequestPreviousIssueDTOTest.java | 44 +++++++-- .../codeanalysis/CodeAnalysisRepository.java | 8 ++ .../core/service/CodeAnalysisService.java | 89 ++++++++++++++++--- .../service/BitbucketAiClientService.java | 21 ++++- .../github/service/GitHubAiClientService.java | 22 ++++- .../gitlab/service/GitLabAiClientService.java | 22 ++++- python-ecosystem/mcp-client/model/models.py | 5 ++ .../service/multi_stage_orchestrator.py | 33 ++++++- .../mcp-client/utils/response_parser.py | 12 ++- 13 files changed, 290 insertions(+), 40 deletions(-) diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java index a8244eca..711027fc 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java @@ -278,6 +278,29 @@ public T withPreviousAnalysisData(Optional optionalPreviousAnalysi return self(); } + /** + * Set previous issues from ALL PR analysis versions. + * This provides the LLM with complete issue history including resolved issues, + * helping it understand what was already found and fixed. + * + * Issues are deduplicated by keeping only the most recent version of each issue. + * Resolved issues are included so the LLM knows what was already addressed. + * + * @param allPrAnalyses List of all analyses for this PR, ordered by version DESC (newest first) + */ + public T withAllPrAnalysesData(List allPrAnalyses) { + if (allPrAnalyses == null || allPrAnalyses.isEmpty()) { + return self(); + } + + this.previousCodeAnalysisIssues = allPrAnalyses.stream() + .flatMap(analysis -> analysis.getIssues().stream()) + .map(AiRequestPreviousIssueDTO::fromEntity) + .toList(); + + return self(); + } + public T withMaxAllowedTokens(int maxAllowedTokens) { this.maxAllowedTokens = maxAllowedTokens; return self(); diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTO.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTO.java index 8d166455..d8bfedd9 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTO.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTO.java @@ -15,12 +15,23 @@ public record AiRequestPreviousIssueDTO( String branch, String pullRequestId, String status, // open|resolved|ignored - String category + String category, + // Resolution tracking fields + Integer prVersion, // Which PR iteration this issue was found in + String resolvedDescription, // Description of how the issue was resolved + String resolvedByCommit, // Commit hash that resolved the issue + Long resolvedInPrVersion // PR version where this was resolved (null if still open) ) { public static AiRequestPreviousIssueDTO fromEntity(CodeAnalysisIssue issue) { String categoryStr = issue.getIssueCategory() != null ? issue.getIssueCategory().name() : IssueCategory.CODE_QUALITY.name(); + + Integer prVersion = null; + if (issue.getAnalysis() != null) { + prVersion = issue.getAnalysis().getPrVersion(); + } + return new AiRequestPreviousIssueDTO( String.valueOf(issue.getId()), categoryStr, @@ -33,7 +44,11 @@ public static AiRequestPreviousIssueDTO fromEntity(CodeAnalysisIssue issue) { issue.getAnalysis() == null ? null : issue.getAnalysis().getBranchName(), issue.getAnalysis() == null || issue.getAnalysis().getPrNumber() == null ? null : String.valueOf(issue.getAnalysis().getPrNumber()), issue.isResolved() ? "resolved" : "open", - categoryStr + categoryStr, + prVersion, + issue.getResolvedDescription(), + issue.getResolvedCommitHash(), + issue.getResolvedAnalysisId() ); } } diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index 634f0419..f5db459a 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -29,6 +29,7 @@ import java.time.Duration; import java.time.Instant; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -139,16 +140,23 @@ public Map process( return Map.of("status", "cached", "cached", true); } - Optional previousAnalysis = codeAnalysisService.getPreviousVersionCodeAnalysis( + // Get all previous analyses for this PR to provide full issue history to AI + List allPrAnalyses = codeAnalysisService.getAllPrAnalyses( project.getId(), request.getPullRequestId() ); + + // Get the most recent analysis for incremental diff calculation + Optional previousAnalysis = allPrAnalyses.isEmpty() + ? Optional.empty() + : Optional.of(allPrAnalyses.get(0)); // Ensure branch index exists for target branch if configured ensureRagIndexForTargetBranch(project, request.getTargetBranchName(), consumer); VcsAiClientService aiClientService = vcsServiceFactory.getAiClientService(provider); - AiAnalysisRequest aiRequest = aiClientService.buildAiAnalysisRequest(project, request, previousAnalysis); + AiAnalysisRequest aiRequest = aiClientService.buildAiAnalysisRequest( + project, request, previousAnalysis, allPrAnalyses); Map aiResponse = aiAnalysisClient.performAnalysis(aiRequest, event -> { try { diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/vcs/VcsAiClientService.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/vcs/VcsAiClientService.java index 1ee29fbd..3b894fa9 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/vcs/VcsAiClientService.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/vcs/VcsAiClientService.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.AnalysisProcessRequest; import java.security.GeneralSecurityException; +import java.util.List; import java.util.Optional; /** @@ -31,4 +32,23 @@ AiAnalysisRequest buildAiAnalysisRequest( AnalysisProcessRequest request, Optional previousAnalysis ) throws GeneralSecurityException; + + /** + * Builds an AI analysis request with full PR issue history. + * + * @param project The project being analyzed + * @param request The analysis process request + * @param previousAnalysis Optional previous analysis for incremental analysis (used for delta diff calculation) + * @param allPrAnalyses All analyses for this PR, ordered by version DESC (for issue history) + * @return The AI analysis request ready to be sent to the AI client + */ + default AiAnalysisRequest buildAiAnalysisRequest( + Project project, + AnalysisProcessRequest request, + Optional previousAnalysis, + List allPrAnalyses + ) throws GeneralSecurityException { + // Default implementation falls back to the previous method for backward compatibility + return buildAiAnalysisRequest(project, request, previousAnalysis); + } } diff --git a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTOTest.java b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTOTest.java index 2fe1aca2..d36d1c74 100644 --- a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTOTest.java +++ b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiRequestPreviousIssueDTOTest.java @@ -30,7 +30,11 @@ void shouldCreateRecordWithAllFields() { "main", "100", "open", - "SECURITY" + "SECURITY", + 1, // prVersion + null, // resolvedDescription + null, // resolvedByCommit + null // resolvedInPrVersion ); assertThat(dto.id()).isEqualTo("123"); @@ -45,13 +49,14 @@ void shouldCreateRecordWithAllFields() { assertThat(dto.pullRequestId()).isEqualTo("100"); assertThat(dto.status()).isEqualTo("open"); assertThat(dto.category()).isEqualTo("SECURITY"); + assertThat(dto.prVersion()).isEqualTo(1); } @Test @DisplayName("should handle null values") void shouldHandleNullValues() { AiRequestPreviousIssueDTO dto = new AiRequestPreviousIssueDTO( - null, null, null, null, null, null, null, null, null, null, null, null + null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null ); assertThat(dto.id()).isNull(); @@ -64,13 +69,13 @@ void shouldHandleNullValues() { @DisplayName("should implement equals correctly") void shouldImplementEqualsCorrectly() { AiRequestPreviousIssueDTO dto1 = new AiRequestPreviousIssueDTO( - "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat" + "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat", 1, null, null, null ); AiRequestPreviousIssueDTO dto2 = new AiRequestPreviousIssueDTO( - "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat" + "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat", 1, null, null, null ); AiRequestPreviousIssueDTO dto3 = new AiRequestPreviousIssueDTO( - "2", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat" + "2", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat", 1, null, null, null ); assertThat(dto1).isEqualTo(dto2); @@ -81,10 +86,10 @@ void shouldImplementEqualsCorrectly() { @DisplayName("should implement hashCode correctly") void shouldImplementHashCodeCorrectly() { AiRequestPreviousIssueDTO dto1 = new AiRequestPreviousIssueDTO( - "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat" + "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat", 1, null, null, null ); AiRequestPreviousIssueDTO dto2 = new AiRequestPreviousIssueDTO( - "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat" + "1", "type", "HIGH", "reason", "fix", "diff", "file.java", 10, "main", "1", "open", "cat", 1, null, null, null ); assertThat(dto1.hashCode()).isEqualTo(dto2.hashCode()); @@ -94,10 +99,14 @@ void shouldImplementHashCodeCorrectly() { @DisplayName("should support resolved status") void shouldSupportResolvedStatus() { AiRequestPreviousIssueDTO dto = new AiRequestPreviousIssueDTO( - "1", "type", "LOW", "reason", null, null, "file.java", 5, "dev", "2", "resolved", "CODE_QUALITY" + "1", "type", "LOW", "reason", null, null, "file.java", 5, "dev", "2", "resolved", "CODE_QUALITY", + 1, "Fixed by adding null check", "abc123", 2L ); assertThat(dto.status()).isEqualTo("resolved"); + assertThat(dto.resolvedDescription()).isEqualTo("Fixed by adding null check"); + assertThat(dto.resolvedByCommit()).isEqualTo("abc123"); + assertThat(dto.resolvedInPrVersion()).isEqualTo(2L); } @Nested @@ -110,6 +119,7 @@ void shouldConvertEntityWithAllFields() { CodeAnalysis analysis = mock(CodeAnalysis.class); when(analysis.getBranchName()).thenReturn("feature-branch"); when(analysis.getPrNumber()).thenReturn(42L); + when(analysis.getPrVersion()).thenReturn(2); CodeAnalysisIssue issue = mock(CodeAnalysisIssue.class); when(issue.getId()).thenReturn(123L); @@ -122,6 +132,9 @@ void shouldConvertEntityWithAllFields() { when(issue.getFilePath()).thenReturn("src/Main.java"); when(issue.getLineNumber()).thenReturn(50); when(issue.isResolved()).thenReturn(false); + when(issue.getResolvedDescription()).thenReturn(null); + when(issue.getResolvedCommitHash()).thenReturn(null); + when(issue.getResolvedAnalysisId()).thenReturn(null); AiRequestPreviousIssueDTO dto = AiRequestPreviousIssueDTO.fromEntity(issue); @@ -137,14 +150,16 @@ void shouldConvertEntityWithAllFields() { assertThat(dto.pullRequestId()).isEqualTo("42"); assertThat(dto.status()).isEqualTo("open"); assertThat(dto.category()).isEqualTo("SECURITY"); + assertThat(dto.prVersion()).isEqualTo(2); } @Test - @DisplayName("should convert resolved entity") + @DisplayName("should convert resolved entity with resolution tracking") void shouldConvertResolvedEntity() { CodeAnalysis analysis = mock(CodeAnalysis.class); when(analysis.getBranchName()).thenReturn("main"); when(analysis.getPrNumber()).thenReturn(10L); + when(analysis.getPrVersion()).thenReturn(3); CodeAnalysisIssue issue = mock(CodeAnalysisIssue.class); when(issue.getId()).thenReturn(456L); @@ -155,10 +170,17 @@ void shouldConvertResolvedEntity() { when(issue.getFilePath()).thenReturn("src/Utils.java"); when(issue.getLineNumber()).thenReturn(10); when(issue.isResolved()).thenReturn(true); + when(issue.getResolvedDescription()).thenReturn("Fixed by refactoring"); + when(issue.getResolvedCommitHash()).thenReturn("abc123def"); + when(issue.getResolvedAnalysisId()).thenReturn(5L); AiRequestPreviousIssueDTO dto = AiRequestPreviousIssueDTO.fromEntity(issue); assertThat(dto.status()).isEqualTo("resolved"); + assertThat(dto.prVersion()).isEqualTo(3); + assertThat(dto.resolvedDescription()).isEqualTo("Fixed by refactoring"); + assertThat(dto.resolvedByCommit()).isEqualTo("abc123def"); + assertThat(dto.resolvedInPrVersion()).isEqualTo(5L); } @Test @@ -167,6 +189,7 @@ void shouldHandleNullIssueCategoryWithDefault() { CodeAnalysis analysis = mock(CodeAnalysis.class); when(analysis.getBranchName()).thenReturn("main"); when(analysis.getPrNumber()).thenReturn(1L); + when(analysis.getPrVersion()).thenReturn(1); CodeAnalysisIssue issue = mock(CodeAnalysisIssue.class); when(issue.getId()).thenReturn(1L); @@ -186,6 +209,7 @@ void shouldHandleNullIssueCategoryWithDefault() { void shouldHandleNullSeverity() { CodeAnalysis analysis = mock(CodeAnalysis.class); when(analysis.getBranchName()).thenReturn("main"); + when(analysis.getPrVersion()).thenReturn(1); CodeAnalysisIssue issue = mock(CodeAnalysisIssue.class); when(issue.getId()).thenReturn(2L); @@ -213,6 +237,7 @@ void shouldHandleNullAnalysis() { assertThat(dto.branch()).isNull(); assertThat(dto.pullRequestId()).isNull(); + assertThat(dto.prVersion()).isNull(); } @Test @@ -221,6 +246,7 @@ void shouldHandleAnalysisWithNullPrNumber() { CodeAnalysis analysis = mock(CodeAnalysis.class); when(analysis.getBranchName()).thenReturn("develop"); when(analysis.getPrNumber()).thenReturn(null); + when(analysis.getPrVersion()).thenReturn(1); CodeAnalysisIssue issue = mock(CodeAnalysisIssue.class); when(issue.getId()).thenReturn(4L); diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java index dbab64ca..27aff62d 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java @@ -111,4 +111,12 @@ Page searchAnalyses( }) @Query("SELECT ca FROM CodeAnalysis ca WHERE ca.id = :id") Optional findByIdWithIssues(@Param("id") Long id); + + /** + * Find all analyses for a PR across all versions, ordered by version descending. + * Used to provide LLM with full issue history including resolved issues. + */ + @org.springframework.data.jpa.repository.EntityGraph(attributePaths = {"issues"}) + @Query("SELECT ca FROM CodeAnalysis ca WHERE ca.project.id = :projectId AND ca.prNumber = :prNumber ORDER BY ca.prVersion DESC") + List findAllByProjectIdAndPrNumberOrderByPrVersionDesc(@Param("projectId") Long projectId, @Param("prNumber") Long prNumber); } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java index 7f3ca41f..c4210261 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java @@ -110,6 +110,11 @@ private CodeAnalysis fillAnalysisData( return analysisRepository.save(analysis); } + // Save analysis first to get its ID for resolution tracking + CodeAnalysis savedAnalysis = analysisRepository.save(analysis); + Long analysisId = savedAnalysis.getId(); + Long prNumber = savedAnalysis.getPrNumber(); + // Handle issues as List (array format from AI) if (issuesObj instanceof List) { List issuesList = (List) issuesObj; @@ -122,9 +127,11 @@ private CodeAnalysis fillAnalysisData( log.warn("Null issue data at index: {}", i); continue; } - CodeAnalysisIssue issue = createIssueFromData(issueData, String.valueOf(i), vcsAuthorId, vcsAuthorUsername); + CodeAnalysisIssue issue = createIssueFromData( + issueData, String.valueOf(i), vcsAuthorId, vcsAuthorUsername, + commitHash, prNumber, analysisId); if (issue != null) { - analysis.addIssue(issue); + savedAnalysis.addIssue(issue); } } catch (Exception e) { log.error("Error processing issue at index '{}': {}", i, e.getMessage(), e); @@ -145,9 +152,11 @@ else if (issuesObj instanceof Map) { continue; } - CodeAnalysisIssue issue = createIssueFromData(issueData, entry.getKey(), vcsAuthorId, vcsAuthorUsername); + CodeAnalysisIssue issue = createIssueFromData( + issueData, entry.getKey(), vcsAuthorId, vcsAuthorUsername, + commitHash, prNumber, analysisId); if (issue != null) { - analysis.addIssue(issue); + savedAnalysis.addIssue(issue); } } catch (Exception e) { log.error("Error processing issue with key '{}': {}", entry.getKey(), e.getMessage(), e); @@ -157,19 +166,19 @@ else if (issuesObj instanceof Map) { log.warn("Issues field is neither List nor Map: {}", issuesObj.getClass().getName()); } - log.info("Successfully created analysis with {} issues", analysis.getIssues().size()); + log.info("Successfully created analysis with {} issues", savedAnalysis.getIssues().size()); // Evaluate quality gate - QualityGate qualityGate = getQualityGateForAnalysis(analysis); + QualityGate qualityGate = getQualityGateForAnalysis(savedAnalysis); if (qualityGate != null) { - QualityGateResult qgResult = qualityGateEvaluator.evaluate(analysis, qualityGate); - analysis.setAnalysisResult(qgResult.result()); + QualityGateResult qgResult = qualityGateEvaluator.evaluate(savedAnalysis, qualityGate); + savedAnalysis.setAnalysisResult(qgResult.result()); log.info("Quality gate '{}' evaluated with result: {}", qualityGate.getName(), qgResult.result()); } else { log.info("No quality gate found for analysis, skipping evaluation"); } - return analysisRepository.save(analysis); + return analysisRepository.save(savedAnalysis); } catch (Exception e) { log.error("Error creating analysis from AI response: {}", e.getMessage(), e); @@ -225,6 +234,14 @@ public Optional getPreviousVersionCodeAnalysis(Long projectId, Lon return codeAnalysisRepository.findByProjectIdAndPrNumberWithMaxPrVersion(projectId, prNumber); } + /** + * Get all analyses for a PR across all versions. + * Useful for providing full issue history to AI including resolved issues. + */ + public List getAllPrAnalyses(Long projectId, Long prNumber) { + return codeAnalysisRepository.findAllByProjectIdAndPrNumberOrderByPrVersionDesc(projectId, prNumber); + } + public int getMaxAnalysisPrVersion(Long projectId, Long prNumber) { return codeAnalysisRepository.findMaxPrVersion(projectId, prNumber).orElse(0); } @@ -233,13 +250,43 @@ public Optional findAnalysisByProjectAndPrNumberAndVersion(Long pr return codeAnalysisRepository.findByProjectIdAndPrNumberAndPrVersion(projectId, prNumber, prVersion); } - private CodeAnalysisIssue createIssueFromData(Map issueData, String issueKey, String vcsAuthorId, String vcsAuthorUsername) { + private CodeAnalysisIssue createIssueFromData( + Map issueData, + String issueKey, + String vcsAuthorId, + String vcsAuthorUsername, + String commitHash, + Long prNumber, + Long analysisId + ) { try { CodeAnalysisIssue issue = new CodeAnalysisIssue(); issue.setVcsAuthorId(vcsAuthorId); issue.setVcsAuthorUsername(vcsAuthorUsername); + // Check if this is a persisted issue from previous analysis (has original ID) + Object originalIdObj = issueData.get("id"); + CodeAnalysisIssue originalIssue = null; + if (originalIdObj != null) { + try { + Long originalId = null; + if (originalIdObj instanceof String) { + originalId = Long.parseLong((String) originalIdObj); + } else if (originalIdObj instanceof Number) { + originalId = ((Number) originalIdObj).longValue(); + } + if (originalId != null) { + originalIssue = issueRepository.findById(originalId).orElse(null); + if (originalIssue != null) { + log.debug("Found original issue {} for reconciliation", originalId); + } + } + } catch (NumberFormatException e) { + log.debug("Could not parse issue ID '{}' as Long, treating as new issue", originalIdObj); + } + } + String severityStr = (String) issueData.get("severity"); if (severityStr == null) { log.warn("No severity found for issue {}", issueKey); @@ -303,6 +350,17 @@ private CodeAnalysisIssue createIssueFromData(Map issueData, Str boolean isResolved = isResolvedObj != null ? isResolvedObj : false; issue.setResolved(isResolved); + // If this issue is resolved and we have original issue data, populate resolution tracking + if (isResolved && originalIssue != null) { + issue.setResolvedDescription(reason); // AI provides resolution reason in the 'reason' field + issue.setResolvedByPr(prNumber); + issue.setResolvedCommitHash(commitHash); + issue.setResolvedAnalysisId(analysisId); + issue.setResolvedAt(OffsetDateTime.now()); + issue.setResolvedBy(vcsAuthorUsername); + log.info("Issue {} marked as resolved by PR {} commit {}", originalIdObj, prNumber, commitHash); + } + String categoryStr = (String) issueData.get("category"); if (categoryStr != null && !categoryStr.isBlank()) { issue.setIssueCategory(IssueCategory.fromString(categoryStr)); @@ -310,8 +368,8 @@ private CodeAnalysisIssue createIssueFromData(Map issueData, Str issue.setIssueCategory(IssueCategory.CODE_QUALITY); } - log.debug("Created issue: {} severity, category: {}, file: {}, line: {}", - issue.getSeverity(), issue.getIssueCategory(), issue.getFilePath(), issue.getLineNumber()); + log.debug("Created issue: {} severity, category: {}, file: {}, line: {}, resolved: {}", + issue.getSeverity(), issue.getIssueCategory(), issue.getFilePath(), issue.getLineNumber(), isResolved); return issue; @@ -321,6 +379,13 @@ private CodeAnalysisIssue createIssueFromData(Map issueData, Str } } + /** + * Overload for backward compatibility with callers that don't have resolution context + */ + private CodeAnalysisIssue createIssueFromData(Map issueData, String issueKey, String vcsAuthorId, String vcsAuthorUsername) { + return createIssueFromData(issueData, issueKey, vcsAuthorId, vcsAuthorUsername, null, null, null); + } + public CodeAnalysis createAnalysis(Project project, AnalysisType analysisType) { CodeAnalysis analysis = new CodeAnalysis(); analysis.setProject(project); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java index 0db3cee3..0094252d 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java @@ -94,14 +94,29 @@ public AiAnalysisRequest buildAiAnalysisRequest( if(request.getAnalysisType() == AnalysisType.BRANCH_ANALYSIS){ return buildBranchAnalysisRequest(project, (BranchProcessRequest) request, previousAnalysis); } else { - return buildPrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis); + return buildPrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis, Collections.emptyList()); + } + } + + @Override + public AiAnalysisRequest buildAiAnalysisRequest( + Project project, + AnalysisProcessRequest request, + Optional previousAnalysis, + List allPrAnalyses + ) throws GeneralSecurityException { + if(request.getAnalysisType() == AnalysisType.BRANCH_ANALYSIS){ + return buildBranchAnalysisRequest(project, (BranchProcessRequest) request, previousAnalysis); + } else { + return buildPrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis, allPrAnalyses); } } public AiAnalysisRequest buildPrAnalysisRequest( Project project, PrProcessRequest request, - Optional previousAnalysis + Optional previousAnalysis, + List allPrAnalyses ) throws GeneralSecurityException { VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); @@ -216,7 +231,7 @@ public AiAnalysisRequest buildPrAnalysisRequest( .withProjectVcsConnectionBindingInfo(vcsInfo.workspace(), vcsInfo.repoSlug()) .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(projectAiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) - .withPreviousAnalysisData(previousAnalysis) + .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version .withMaxAllowedTokens(aiConnection.getTokenLimitation()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(prTitle) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java index 7075aca6..b39202de 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java @@ -88,14 +88,30 @@ public AiAnalysisRequest buildAiAnalysisRequest( case BRANCH_ANALYSIS: return buildBranchAnalysisRequest(project, (BranchProcessRequest) request, previousAnalysis); default: - return buildPrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis); + return buildPrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis, Collections.emptyList()); + } + } + + @Override + public AiAnalysisRequest buildAiAnalysisRequest( + Project project, + AnalysisProcessRequest request, + Optional previousAnalysis, + List allPrAnalyses + ) throws GeneralSecurityException { + switch (request.getAnalysisType()) { + case BRANCH_ANALYSIS: + return buildBranchAnalysisRequest(project, (BranchProcessRequest) request, previousAnalysis); + default: + return buildPrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis, allPrAnalyses); } } private AiAnalysisRequest buildPrAnalysisRequest( Project project, PrProcessRequest request, - Optional previousAnalysis + Optional previousAnalysis, + List allPrAnalyses ) throws GeneralSecurityException { VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); @@ -206,7 +222,7 @@ private AiAnalysisRequest buildPrAnalysisRequest( .withProjectVcsConnectionBindingInfo(vcsInfo.owner(), vcsInfo.repoSlug()) .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) - .withPreviousAnalysisData(previousAnalysis) + .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version .withMaxAllowedTokens(aiConnection.getTokenLimitation()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(prTitle) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java index a6dd0fce..1c697fc7 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java @@ -88,14 +88,30 @@ public AiAnalysisRequest buildAiAnalysisRequest( case BRANCH_ANALYSIS: return buildBranchAnalysisRequest(project, (BranchProcessRequest) request, previousAnalysis); default: - return buildMrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis); + return buildMrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis, Collections.emptyList()); + } + } + + @Override + public AiAnalysisRequest buildAiAnalysisRequest( + Project project, + AnalysisProcessRequest request, + Optional previousAnalysis, + List allPrAnalyses + ) throws GeneralSecurityException { + switch (request.getAnalysisType()) { + case BRANCH_ANALYSIS: + return buildBranchAnalysisRequest(project, (BranchProcessRequest) request, previousAnalysis); + default: + return buildMrAnalysisRequest(project, (PrProcessRequest) request, previousAnalysis, allPrAnalyses); } } private AiAnalysisRequest buildMrAnalysisRequest( Project project, PrProcessRequest request, - Optional previousAnalysis + Optional previousAnalysis, + List allPrAnalyses ) throws GeneralSecurityException { VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); @@ -207,7 +223,7 @@ private AiAnalysisRequest buildMrAnalysisRequest( .withProjectVcsConnectionBindingInfo(vcsInfo.namespace(), vcsInfo.repoSlug()) .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) - .withPreviousAnalysisData(previousAnalysis) + .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version .withMaxAllowedTokens(aiConnection.getTokenLimitation()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(mrTitle) diff --git a/python-ecosystem/mcp-client/model/models.py b/python-ecosystem/mcp-client/model/models.py index 153de871..e6aa26e8 100644 --- a/python-ecosystem/mcp-client/model/models.py +++ b/python-ecosystem/mcp-client/model/models.py @@ -41,6 +41,11 @@ class IssueDTO(BaseModel): branch: Optional[str] = None pullRequestId: Optional[str] = None status: Optional[str] = None # open|resolved|ignored + # Resolution tracking fields (for full PR issue history) + prVersion: Optional[int] = None # Which PR iteration this issue was found in + resolvedDescription: Optional[str] = None # How the issue was resolved + resolvedByCommit: Optional[str] = None # Commit hash that resolved the issue + resolvedInPrVersion: Optional[int] = None # PR version where this was resolved # Legacy fields for backwards compatibility title: Optional[str] = None # Legacy - use reason instead description: Optional[str] = None # Legacy - use suggestedFixDescription instead diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index 70fa3cd7..a479a1e3 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -344,11 +344,20 @@ def _issue_matches_files(self, issue: Any, file_paths: List[str]) -> bool: return False def _format_previous_issues_for_batch(self, issues: List[Any]) -> str: - """Format previous issues for inclusion in batch prompt.""" + """Format previous issues for inclusion in batch prompt. + + Includes full issue history with resolution tracking so LLM knows: + - Which issues were previously found + - Which have been resolved (and how) + - Which PR version each issue was found/resolved in + """ if not issues: return "" - lines = ["=== PREVIOUS ISSUES IN THESE FILES (check if resolved) ==="] + lines = ["=== PREVIOUS ISSUES HISTORY (check if resolved/persisting) ==="] + lines.append("Issues from ALL previous PR iterations. Status indicates if resolved or still open.") + lines.append("") + for issue in issues: if hasattr(issue, 'model_dump'): data = issue.model_dump() @@ -362,12 +371,28 @@ def _format_previous_issues_for_batch(self, issues: List[Any]) -> str: file_path = data.get('file', data.get('filePath', 'unknown')) line = data.get('line', data.get('lineNumber', '?')) reason = data.get('reason', data.get('description', 'No description')) + status = data.get('status', 'open') + pr_version = data.get('prVersion', '?') + + # Format status with resolution details if resolved + status_display = status.upper() + if status == 'resolved': + resolved_desc = data.get('resolvedDescription', '') + resolved_in = data.get('resolvedInPrVersion', '') + if resolved_desc: + status_display += f" - {resolved_desc}" + if resolved_in: + status_display += f" (in v{resolved_in})" - lines.append(f"[ID:{issue_id}] {severity} @ {file_path}:{line}") + lines.append(f"[ID:{issue_id}] {severity} @ {file_path}:{line} (v{pr_version})") + lines.append(f" Status: {status_display}") lines.append(f" Issue: {reason}") lines.append("") - lines.append("Mark these as 'isResolved: true' if fixed in the diff above.") + lines.append("INSTRUCTIONS:") + lines.append("- For OPEN issues: Set 'isResolved: true' if fixed in the current diff, 'isResolved: false' if still present") + lines.append("- For already RESOLVED issues: Do NOT re-report them (they're just for context)") + lines.append("- Preserve the 'id' field for all issues you report") lines.append("=== END PREVIOUS ISSUES ===") return "\n".join(lines) diff --git a/python-ecosystem/mcp-client/utils/response_parser.py b/python-ecosystem/mcp-client/utils/response_parser.py index 9336a7ee..d32c2c61 100644 --- a/python-ecosystem/mcp-client/utils/response_parser.py +++ b/python-ecosystem/mcp-client/utils/response_parser.py @@ -12,7 +12,7 @@ class ResponseParser: # Valid issue fields - others will be removed VALID_ISSUE_FIELDS = { - 'id', 'severity', 'category', 'file', 'line', 'reason', + 'id', 'issueId', 'severity', 'category', 'file', 'line', 'reason', 'suggestedFixDescription', 'suggestedFixDiff', 'isResolved' } @@ -66,7 +66,7 @@ def _normalize_diff(diff_value: Any) -> Optional[str]: def _clean_issue(issue: Dict[str, Any]) -> Dict[str, Any]: """ Clean and normalize a single issue object. - - Removes unexpected fields (like 'id') + - Normalizes issueId to id (for consistent field naming) - Normalizes suggestedFixDiff format - Normalizes severity and category @@ -81,7 +81,15 @@ def _clean_issue(issue: Dict[str, Any]) -> Dict[str, Any]: cleaned = {} + # Normalize issueId to id (AI may use either) + if 'issueId' in issue and 'id' not in issue: + issue['id'] = issue['issueId'] + for key, value in issue.items(): + # Skip issueId since we normalized it to id above + if key == 'issueId': + continue + # Skip fields not in valid set if key not in ResponseParser.VALID_ISSUE_FIELDS: continue From e509ffcb4208fa54b14610db72afe539ea694095 Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 15:27:51 +0200 Subject: [PATCH 08/34] fix: Update issue reconciliation logic to handle previous issues in both incremental and full modes --- .../mcp-client/service/multi_stage_orchestrator.py | 9 ++++++--- .../mcp-client/utils/prompts/prompt_constants.py | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index a479a1e3..a34c5744 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -195,8 +195,9 @@ async def orchestrate_review( ) self._emit_progress(60, f"Stage 1 Complete: {len(file_issues)} issues found across files") - # === STAGE 1.5: Issue Reconciliation (INCREMENTAL only) === - if is_incremental and request.previousCodeAnalysisIssues: + # === STAGE 1.5: Issue Reconciliation === + # Run reconciliation if we have previous issues (both INCREMENTAL and FULL modes) + if request.previousCodeAnalysisIssues: self._emit_status("reconciliation_started", "Reconciling previous issues...") file_issues = await self._reconcile_previous_issues( request, file_issues, processed_diff @@ -625,8 +626,10 @@ async def _review_file_batch( logger.warning(f"Failed to fetch per-batch RAG context: {e}") # For incremental mode, filter previous issues relevant to this batch + # Also pass previous issues in FULL mode if they exist (subsequent PR iterations) previous_issues_for_batch = "" - if is_incremental and request.previousCodeAnalysisIssues: + has_previous_issues = request.previousCodeAnalysisIssues and len(request.previousCodeAnalysisIssues) > 0 + if has_previous_issues: relevant_prev_issues = [ issue for issue in request.previousCodeAnalysisIssues if self._issue_matches_files(issue, batch_file_paths) diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index 64d09c20..91cd4e7b 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -369,7 +369,7 @@ Review each file below independently. For each file, produce a review result. Use the CODEBASE CONTEXT above to understand how the changed code integrates with existing patterns, dependencies, and architectural decisions. -If previous issue fixed in a current version, mark it as resolved. +If a previous issue (from PREVIOUS ISSUES section) is fixed in the current version, include it with isResolved=true and preserve its original id. INPUT FILES: Priority: {priority} @@ -385,14 +385,15 @@ "analysis_summary": "Summary of findings for file 1", "issues": [ {{ + "id": "original-issue-id-if-from-previous-issues", "severity": "HIGH|MEDIUM|LOW|INFO", "category": "SECURITY|PERFORMANCE|CODE_QUALITY|BUG_RISK|STYLE|DOCUMENTATION|BEST_PRACTICES|ERROR_HANDLING|TESTING|ARCHITECTURE", "file": "path/to/file1", "line": "42", - "reason": "Detailed explanation of the issue", + "reason": "Detailed explanation of the issue (or resolution reason if isResolved=true)", "suggestedFixDescription": "Clear description of how to fix the issue", "suggestedFixDiff": "Unified diff showing exact code changes (MUST follow SUGGESTED_FIX_DIFF_FORMAT)", - "isResolved": false|true + "isResolved": false }} ], "confidence": "HIGH|MEDIUM|LOW|INFO", @@ -409,6 +410,7 @@ - Return exactly one review object per input file. - Match file paths exactly. - Skip style nits. +- For PREVIOUS ISSUES that are now RESOLVED: include them with isResolved=true and PRESERVE the original id field. - suggestedFixDiff MUST be a valid unified diff string if a fix is proposed. """ From a2889c7327e99b8025111be58a5d7f7d7a48c0bb Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 17:59:50 +0200 Subject: [PATCH 09/34] fix: Improve handling of issue resolution status and logging for better clarity --- .../core/service/CodeAnalysisService.java | 13 +++++++-- .../service/multi_stage_orchestrator.py | 27 ++++++++++++------- .../utils/prompts/prompt_constants.py | 3 ++- .../mcp-client/utils/response_parser.py | 6 +++++ 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java index c4210261..d0fd4624 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java @@ -346,9 +346,18 @@ private CodeAnalysisIssue createIssueFromData( } issue.setSuggestedFixDiff(suggestedFixDiff); - Boolean isResolvedObj = (Boolean) issueData.get("isResolved"); - boolean isResolved = isResolvedObj != null ? isResolvedObj : false; + // Parse isResolved - handle both Boolean and String representations + Object isResolvedObj = issueData.get("isResolved"); + boolean isResolved = false; + if (isResolvedObj instanceof Boolean) { + isResolved = (Boolean) isResolvedObj; + } else if (isResolvedObj instanceof String) { + isResolved = "true".equalsIgnoreCase((String) isResolvedObj); + } issue.setResolved(isResolved); + + log.debug("Issue resolved status: isResolvedObj={}, type={}, parsed={}", + isResolvedObj, isResolvedObj != null ? isResolvedObj.getClass().getSimpleName() : "null", isResolved); // If this issue is resolved and we have original issue data, populate resolution tracking if (isResolved && originalIssue != null) { diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index a34c5744..a8dce904 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -391,9 +391,11 @@ def _format_previous_issues_for_batch(self, issues: List[Any]) -> str: lines.append("") lines.append("INSTRUCTIONS:") - lines.append("- For OPEN issues: Set 'isResolved: true' if fixed in the current diff, 'isResolved: false' if still present") + lines.append("- For OPEN issues that are now FIXED: report with 'isResolved': true (boolean)") + lines.append("- For OPEN issues still present: report with 'isResolved': false (boolean)") lines.append("- For already RESOLVED issues: Do NOT re-report them (they're just for context)") - lines.append("- Preserve the 'id' field for all issues you report") + lines.append("- IMPORTANT: 'isResolved' MUST be a JSON boolean (true/false), not a string") + lines.append("- Preserve the 'id' field for all issues you report from previous issues") lines.append("=== END PREVIOUS ISSUES ===") return "\n".join(lines) @@ -606,7 +608,7 @@ async def _review_file_batch( try: rag_response = await self.rag_client.get_pr_context( workspace=request.projectWorkspace, - project=request.projectVcsRepoSlug, + project=request.projectNamespace, branch=request.targetBranchName, changed_files=batch_file_paths, diff_snippets=batch_diff_snippets, @@ -965,11 +967,17 @@ def _format_rag_context( any(path.endswith(f) or f.endswith(path) for f in pr_changed_set) ) - # Skip chunks from files being modified in the PR - they're stale + # For chunks from modified files: + # - Skip if low relevance (score < 0.85) - likely not useful and stale + # - Include if high relevance (score >= 0.85) - context is still valuable even if code may change + # The LLM can use this context to understand patterns even if specific lines changed if is_from_modified_file: - logger.debug(f"Skipping RAG chunk from modified file: {path}") - skipped_modified += 1 - continue + if score < 0.85: + logger.debug(f"Skipping RAG chunk from modified file (low score): {path} (score={score})") + skipped_modified += 1 + continue + else: + logger.debug(f"Including RAG chunk from modified file (high relevance): {path} (score={score})") # Optionally filter by relevance to batch files if relevant_files: @@ -1048,10 +1056,11 @@ def _format_rag_context( ) if not formatted_parts: - logger.info(f"No RAG chunks included in prompt (total: {len(chunks)}, skipped_modified: {skipped_modified}, skipped_relevance: {skipped_relevance})") + logger.warning(f"No RAG chunks included in prompt (total: {len(chunks)}, skipped_modified: {skipped_modified}, skipped_relevance: {skipped_relevance}). " + f"PR changed files: {pr_changed_files[:5] if pr_changed_files else 'none'}...") return "" - logger.info(f"Included {len(formatted_parts)} RAG chunks in prompt context (skipped: {skipped_modified} modified, {skipped_relevance} low relevance)") + logger.info(f"Included {len(formatted_parts)} RAG chunks in prompt context (total: {len(chunks)}, skipped: {skipped_modified} low-score modified, {skipped_relevance} low relevance)") return "\n".join(formatted_parts) def _emit_status(self, state: str, message: str): diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index 91cd4e7b..d20fca9d 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -410,7 +410,8 @@ - Return exactly one review object per input file. - Match file paths exactly. - Skip style nits. -- For PREVIOUS ISSUES that are now RESOLVED: include them with isResolved=true and PRESERVE the original id field. +- For PREVIOUS ISSUES that are now RESOLVED: set "isResolved": true (boolean, not string) and PRESERVE the original id field. +- The "isResolved" field MUST be a JSON boolean: true or false, NOT a string "true" or "false". - suggestedFixDiff MUST be a valid unified diff string if a fix is proposed. """ diff --git a/python-ecosystem/mcp-client/utils/response_parser.py b/python-ecosystem/mcp-client/utils/response_parser.py index d32c2c61..fc169f47 100644 --- a/python-ecosystem/mcp-client/utils/response_parser.py +++ b/python-ecosystem/mcp-client/utils/response_parser.py @@ -119,10 +119,16 @@ def _clean_issue(issue: Dict[str, Any]) -> Dict[str, Any]: # Ensure isResolved is boolean if key == 'isResolved': + original_value = value if isinstance(value, str): value = value.lower() == 'true' elif not isinstance(value, bool): value = False + # Log resolved issues for debugging + if value: + issue_id = issue.get('id') or issue.get('issueId', 'unknown') + import logging + logging.getLogger(__name__).info(f"Issue {issue_id} marked as isResolved=True (original: {original_value})") # Ensure id is a string when present (preserve mapping to DB ids) if key == 'id': From 6fb569357035402c1bd866d7df2dea02923ede6b Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 18:27:28 +0200 Subject: [PATCH 10/34] fix: Implement method to retrieve branch differences from GitLab API --- .../vcsclient/gitlab/GitLabClient.java | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java index 7e50f566..30a6b555 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java @@ -498,6 +498,68 @@ public String getLatestCommitHash(String workspaceId, String repoIdOrSlug, Strin } } + @Override + public String getBranchDiff(String workspaceId, String repoIdOrSlug, String baseBranch, String compareBranch) throws IOException { + // GitLab: GET /projects/:id/repository/compare + // Returns diff between two branches/commits + // API: https://docs.gitlab.com/ee/api/repositories.html#compare-branches-tags-or-commits + String projectPath = workspaceId + "/" + repoIdOrSlug; + String encodedPath = URLEncoder.encode(projectPath, StandardCharsets.UTF_8); + String encodedFrom = URLEncoder.encode(baseBranch, StandardCharsets.UTF_8); + String encodedTo = URLEncoder.encode(compareBranch, StandardCharsets.UTF_8); + + String url = baseUrl + "/projects/" + encodedPath + "/repository/compare?from=" + encodedFrom + "&to=" + encodedTo; + + Request request = createGetRequest(url); + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw createException("get branch diff", response); + } + + JsonNode root = objectMapper.readTree(response.body().string()); + JsonNode diffs = root.get("diffs"); + + if (diffs == null || !diffs.isArray() || diffs.isEmpty()) { + return ""; + } + + // Build unified diff format from GitLab's compare response + StringBuilder diffBuilder = new StringBuilder(); + for (JsonNode diff : diffs) { + String oldPath = getTextOrNull(diff, "old_path"); + String newPath = getTextOrNull(diff, "new_path"); + boolean newFile = diff.has("new_file") && diff.get("new_file").asBoolean(); + boolean deletedFile = diff.has("deleted_file") && diff.get("deleted_file").asBoolean(); + boolean renamedFile = diff.has("renamed_file") && diff.get("renamed_file").asBoolean(); + String diffContent = getTextOrNull(diff, "diff"); + + // Build git diff header + diffBuilder.append("diff --git a/").append(oldPath).append(" b/").append(newPath).append("\n"); + + if (newFile) { + diffBuilder.append("new file mode 100644\n"); + } else if (deletedFile) { + diffBuilder.append("deleted file mode 100644\n"); + } else if (renamedFile) { + diffBuilder.append("rename from ").append(oldPath).append("\n"); + diffBuilder.append("rename to ").append(newPath).append("\n"); + } + + diffBuilder.append("--- a/").append(oldPath).append("\n"); + diffBuilder.append("+++ b/").append(newPath).append("\n"); + + if (diffContent != null && !diffContent.isEmpty()) { + diffBuilder.append(diffContent); + if (!diffContent.endsWith("\n")) { + diffBuilder.append("\n"); + } + } + } + + return diffBuilder.toString(); + } + } + @Override public List listBranches(String workspaceId, String repoIdOrSlug) throws IOException { List branches = new ArrayList<>(); From 56761fbd0dc5fc9188b1673d123c960fb1e9ac9d Mon Sep 17 00:00:00 2001 From: rostislav Date: Fri, 23 Jan 2026 19:33:14 +0200 Subject: [PATCH 11/34] fix: Enhance logging and implement deterministic context retrieval in RAG pipeline --- frontend | 2 +- .../service/IncrementalRagUpdateService.java | 18 +- .../service/RagOperationsServiceImpl.java | 68 +++- .../service/multi_stage_orchestrator.py | 102 ++--- .../mcp-client/service/rag_client.py | 67 +++ .../utils/prompts/prompt_constants.py | 1 + .../rag-pipeline/src/rag_pipeline/api/api.py | 45 ++ .../rag_pipeline/services/query_service.py | 385 +++++++++++++++++- 8 files changed, 588 insertions(+), 100 deletions(-) diff --git a/frontend b/frontend index 9ce813e0..fdbb0555 160000 --- a/frontend +++ b/frontend @@ -1 +1 @@ -Subproject commit 9ce813e0ad5798c7059cf6c23c08b4d0a36c74b5 +Subproject commit fdbb055524794f49a0299fd7f020177243855e58 diff --git a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/IncrementalRagUpdateService.java b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/IncrementalRagUpdateService.java index bae4b707..04a6fd5f 100644 --- a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/IncrementalRagUpdateService.java +++ b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/IncrementalRagUpdateService.java @@ -43,15 +43,29 @@ public IncrementalRagUpdateService( public boolean shouldPerformIncrementalUpdate(Project project) { if (!ragApiEnabled) { + log.info("shouldPerformIncrementalUpdate: ragApiEnabled=false for project={}", project.getId()); return false; } ProjectConfig config = project.getConfiguration(); - if (config == null || config.ragConfig() == null || !config.ragConfig().enabled()) { + if (config == null) { + log.info("shouldPerformIncrementalUpdate: config is null for project={}", project.getId()); + return false; + } + + if (config.ragConfig() == null) { + log.info("shouldPerformIncrementalUpdate: ragConfig is null for project={}", project.getId()); + return false; + } + + if (!config.ragConfig().enabled()) { + log.info("shouldPerformIncrementalUpdate: ragConfig.enabled=false for project={}", project.getId()); return false; } - return ragIndexTrackingService.isProjectIndexed(project); + boolean isIndexed = ragIndexTrackingService.isProjectIndexed(project); + log.info("shouldPerformIncrementalUpdate: project={} isProjectIndexed={}", project.getId(), isIndexed); + return isIndexed; } public Map performIncrementalUpdate( diff --git a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java index 65720a4d..e8128be4 100644 --- a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java +++ b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java @@ -387,6 +387,8 @@ public boolean ensureBranchIndexForPrTarget( // With single-collection architecture, we check if branch has any indexed data // If not, we need to index the branch + log.info("ensureBranchIndexForPrTarget called for project={}, branch={}", project.getId(), targetBranch); + if (!isRagEnabled(project)) { log.debug("RAG not enabled for project={}", project.getId()); return false; @@ -402,12 +404,6 @@ public boolean ensureBranchIndexForPrTarget( return false; } - // Check if branch already has indexed data - if (isBranchIndexReady(project, targetBranch)) { - log.debug("Branch {} already indexed for project={}", targetBranch, project.getId()); - return true; - } - // Get VCS connection info VcsRepoBinding vcsRepoBinding = project.getVcsRepoBinding(); if (vcsRepoBinding == null) { @@ -424,43 +420,57 @@ public boolean ensureBranchIndexForPrTarget( // Same branch? Already indexed via main index if (targetBranch.equals(baseBranch)) { - log.debug("Target branch is same as base branch - already indexed"); + log.debug("Target branch {} is same as base branch {} - already indexed", targetBranch, baseBranch); return true; } + // Check if branch already has indexed data (RagBranchIndex exists) + // Note: We still proceed with diff check to ensure any new changes are indexed + boolean branchIndexExists = isBranchIndexReady(project, targetBranch); + log.info("Branch index status for project={}, branch={}: exists={}", + project.getId(), targetBranch, branchIndexExists); + try { - log.info("Indexing branch data for project={}, branch={}", project.getId(), targetBranch); + log.info("Fetching diff between base branch '{}' and target branch '{}' for project={}", + baseBranch, targetBranch, project.getId()); eventConsumer.accept(Map.of( "type", "status", "state", "branch_index", - "message", String.format("Indexing branch '%s'", targetBranch) + "message", String.format("Indexing branch '%s' (diff vs '%s')", targetBranch, baseBranch) )); // Fetch diff between base branch and target branch VcsClient vcsClient = vcsClientProvider.getClient(vcsConnection); String rawDiff = vcsClient.getBranchDiff(workspaceSlug, repoSlug, baseBranch, targetBranch); + log.info("Branch diff result for project={}, branch={}: diffLength={}", + project.getId(), targetBranch, rawDiff != null ? rawDiff.length() : 0); + if (rawDiff == null || rawDiff.isEmpty()) { - log.debug("No diff between {} and {} - using main index", baseBranch, targetBranch); + log.info("No diff between '{}' and '{}' - branch has same content as base, using main index", + baseBranch, targetBranch); eventConsumer.accept(Map.of( "type", "info", - "message", String.format("No changes between %s and %s", baseBranch, targetBranch) + "message", String.format("No changes between %s and %s - using main branch index", baseBranch, targetBranch) )); return true; } // Get latest commit hash on target branch String targetCommit = vcsClient.getLatestCommitHash(workspaceSlug, repoSlug, targetBranch); + log.info("Target branch '{}' commit hash: {}", targetBranch, targetCommit); // Trigger incremental update for this branch + log.info("Triggering incremental update for project={}, branch={}, commit={}, diffBytes={}", + project.getId(), targetBranch, targetCommit, rawDiff.length()); triggerIncrementalUpdate(project, targetBranch, targetCommit, rawDiff, eventConsumer); return true; } catch (Exception e) { - log.error("Failed to index branch data for project={}, branch={}", - project.getId(), targetBranch, e); + log.error("Failed to index branch data for project={}, branch={}: {}", + project.getId(), targetBranch, e.getMessage(), e); eventConsumer.accept(Map.of( "type", "warning", "state", "branch_error", @@ -638,8 +648,10 @@ public boolean ensureRagIndexUpToDate( String targetBranch, Consumer> eventConsumer ) { + log.info("ensureRagIndexUpToDate called for project={}, targetBranch={}", project.getId(), targetBranch); + if (!isRagEnabled(project)) { - log.debug("RAG not enabled for project={}", project.getId()); + log.info("RAG not enabled for project={}", project.getId()); return false; } @@ -656,16 +668,20 @@ public boolean ensureRagIndexUpToDate( // Get base branch (main branch) String baseBranch = getBaseBranch(project); + log.info("Base branch for project={}: '{}'", project.getId(), baseBranch); try { VcsClient vcsClient = vcsClientProvider.getClient(vcsConnection); // Case 1: Target branch is the main branch - check/update main RAG index if (targetBranch.equals(baseBranch)) { + log.info("Target branch '{}' equals base branch '{}' - updating main index only", targetBranch, baseBranch); return ensureMainIndexUpToDate(project, targetBranch, vcsClient, workspaceSlug, repoSlug, eventConsumer); } // Case 2: Different branch - ensure main index is ready, then ensure branch is indexed + log.info("Target branch '{}' differs from base branch '{}' - will ensure branch index", targetBranch, baseBranch); + // First ensure main index is up to date ensureMainIndexUpToDate(project, baseBranch, vcsClient, workspaceSlug, repoSlug, eventConsumer); @@ -745,6 +761,7 @@ private boolean ensureMainIndexUpToDate( /** * Ensures the branch index is up-to-date with the current commit. + * For non-main branches, this compares against the previously indexed commit. */ private boolean ensureBranchIndexUpToDate( Project project, @@ -755,38 +772,47 @@ private boolean ensureBranchIndexUpToDate( String repoSlug, Consumer> eventConsumer ) throws IOException { + log.info("ensureBranchIndexUpToDate called for project={}, targetBranch={}, baseBranch={}", + project.getId(), targetBranch, baseBranch); + // Get current commit on target branch String currentCommit = vcsClient.getLatestCommitHash(workspaceSlug, repoSlug, targetBranch); + log.info("Current commit on branch '{}': {}", targetBranch, currentCommit); // Check if we have branch index tracking Optional branchIndexOpt = ragBranchIndexRepository .findByProjectIdAndBranchName(project.getId(), targetBranch); if (branchIndexOpt.isEmpty()) { - // No branch index exists - create it - log.info("Branch index does not exist for project={}, branch={} - creating", - project.getId(), targetBranch); + // No branch index exists - create it by getting full diff vs main + log.info("No RagBranchIndex entry found for project={}, branch={} - will create with full diff vs {}", + project.getId(), targetBranch, baseBranch); return ensureBranchIndexForPrTarget(project, targetBranch, eventConsumer); } RagBranchIndex branchIndex = branchIndexOpt.get(); String indexedCommit = branchIndex.getCommitHash(); + log.info("Existing RagBranchIndex for project={}, branch={}: indexedCommit={}", + project.getId(), targetBranch, indexedCommit); // If commits match, index is up to date if (currentCommit.equals(indexedCommit)) { - log.debug("Branch index is up-to-date for project={}, branch={}, commit={}", + log.info("Branch index is up-to-date for project={}, branch={}, commit={}", project.getId(), targetBranch, currentCommit); return true; } - log.info("Branch index outdated for project={}, branch={}: indexed={}, current={}", + log.info("Branch index outdated for project={}, branch={}: indexed={}, current={} - fetching incremental diff", project.getId(), targetBranch, indexedCommit, currentCommit); // Fetch diff between indexed commit and current commit on this branch String rawDiff = vcsClient.getBranchDiff(workspaceSlug, repoSlug, indexedCommit, currentCommit); + log.info("Incremental diff for branch '{}' ({}..{}): bytes={}", + targetBranch, indexedCommit.substring(0, Math.min(7, indexedCommit.length())), + currentCommit.substring(0, 7), rawDiff != null ? rawDiff.length() : 0); if (rawDiff == null || rawDiff.isEmpty()) { - log.debug("No diff between {} and {} - index is up to date", indexedCommit, currentCommit); + log.info("No diff between {} and {} - updating commit hash only", indexedCommit, currentCommit); // Update commit hash branchIndex.setCommitHash(currentCommit); branchIndex.setUpdatedAt(OffsetDateTime.now()); @@ -803,6 +829,8 @@ private boolean ensureBranchIndexUpToDate( )); // Trigger incremental update for this branch + log.info("Triggering incremental update for branch '{}' with diff of {} bytes", + targetBranch, rawDiff.length()); triggerIncrementalUpdate(project, targetBranch, currentCommit, rawDiff, eventConsumer); return true; diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index a8dce904..f58c702e 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -407,35 +407,18 @@ def _get_diff_snippets_for_batch( """ Filter diff snippets to only include those relevant to the batch files. - The diffSnippets from ReviewRequestDto are extracted by Java DiffParser.extractDiffSnippets() - which creates snippets in format that may include file paths in diff headers. - We filter to get only snippets relevant to the current batch for targeted RAG queries. + Note: Java DiffParser.extractDiffSnippets() returns CLEAN CODE SNIPPETS (no file paths). + These snippets are just significant code lines like function signatures. + Since snippets don't contain file paths, we return all snippets for semantic search. + The embedding similarity will naturally prioritize relevant matches. """ - if not all_diff_snippets or not batch_file_paths: + if not all_diff_snippets: return [] - batch_snippets = [] - batch_file_names = {path.split('/')[-1] for path in batch_file_paths} - - for snippet in all_diff_snippets: - # Check if snippet relates to any file in the batch - # Diff snippets typically contain file paths in headers like "diff --git a/path b/path" - # or "--- a/path" or "+++ b/path" - snippet_lower = snippet.lower() - - for file_path in batch_file_paths: - if file_path.lower() in snippet_lower: - batch_snippets.append(snippet) - break - else: - # Also check by filename only (for cases where paths differ) - for file_name in batch_file_names: - if file_name.lower() in snippet_lower: - batch_snippets.append(snippet) - break - - logger.debug(f"Filtered {len(batch_snippets)} diff snippets for batch from {len(all_diff_snippets)} total") - return batch_snippets + # Java snippets are clean code (no file paths), so we can't filter by path + # Return all snippets - the semantic search will find relevant matches + logger.info(f"Using {len(all_diff_snippets)} diff snippets for batch files {batch_file_paths}") + return all_diff_snippets async def _execute_stage_0_planning(self, request: ReviewRequestDto, is_incremental: bool = False) -> ReviewPlan: """ @@ -529,7 +512,10 @@ async def _execute_stage_1_file_reviews( for batch_idx, batch in enumerate(wave_batches, start=wave_start + 1): batch_paths = [item["file"].path for item in batch] logger.debug(f"Batch {batch_idx}: {batch_paths}") - tasks.append(self._review_file_batch(request, batch, processed_diff, is_incremental)) + tasks.append(self._review_file_batch( + request, batch, processed_diff, is_incremental, + fallback_rag_context=rag_context + )) results = await asyncio.gather(*tasks, return_exceptions=True) @@ -555,7 +541,8 @@ async def _review_file_batch( request: ReviewRequestDto, batch_items: List[Dict[str, Any]], processed_diff: Optional[ProcessedDiff] = None, - is_incremental: bool = False + is_incremental: bool = False, + fallback_rag_context: Optional[Dict[str, Any]] = None ) -> List[CodeReviewIssue]: """ Review a batch of files in a single LLM call with per-batch RAG context. @@ -595,37 +582,18 @@ async def _review_file_batch( "is_incremental": is_incremental # Pass mode to prompt builder }) - # Filter pre-computed diff snippets for files in this batch (for RAG query) - # The diffSnippets from ReviewRequestDto are already properly extracted by Java DiffParser - batch_diff_snippets = self._get_diff_snippets_for_batch( - request.diffSnippets or [], - batch_file_paths - ) - - # Fetch per-batch RAG context (targeted to these specific files) + # Use initial RAG context (already fetched with all files/snippets) + # The initial query is more comprehensive - it uses ALL changed files and snippets + # Per-batch filtering is done in _format_rag_context via relevant_files param rag_context_text = "" - if self.rag_client and self.rag_client.enabled: - try: - rag_response = await self.rag_client.get_pr_context( - workspace=request.projectWorkspace, - project=request.projectNamespace, - branch=request.targetBranchName, - changed_files=batch_file_paths, - diff_snippets=batch_diff_snippets, - pr_title=request.prTitle, - top_k=5, # Focused context for this batch - min_relevance_score=0.75 # Higher threshold for per-batch - ) - # Pass ALL changed files from the PR to filter out stale context - # RAG returns context from main branch, so files being modified in PR are stale - rag_context_text = self._format_rag_context( - rag_response.get("context", {}), - set(batch_file_paths), - pr_changed_files=request.changedFiles - ) - logger.debug(f"Batch RAG context retrieved for {batch_file_paths}") - except Exception as e: - logger.warning(f"Failed to fetch per-batch RAG context: {e}") + if fallback_rag_context: + logger.info(f"Using initial RAG context for batch: {batch_file_paths}") + rag_context_text = self._format_rag_context( + fallback_rag_context, + set(batch_file_paths), + pr_changed_files=request.changedFiles + ) + logger.info(f"RAG context for batch: {len(rag_context_text)} chars") # For incremental mode, filter previous issues relevant to this batch # Also pass previous issues in FULL mode if they exist (subsequent PR iterations) @@ -933,6 +901,7 @@ def _format_rag_context( return "" logger.debug(f"Processing {len(chunks)} RAG chunks for context") + logger.debug(f"PR changed files for filtering: {pr_changed_files[:5] if pr_changed_files else 'none'}...") # Normalize PR changed files for comparison pr_changed_set = set() @@ -948,7 +917,8 @@ def _format_rag_context( skipped_modified = 0 skipped_relevance = 0 for chunk in chunks: - if included_count >= 10: # Limit to top 10 chunks for focused context + if included_count >= 15: # Increased from 10 for more context + logger.debug(f"Reached chunk limit of 15, stopping") break # Extract metadata @@ -968,16 +938,15 @@ def _format_rag_context( ) # For chunks from modified files: - # - Skip if low relevance (score < 0.85) - likely not useful and stale - # - Include if high relevance (score >= 0.85) - context is still valuable even if code may change - # The LLM can use this context to understand patterns even if specific lines changed + # - Skip if very low relevance (score < 0.70) - likely not useful + # - Include if moderate+ relevance (score >= 0.70) - context is valuable if is_from_modified_file: - if score < 0.85: + if score < 0.70: logger.debug(f"Skipping RAG chunk from modified file (low score): {path} (score={score})") skipped_modified += 1 continue else: - logger.debug(f"Including RAG chunk from modified file (high relevance): {path} (score={score})") + logger.debug(f"Including RAG chunk from modified file (relevant): {path} (score={score})") # Optionally filter by relevance to batch files if relevant_files: @@ -987,8 +956,9 @@ def _format_rag_context( path.rsplit("/", 1)[-1] == f.rsplit("/", 1)[-1] for f in relevant_files ) - # Also include high-scoring chunks regardless - if not is_relevant and score < 0.85: + # Also include chunks with moderate+ score regardless + if not is_relevant and score < 0.70: + logger.debug(f"Skipping RAG chunk (not relevant to batch and low score): {path} (score={score})") skipped_relevance += 1 continue diff --git a/python-ecosystem/mcp-client/service/rag_client.py b/python-ecosystem/mcp-client/service/rag_client.py index 44c98d34..ae9ba481 100644 --- a/python-ecosystem/mcp-client/service/rag_client.py +++ b/python-ecosystem/mcp-client/service/rag_client.py @@ -217,3 +217,70 @@ async def is_healthy(self) -> bool: logger.warning(f"RAG health check failed: {e}") return False + async def get_deterministic_context( + self, + workspace: str, + project: str, + branches: List[str], + file_paths: List[str], + limit_per_file: int = 10 + ) -> Dict[str, Any]: + """ + Get context using DETERMINISTIC metadata-based retrieval. + + Two-step process leveraging tree-sitter metadata: + 1. Get chunks for the changed file_paths + 2. Extract semantic_names/imports/extends from those chunks + 3. Find related definitions using extracted identifiers + + NO language-specific parsing needed - tree-sitter did it during indexing! + Predictable: same input = same output. + + Args: + workspace: Workspace identifier + project: Project identifier + branches: Branches to search (e.g., ['release/1.29', 'master']) + file_paths: Changed file paths from diff + limit_per_file: Max chunks per file (default 10) + + Returns: + Dict with chunks grouped by: changed_files, related_definitions + """ + if not self.enabled: + logger.debug("RAG disabled, returning empty deterministic context") + return {"context": {"chunks": [], "changed_files": {}, "related_definitions": {}}} + + start_time = datetime.now() + + try: + payload = { + "workspace": workspace, + "project": project, + "branches": branches, + "file_paths": file_paths, + "limit_per_file": limit_per_file + } + + client = await self._get_client() + response = await client.post( + f"{self.base_url}/query/deterministic", + json=payload + ) + response.raise_for_status() + result = response.json() + + # Log timing and stats + elapsed_ms = (datetime.now() - start_time).total_seconds() * 1000 + context = result.get("context", {}) + chunk_count = len(context.get("chunks", [])) + logger.info(f"Deterministic RAG query completed in {elapsed_ms:.2f}ms, " + f"retrieved {chunk_count} chunks for {len(file_paths)} files") + + return result + + except httpx.HTTPError as e: + logger.warning(f"Failed to retrieve deterministic context: {e}") + return {"context": {"chunks": [], "by_identifier": {}, "by_file": {}}} + except Exception as e: + logger.error(f"Unexpected error in deterministic RAG query: {e}") + return {"context": {"chunks": [], "by_identifier": {}, "by_file": {}}} diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index d20fca9d..47e65d1b 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -343,6 +343,7 @@ - Missing variable declarations - they may exist outside the diff context - Missing function definitions - the function may be defined elsewhere in the file - Missing class properties - they may be declared outside the visible changes +- Security issues in code that is not visible in the diff of RAG context ONLY report issues that you can VERIFY from the visible diff content. If you suspect an issue but cannot confirm it from the diff, DO NOT report it. diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py index 0b97486d..2ea2c73a 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py @@ -67,6 +67,23 @@ class PRContextRequest(BaseModel): deleted_files: Optional[List[str]] = [] # Files deleted in target branch +class DeterministicContextRequest(BaseModel): + """Request for deterministic metadata-based context retrieval. + + TWO-STEP process leveraging tree-sitter metadata: + 1. Get chunks for the changed file_paths + 2. Extract semantic_names/imports/extends from those chunks + 3. Find related definitions using extracted identifiers + + NO language-specific parsing needed - tree-sitter already did it during indexing! + """ + workspace: str + project: str + branches: List[str] # Branches to search (e.g., ['release/1.29', 'master']) + file_paths: List[str] # Changed file paths from diff + limit_per_file: Optional[int] = 10 # Max chunks per file + + class DeleteBranchRequest(BaseModel): workspace: str project: str @@ -422,6 +439,34 @@ def get_pr_context(request: PRContextRequest): raise HTTPException(status_code=500, detail=str(e)) +@app.post("/query/deterministic") +def get_deterministic_context(request: DeterministicContextRequest): + """ + Get context using DETERMINISTIC metadata-based retrieval. + + Two-step process: + 1. Get chunks for changed file_paths + 2. Extract semantic_names/imports/extends from those chunks (tree-sitter metadata!) + 3. Find related definitions using extracted identifiers + + No language-specific parsing needed - tree-sitter already did it during indexing. + Predictable: same input = same output. + """ + try: + context = query_service.get_deterministic_context( + workspace=request.workspace, + project=request.project, + branches=request.branches, + file_paths=request.file_paths, + limit_per_file=request.limit_per_file or 10 + ) + + return {"context": context} + except Exception as e: + logger.error(f"Error getting deterministic context: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================= # BRANCH MANAGEMENT ENDPOINTS # ============================================================================= diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py index a8dcf3f5..dc222980 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py @@ -288,6 +288,340 @@ def _get_fallback_branch(self, workspace: str, project: str, requested_branch: s return None + def get_deterministic_context( + self, + workspace: str, + project: str, + branches: List[str], + file_paths: List[str], + limit_per_file: int = 10 + ) -> Dict: + """ + Get context using DETERMINISTIC metadata-based retrieval. + + Leverages ALL tree-sitter metadata extracted during indexing: + - semantic_names: function/method/class names + - primary_name: main identifier + - parent_class: containing class + - full_path: qualified name (e.g., "Data.getConfigData") + - imports: import statements + - extends: parent classes/interfaces + - namespace: package/namespace + - node_type: method_declaration, class_definition, etc. + + Multi-step process: + 1. Query chunks for changed file_paths + 2. Extract metadata (identifiers, parent classes, namespaces, imports) + 3. Find related definitions by: + a) primary_name match (definitions of used identifiers) + b) parent_class match (other methods in same class) + c) namespace match (related code in same package) + + NO LANGUAGE-SPECIFIC PARSING NEEDED - tree-sitter already did that! + Same input always produces same output (deterministic). + + Args: + workspace: VCS workspace + project: Project name + branches: Branches to search (target + base for PRs) + file_paths: Changed file paths from diff + limit_per_file: Max chunks per file + + Returns: + Dict with chunks grouped by retrieval type and rich metadata + """ + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_or_alias_exists(collection_name): + logger.warning(f"Collection {collection_name} does not exist") + return {"chunks": [], "changed_files": {}, "related_definitions": {}, + "class_context": {}, "namespace_context": {}, + "_metadata": {"error": "collection_not_found"}} + + logger.info(f"Deterministic context: files={file_paths[:5]}, branches={branches}") + + all_chunks = [] + changed_files_chunks = {} + related_definitions = {} + class_context = {} # Other methods in same classes + namespace_context = {} # Related code in same namespaces + + # Metadata to collect from changed files + identifiers_to_find = set() + parent_classes = set() + namespaces = set() + imports_raw = set() + extends_raw = set() + + # Track changed file paths for deduplication + changed_file_paths = set() + seen_texts = set() + + # Build branch filter + branch_filter = ( + FieldCondition(key="branch", match=MatchValue(value=branches[0])) + if len(branches) == 1 + else FieldCondition(key="branch", match=MatchAny(any=branches)) + ) + + # ========== STEP 1: Get chunks from changed files ========== + for file_path in file_paths: + try: + normalized_path = file_path.lstrip("/") + filename = normalized_path.rsplit("/", 1)[-1] if "/" in normalized_path else normalized_path + + # Try exact path match + results, _ = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=Filter( + must=[ + branch_filter, + FieldCondition(key="path", match=MatchValue(value=normalized_path)) + ] + ), + limit=limit_per_file, + with_payload=True, + with_vectors=False + ) + + # Fallback: partial match if exact fails + if not results: + all_results, _ = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=Filter(must=[branch_filter]), + limit=1000, + with_payload=True, + with_vectors=False + ) + results = [ + p for p in all_results + if normalized_path in p.payload.get("path", "") or + filename == p.payload.get("path", "").rsplit("/", 1)[-1] + ][:limit_per_file] + + chunks_for_file = [] + for point in results: + payload = point.payload + text = payload.get("text", payload.get("_node_content", "")) + + if text in seen_texts: + continue + seen_texts.add(text) + + chunk = { + "text": text, + "metadata": {k: v for k, v in payload.items() if k not in ("text", "_node_content")}, + "score": 1.0, + "_match_type": "changed_file", + "_matched_on": file_path + } + chunks_for_file.append(chunk) + all_chunks.append(chunk) + changed_file_paths.add(payload.get("path", "")) + + # Extract ALL tree-sitter metadata for step 2-4 + if isinstance(payload.get("semantic_names"), list): + identifiers_to_find.update(payload["semantic_names"]) + + if payload.get("primary_name"): + identifiers_to_find.add(payload["primary_name"]) + + if payload.get("parent_class"): + parent_classes.add(payload["parent_class"]) + + if payload.get("namespace"): + namespaces.add(payload["namespace"]) + + if isinstance(payload.get("imports"), list): + for imp in payload["imports"]: + # Extract class name from import statement + # "use Magento\Store\Model\ScopeInterface;" -> "ScopeInterface" + if isinstance(imp, str): + parts = imp.replace(";", "").split("\\") + if parts: + imports_raw.add(parts[-1].strip()) + + if isinstance(payload.get("extends"), list): + extends_raw.update(payload["extends"]) + + if payload.get("parent_class"): + extends_raw.add(payload["parent_class"]) + + changed_files_chunks[file_path] = chunks_for_file + + except Exception as e: + logger.warning(f"Error querying file '{file_path}': {e}") + + logger.info(f"Step 1: {len(all_chunks)} chunks from changed files. " + f"Extracted: {len(identifiers_to_find)} identifiers, " + f"{len(parent_classes)} parent_classes, {len(namespaces)} namespaces, " + f"{len(imports_raw)} imports, {len(extends_raw)} extends") + + # ========== STEP 2: Find definitions by primary_name ========== + # Find where identifiers/imports/extends are DEFINED + all_to_find = identifiers_to_find | imports_raw | extends_raw + if all_to_find: + try: + batch = list(all_to_find)[:100] + results, _ = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=Filter( + must=[ + branch_filter, + FieldCondition(key="primary_name", match=MatchAny(any=batch)) + ] + ), + limit=200, + with_payload=True, + with_vectors=False + ) + + for point in results: + payload = point.payload + if payload.get("path") in changed_file_paths: + continue + + text = payload.get("text", payload.get("_node_content", "")) + if text in seen_texts: + continue + seen_texts.add(text) + + primary_name = payload.get("primary_name", "") + chunk = { + "text": text, + "metadata": {k: v for k, v in payload.items() if k not in ("text", "_node_content")}, + "score": 0.95, + "_match_type": "definition", + "_matched_on": primary_name + } + all_chunks.append(chunk) + + if primary_name not in related_definitions: + related_definitions[primary_name] = [] + related_definitions[primary_name].append(chunk) + + logger.info(f"Step 2: Found {len(related_definitions)} definitions by primary_name") + + except Exception as e: + logger.warning(f"Error in primary_name query: {e}") + + # ========== STEP 3: Find other methods in same parent_class ========== + # If we're changing a method in class "Data", find other methods of "Data" + if parent_classes: + try: + batch = list(parent_classes)[:20] + results, _ = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=Filter( + must=[ + branch_filter, + FieldCondition(key="parent_class", match=MatchAny(any=batch)) + ] + ), + limit=100, + with_payload=True, + with_vectors=False + ) + + for point in results: + payload = point.payload + if payload.get("path") in changed_file_paths: + continue + + text = payload.get("text", payload.get("_node_content", "")) + if text in seen_texts: + continue + seen_texts.add(text) + + parent_class = payload.get("parent_class", "") + chunk = { + "text": text, + "metadata": {k: v for k, v in payload.items() if k not in ("text", "_node_content")}, + "score": 0.85, + "_match_type": "class_context", + "_matched_on": parent_class + } + all_chunks.append(chunk) + + if parent_class not in class_context: + class_context[parent_class] = [] + class_context[parent_class].append(chunk) + + logger.info(f"Step 3: Found {sum(len(v) for v in class_context.values())} class context chunks") + + except Exception as e: + logger.warning(f"Error in parent_class query: {e}") + + # ========== STEP 4: Find related code in same namespace ========== + # Lower priority - only get a few for broader context + if namespaces: + try: + batch = list(namespaces)[:10] + results, _ = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=Filter( + must=[ + branch_filter, + FieldCondition(key="namespace", match=MatchAny(any=batch)) + ] + ), + limit=30, + with_payload=True, + with_vectors=False + ) + + for point in results: + payload = point.payload + if payload.get("path") in changed_file_paths: + continue + + text = payload.get("text", payload.get("_node_content", "")) + if text in seen_texts: + continue + seen_texts.add(text) + + namespace = payload.get("namespace", "") + chunk = { + "text": text, + "metadata": {k: v for k, v in payload.items() if k not in ("text", "_node_content")}, + "score": 0.75, + "_match_type": "namespace_context", + "_matched_on": namespace + } + all_chunks.append(chunk) + + if namespace not in namespace_context: + namespace_context[namespace] = [] + namespace_context[namespace].append(chunk) + + logger.info(f"Step 4: Found {sum(len(v) for v in namespace_context.values())} namespace context chunks") + + except Exception as e: + logger.warning(f"Error in namespace query: {e}") + + logger.info(f"Deterministic context complete: {len(all_chunks)} total chunks " + f"(changed: {sum(len(v) for v in changed_files_chunks.values())}, " + f"definitions: {sum(len(v) for v in related_definitions.values())}, " + f"class_ctx: {sum(len(v) for v in class_context.values())}, " + f"ns_ctx: {sum(len(v) for v in namespace_context.values())})") + + return { + "chunks": all_chunks, + "changed_files": changed_files_chunks, + "related_definitions": related_definitions, + "class_context": class_context, + "namespace_context": namespace_context, + "_metadata": { + "branches_searched": branches, + "files_requested": file_paths, + "identifiers_extracted": list(identifiers_to_find)[:30], + "parent_classes_found": list(parent_classes), + "namespaces_found": list(namespaces), + "imports_extracted": list(imports_raw)[:30], + "extends_extracted": list(extends_raw)[:20] + } + } + def get_context_for_pr( self, workspace: str, @@ -356,11 +690,15 @@ def get_context_for_pr( diff_snippets=diff_snippets, changed_files=changed_files ) + + logger.info(f"Generated {len(queries)} queries for PR context") + for i, (q_text, q_weight, q_top_k, q_type) in enumerate(queries): + logger.info(f" Query {i+1}: weight={q_weight}, top_k={q_top_k}, text='{q_text[:80]}...'") all_results = [] # 2. Execute queries with multi-branch search - for q_text, q_weight, q_top_k, q_instruction_type in queries: + for i, (q_text, q_weight, q_top_k, q_instruction_type) in enumerate(queries): if not q_text.strip(): continue @@ -373,6 +711,8 @@ def get_context_for_pr( instruction_type=q_instruction_type, excluded_paths=deleted_files ) + + logger.info(f"Query {i+1}/{len(queries)} returned {len(results)} results") for r in results: r["_query_weight"] = q_weight @@ -419,9 +759,12 @@ def get_context_for_pr( if "path" in result["metadata"]: related_files.add(result["metadata"]["path"]) - logger.info( - f"Smart RAG: Final context has {len(relevant_code)} chunks " - f"from {len(related_files)} files across {len(branches_to_search)} branches") + # Log top results for debugging + logger.info(f"Smart RAG: Final context has {len(relevant_code)} chunks from {len(related_files)} files") + for i, r in enumerate(relevant_code[:5]): + path = r["metadata"].get("path", "unknown") + primary_name = r["metadata"].get("primary_name", "N/A") + logger.info(f" Chunk {i+1}: score={r['score']:.3f}, name={primary_name}, path=...{path[-60:]}") result = { "relevant_code": relevant_code, @@ -444,6 +787,7 @@ def _decompose_queries( """ from collections import defaultdict import os + import re queries = [] @@ -486,14 +830,33 @@ def _decompose_queries( queries.append((q, 0.8, 5, InstructionType.LOGIC)) # C. Snippet Queries (Low Level) - Weight 1.2 (High precision) - for snippet in diff_snippets[:3]: - # Clean snippet: remove +/- markers, take first few lines - lines = [l.strip() for l in snippet.split('\n') if l.strip() and not l.startswith(('+', '-'))] + # Use actual changed code for semantic matching (not just context lines) + for snippet in diff_snippets[:5]: + # Extract meaningful code from diff - INCLUDE changed lines (+/-), they ARE the code + lines = [] + for line in snippet.split('\n'): + stripped = line.strip() + if not stripped: + continue + # Skip diff headers but keep actual code (including +/- prefixed lines) + if stripped.startswith(('diff --git', '---', '+++', '@@', 'index ')): + continue + # Remove the +/- prefix but keep the code content + if stripped.startswith('+') or stripped.startswith('-'): + code_line = stripped[1:].strip() + if code_line and len(code_line) > 3: # Skip empty/trivial lines + lines.append(code_line) + elif stripped: + lines.append(stripped) + if lines: - # Join first 2-3 significant lines - clean_snippet = " ".join(lines[:3]) - if len(clean_snippet) > 10: - queries.append((clean_snippet, 1.2, 5, InstructionType.DEPENDENCY)) + # Join significant lines (function names, method calls, etc.) + clean_snippet = " ".join(lines[:5]) + if len(clean_snippet) > 15: + queries.append((clean_snippet, 1.2, 8, InstructionType.DEPENDENCY)) + + # Log the generated queries for debugging + logger.debug(f"Decomposed into {len(queries)} queries: {[(q[0][:50], q[1]) for q in queries]}") return queries From 8fc46f3bd4a3bdc4e74ebb99d80ec448ba1eb25b Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 26 Jan 2026 15:12:42 +0200 Subject: [PATCH 12/34] Refactor AI connection handling and improve job deletion logic - Updated JobService to use REQUIRES_NEW transaction propagation for deleting ignored jobs, ensuring fresh entity retrieval and preventing issues with the calling transaction. - Removed token limitation from AI connection model and related DTOs, transitioning to project-level configuration for token limits. - Adjusted AIConnectionDTO tests to reflect the removal of token limitation. - Enhanced Bitbucket, GitHub, and GitLab AI client services to check token limits before analysis, throwing DiffTooLargeException when limits are exceeded. - Updated command processors to utilize project-level token limits instead of AI connection-specific limits. - Modified webhook processing to handle diff size issues gracefully, posting informative messages to VCS when analysis is skipped due to large diffs. - Cleaned up integration tests to remove references to token limitation in AI connection creation and updates. --- java-ecosystem/libs/analysis-engine/pom.xml | 6 ++ .../src/main/java/module-info.java | 1 + .../exception/DiffTooLargeException.java | 47 ++++++++++ .../analysisengine/util/TokenEstimator.java | 83 +++++++++++++++++ .../codecrow/core/dto/ai/AIConnectionDTO.java | 6 +- .../codecrow/core/model/ai/AIConnection.java | 11 --- .../model/project/config/ProjectConfig.java | 34 ++++++- .../codecrow/core/service/JobService.java | 19 +++- ...ve_token_limitation_from_ai_connection.sql | 5 ++ .../core/dto/ai/AIConnectionDTOTest.java | 35 ++------ .../core/model/ai/AIConnectionTest.java | 15 ---- .../service/BitbucketAiClientService.java | 23 ++++- .../processor/WebhookAsyncProcessor.java | 89 ++++++++++++++++++- .../command/AskCommandProcessor.java | 2 +- .../command/ReviewCommandProcessor.java | 2 +- .../command/SummarizeCommandProcessor.java | 2 +- .../github/service/GitHubAiClientService.java | 23 ++++- .../gitlab/service/GitLabAiClientService.java | 23 ++++- .../GitLabMergeRequestWebhookHandler.java | 27 +++++- .../request/CreateAIConnectionRequest.java | 2 - .../request/UpdateAiConnectionRequest.java | 1 - .../ai/service/AIConnectionService.java | 4 - .../integration/ai/AIConnectionCrudIT.java | 34 +++---- .../integration/auth/UserAuthFlowIT.java | 3 +- .../builder/AIConnectionBuilder.java | 7 -- .../integration/util/AuthTestHelper.java | 1 - 26 files changed, 393 insertions(+), 112 deletions(-) create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java create mode 100644 java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql diff --git a/java-ecosystem/libs/analysis-engine/pom.xml b/java-ecosystem/libs/analysis-engine/pom.xml index 4d16d658..3fba0c89 100644 --- a/java-ecosystem/libs/analysis-engine/pom.xml +++ b/java-ecosystem/libs/analysis-engine/pom.xml @@ -68,6 +68,12 @@ okhttp + + + com.knuddels + jtokkit + + org.junit.jupiter diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java b/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java index b3e30345..d03cdf59 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java @@ -18,6 +18,7 @@ requires com.fasterxml.jackson.annotation; requires jakarta.persistence; requires kotlin.stdlib; + requires jtokkit; exports org.rostilos.codecrow.analysisengine.aiclient; exports org.rostilos.codecrow.analysisengine.config; diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java new file mode 100644 index 00000000..7304448c --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java @@ -0,0 +1,47 @@ +package org.rostilos.codecrow.analysisengine.exception; + +/** + * Exception thrown when a diff exceeds the configured token limit for analysis. + * This is a soft skip - the analysis is not performed but the job is not marked as failed. + */ +public class DiffTooLargeException extends RuntimeException { + + private final int estimatedTokens; + private final int maxAllowedTokens; + private final Long projectId; + private final Long pullRequestId; + + public DiffTooLargeException(int estimatedTokens, int maxAllowedTokens, Long projectId, Long pullRequestId) { + super(String.format( + "PR diff exceeds token limit: estimated %d tokens, max allowed %d tokens (project=%d, PR=%d)", + estimatedTokens, maxAllowedTokens, projectId, pullRequestId + )); + this.estimatedTokens = estimatedTokens; + this.maxAllowedTokens = maxAllowedTokens; + this.projectId = projectId; + this.pullRequestId = pullRequestId; + } + + public int getEstimatedTokens() { + return estimatedTokens; + } + + public int getMaxAllowedTokens() { + return maxAllowedTokens; + } + + public Long getProjectId() { + return projectId; + } + + public Long getPullRequestId() { + return pullRequestId; + } + + /** + * Returns the percentage of the token limit that would be used. + */ + public double getUtilizationPercentage() { + return maxAllowedTokens > 0 ? (estimatedTokens * 100.0 / maxAllowedTokens) : 0; + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java new file mode 100644 index 00000000..4ccb613e --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java @@ -0,0 +1,83 @@ +package org.rostilos.codecrow.analysisengine.util; + +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingRegistry; +import com.knuddels.jtokkit.api.EncodingType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility class for estimating token counts in text content. + * Uses the cl100k_base encoding (used by GPT-4, Claude, and most modern LLMs). + */ +public class TokenEstimator { + private static final Logger log = LoggerFactory.getLogger(TokenEstimator.class); + + private static final EncodingRegistry ENCODING_REGISTRY = Encodings.newDefaultEncodingRegistry(); + private static final Encoding ENCODING = ENCODING_REGISTRY.getEncoding(EncodingType.CL100K_BASE); + + /** + * Estimate the number of tokens in the given text. + * + * @param text The text to estimate tokens for + * @return The estimated token count, or 0 if text is null/empty + */ + public static int estimateTokens(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + try { + return ENCODING.countTokens(text); + } catch (Exception e) { + log.warn("Failed to count tokens, using fallback estimation: {}", e.getMessage()); + // Fallback: rough estimate of ~4 characters per token + return text.length() / 4; + } + } + + /** + * Check if the estimated token count exceeds the given limit. + * + * @param text The text to check + * @param maxTokens The maximum allowed tokens + * @return true if the text exceeds the limit, false otherwise + */ + public static boolean exceedsLimit(String text, int maxTokens) { + return estimateTokens(text) > maxTokens; + } + + /** + * Result of a token estimation check with details. + */ + public record TokenEstimationResult( + int estimatedTokens, + int maxAllowedTokens, + boolean exceedsLimit, + double utilizationPercentage + ) { + public String toLogString() { + return String.format("Tokens: %d / %d (%.1f%%) - %s", + estimatedTokens, maxAllowedTokens, utilizationPercentage, + exceedsLimit ? "EXCEEDS LIMIT" : "within limit"); + } + } + + /** + * Estimate tokens and check against limit, returning detailed result. + * + * @param text The text to check + * @param maxTokens The maximum allowed tokens + * @return Detailed estimation result + */ + public static TokenEstimationResult estimateAndCheck(String text, int maxTokens) { + int estimated = estimateTokens(text); + double utilization = maxTokens > 0 ? (estimated * 100.0 / maxTokens) : 0; + return new TokenEstimationResult( + estimated, + maxTokens, + estimated > maxTokens, + utilization + ); + } +} diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java index b9e16fb5..b04b4434 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java @@ -11,8 +11,7 @@ public record AIConnectionDTO( AIProviderKey providerKey, String aiModel, OffsetDateTime createdAt, - OffsetDateTime updatedAt, - int tokenLimitation + OffsetDateTime updatedAt ) { public static AIConnectionDTO fromAiConnection(AIConnection aiConnection) { @@ -22,8 +21,7 @@ public static AIConnectionDTO fromAiConnection(AIConnection aiConnection) { aiConnection.getProviderKey(), aiConnection.getAiModel(), aiConnection.getCreatedAt(), - aiConnection.getUpdatedAt(), - aiConnection.getTokenLimitation() + aiConnection.getUpdatedAt() ); } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java index 2ca682c3..f6558f75 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java @@ -39,9 +39,6 @@ public class AIConnection { @Column(name = "updated_at", nullable = false) private OffsetDateTime updatedAt = OffsetDateTime.now(); - @Column(name= "token_limitation", nullable = false) - private int tokenLimitation = 100000; - @PreUpdate public void onUpdate() { this.updatedAt = OffsetDateTime.now(); @@ -98,12 +95,4 @@ public OffsetDateTime getCreatedAt() { public OffsetDateTime getUpdatedAt() { return updatedAt; } - - public void setTokenLimitation(int tokenLimitation) { - this.tokenLimitation = tokenLimitation; - } - - public int getTokenLimitation() { - return tokenLimitation; - } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java index 66335185..99d18764 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java @@ -24,6 +24,8 @@ * - branchAnalysisEnabled: whether to analyze branch pushes (default: true). * - installationMethod: how the project integration is installed (WEBHOOK, PIPELINE, GITHUB_ACTION). * - commentCommands: configuration for PR comment-triggered commands (/codecrow analyze, summarize, ask). + * - maxAnalysisTokenLimit: maximum allowed tokens for PR analysis (default: 200000). + * Analysis will be skipped if the diff exceeds this limit. * * @see BranchAnalysisConfig * @see RagConfig @@ -32,6 +34,8 @@ */ @JsonIgnoreProperties(ignoreUnknown = true) public class ProjectConfig { + public static final int DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT = 200000; + @JsonProperty("useLocalMcp") private boolean useLocalMcp; @@ -56,16 +60,27 @@ public class ProjectConfig { private InstallationMethod installationMethod; @JsonProperty("commentCommands") private CommentCommandsConfig commentCommands; + @JsonProperty("maxAnalysisTokenLimit") + private Integer maxAnalysisTokenLimit; public ProjectConfig() { this.useLocalMcp = false; this.prAnalysisEnabled = true; this.branchAnalysisEnabled = true; + this.maxAnalysisTokenLimit = DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; } public ProjectConfig(boolean useLocalMcp, String mainBranch, BranchAnalysisConfig branchAnalysis, RagConfig ragConfig, Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, InstallationMethod installationMethod, CommentCommandsConfig commentCommands) { + this(useLocalMcp, mainBranch, branchAnalysis, ragConfig, prAnalysisEnabled, branchAnalysisEnabled, + installationMethod, commentCommands, DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT); + } + + public ProjectConfig(boolean useLocalMcp, String mainBranch, BranchAnalysisConfig branchAnalysis, + RagConfig ragConfig, Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, + InstallationMethod installationMethod, CommentCommandsConfig commentCommands, + Integer maxAnalysisTokenLimit) { this.useLocalMcp = useLocalMcp; this.mainBranch = mainBranch; this.defaultBranch = mainBranch; // Keep in sync for backward compatibility @@ -75,6 +90,7 @@ public ProjectConfig(boolean useLocalMcp, String mainBranch, BranchAnalysisConfi this.branchAnalysisEnabled = branchAnalysisEnabled; this.installationMethod = installationMethod; this.commentCommands = commentCommands; + this.maxAnalysisTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; } public ProjectConfig(boolean useLocalMcp, String mainBranch) { @@ -112,6 +128,14 @@ public String defaultBranch() { public InstallationMethod installationMethod() { return installationMethod; } public CommentCommandsConfig commentCommands() { return commentCommands; } + /** + * Get the maximum token limit for PR analysis. + * Returns the configured value or the default (200000) if not set. + */ + public int maxAnalysisTokenLimit() { + return maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; + } + // Setters for Jackson public void setUseLocalMcp(boolean useLocalMcp) { this.useLocalMcp = useLocalMcp; } @@ -149,6 +173,9 @@ public void setDefaultBranch(String defaultBranch) { public void setBranchAnalysisEnabled(Boolean branchAnalysisEnabled) { this.branchAnalysisEnabled = branchAnalysisEnabled; } public void setInstallationMethod(InstallationMethod installationMethod) { this.installationMethod = installationMethod; } public void setCommentCommands(CommentCommandsConfig commentCommands) { this.commentCommands = commentCommands; } + public void setMaxAnalysisTokenLimit(Integer maxAnalysisTokenLimit) { + this.maxAnalysisTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; + } public void ensureMainBranchInPatterns() { String main = mainBranch(); @@ -230,13 +257,15 @@ public boolean equals(Object o) { Objects.equals(prAnalysisEnabled, that.prAnalysisEnabled) && Objects.equals(branchAnalysisEnabled, that.branchAnalysisEnabled) && installationMethod == that.installationMethod && - Objects.equals(commentCommands, that.commentCommands); + Objects.equals(commentCommands, that.commentCommands) && + Objects.equals(maxAnalysisTokenLimit, that.maxAnalysisTokenLimit); } @Override public int hashCode() { return Objects.hash(useLocalMcp, mainBranch, branchAnalysis, ragConfig, - prAnalysisEnabled, branchAnalysisEnabled, installationMethod, commentCommands); + prAnalysisEnabled, branchAnalysisEnabled, installationMethod, + commentCommands, maxAnalysisTokenLimit); } @Override @@ -250,6 +279,7 @@ public String toString() { ", branchAnalysisEnabled=" + branchAnalysisEnabled + ", installationMethod=" + installationMethod + ", commentCommands=" + commentCommands + + ", maxAnalysisTokenLimit=" + maxAnalysisTokenLimit + '}'; } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index f239b9c1..04036c47 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -311,14 +311,27 @@ public Job skipJob(Job job, String reason) { * Used for jobs that were created but then determined to be unnecessary * (e.g., branch not matching pattern, PR analysis disabled). * This prevents DB clutter from ignored webhooks. + * Uses REQUIRES_NEW to ensure this runs in its own transaction, + * allowing it to work even if the calling transaction has issues. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW) public void deleteIgnoredJob(Job job, String reason) { log.info("Deleting ignored job {} ({}): {}", job.getExternalId(), job.getJobType(), reason); + // Re-fetch the job to ensure we have a fresh entity in this new transaction + Long jobId = job.getId(); + if (jobId == null) { + log.warn("Cannot delete ignored job - job ID is null"); + return; + } + Optional existingJob = jobRepository.findById(jobId); + if (existingJob.isEmpty()) { + log.warn("Cannot delete ignored job {} - not found in database", job.getExternalId()); + return; + } // Delete any logs first (foreign key constraint) - jobLogRepository.deleteByJobId(job.getId()); + jobLogRepository.deleteByJobId(jobId); // Delete the job - jobRepository.delete(job); + jobRepository.delete(existingJob.get()); } /** diff --git a/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql new file mode 100644 index 00000000..fcd0faaf --- /dev/null +++ b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql @@ -0,0 +1,5 @@ +-- Remove token_limitation column from ai_connection table +-- Token limitation is now configured per-project in the project configuration JSON +-- Default value is 200000 tokens, configured in ProjectConfig.maxAnalysisTokenLimit + +ALTER TABLE ai_connection DROP COLUMN IF EXISTS token_limitation; diff --git a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java index 5a8af6ef..7536258c 100644 --- a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java +++ b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java @@ -24,7 +24,7 @@ void shouldCreateWithAllFields() { OffsetDateTime now = OffsetDateTime.now(); AIConnectionDTO dto = new AIConnectionDTO( 1L, "Test Connection", AIProviderKey.ANTHROPIC, "claude-3-opus", - now, now, 100000 + now, now ); assertThat(dto.id()).isEqualTo(1L); @@ -33,14 +33,13 @@ void shouldCreateWithAllFields() { assertThat(dto.aiModel()).isEqualTo("claude-3-opus"); assertThat(dto.createdAt()).isEqualTo(now); assertThat(dto.updatedAt()).isEqualTo(now); - assertThat(dto.tokenLimitation()).isEqualTo(100000); } @Test @DisplayName("should create AIConnectionDTO with null optional fields") void shouldCreateWithNullOptionalFields() { AIConnectionDTO dto = new AIConnectionDTO( - 1L, null, AIProviderKey.OPENAI, null, null, null, 50000 + 1L, null, AIProviderKey.OPENAI, null, null, null ); assertThat(dto.id()).isEqualTo(1L); @@ -53,24 +52,14 @@ void shouldCreateWithNullOptionalFields() { @Test @DisplayName("should create AIConnectionDTO with different providers") void shouldCreateWithDifferentProviders() { - AIConnectionDTO openai = new AIConnectionDTO(1L, "OpenAI", AIProviderKey.OPENAI, "gpt-4", null, null, 100000); - AIConnectionDTO anthropic = new AIConnectionDTO(2L, "Anthropic", AIProviderKey.ANTHROPIC, "claude-3", null, null, 200000); - AIConnectionDTO google = new AIConnectionDTO(3L, "Google", AIProviderKey.GOOGLE, "gemini-pro", null, null, 150000); + AIConnectionDTO openai = new AIConnectionDTO(1L, "OpenAI", AIProviderKey.OPENAI, "gpt-4", null, null); + AIConnectionDTO anthropic = new AIConnectionDTO(2L, "Anthropic", AIProviderKey.ANTHROPIC, "claude-3", null, null); + AIConnectionDTO google = new AIConnectionDTO(3L, "Google", AIProviderKey.GOOGLE, "gemini-pro", null, null); assertThat(openai.providerKey()).isEqualTo(AIProviderKey.OPENAI); assertThat(anthropic.providerKey()).isEqualTo(AIProviderKey.ANTHROPIC); assertThat(google.providerKey()).isEqualTo(AIProviderKey.GOOGLE); } - - @Test - @DisplayName("should support different token limitations") - void shouldSupportDifferentTokenLimitations() { - AIConnectionDTO small = new AIConnectionDTO(1L, "Small", AIProviderKey.OPENAI, "gpt-3.5", null, null, 16000); - AIConnectionDTO large = new AIConnectionDTO(2L, "Large", AIProviderKey.ANTHROPIC, "claude-3", null, null, 200000); - - assertThat(small.tokenLimitation()).isEqualTo(16000); - assertThat(large.tokenLimitation()).isEqualTo(200000); - } } @Nested @@ -85,7 +74,6 @@ void shouldConvertWithAllFields() { connection.setName("Production AI"); setField(connection, "providerKey", AIProviderKey.ANTHROPIC); setField(connection, "aiModel", "claude-3-opus"); - setField(connection, "tokenLimitation", 100000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -93,7 +81,6 @@ void shouldConvertWithAllFields() { assertThat(dto.name()).isEqualTo("Production AI"); assertThat(dto.providerKey()).isEqualTo(AIProviderKey.ANTHROPIC); assertThat(dto.aiModel()).isEqualTo("claude-3-opus"); - assertThat(dto.tokenLimitation()).isEqualTo(100000); } @Test @@ -104,7 +91,6 @@ void shouldConvertWithNullName() { connection.setName(null); setField(connection, "providerKey", AIProviderKey.OPENAI); setField(connection, "aiModel", "gpt-4"); - setField(connection, "tokenLimitation", 50000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -120,7 +106,6 @@ void shouldConvertWithNullModel() { connection.setName("Test"); setField(connection, "providerKey", AIProviderKey.GOOGLE); setField(connection, "aiModel", null); - setField(connection, "tokenLimitation", 75000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -134,7 +119,6 @@ void shouldConvertAllProviderTypes() { AIConnection connection = new AIConnection(); setField(connection, "id", 1L); setField(connection, "providerKey", providerKey); - setField(connection, "tokenLimitation", 100000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -149,7 +133,6 @@ void shouldHandleTimestamps() { setField(connection, "id", 1L); connection.setName("Test"); setField(connection, "providerKey", AIProviderKey.ANTHROPIC); - setField(connection, "tokenLimitation", 100000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -165,8 +148,8 @@ class EqualityTests { @DisplayName("should be equal for same values") void shouldBeEqualForSameValues() { OffsetDateTime now = OffsetDateTime.now(); - AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now, 100000); - AIConnectionDTO dto2 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now, 100000); + AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now); + AIConnectionDTO dto2 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now); assertThat(dto1).isEqualTo(dto2); assertThat(dto1.hashCode()).isEqualTo(dto2.hashCode()); @@ -176,8 +159,8 @@ void shouldBeEqualForSameValues() { @DisplayName("should not be equal for different values") void shouldNotBeEqualForDifferentValues() { OffsetDateTime now = OffsetDateTime.now(); - AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test1", AIProviderKey.OPENAI, "gpt-4", now, now, 100000); - AIConnectionDTO dto2 = new AIConnectionDTO(2L, "Test2", AIProviderKey.ANTHROPIC, "claude", now, now, 200000); + AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test1", AIProviderKey.OPENAI, "gpt-4", now, now); + AIConnectionDTO dto2 = new AIConnectionDTO(2L, "Test2", AIProviderKey.ANTHROPIC, "claude", now, now); assertThat(dto1).isNotEqualTo(dto2); } diff --git a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java index d2dd3d4d..6adcfaa6 100644 --- a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java +++ b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java @@ -67,25 +67,12 @@ void shouldSetAndGetApiKeyEncrypted() { aiConnection.setApiKeyEncrypted("encrypted-api-key-xyz"); assertThat(aiConnection.getApiKeyEncrypted()).isEqualTo("encrypted-api-key-xyz"); } - - @Test - @DisplayName("Should set and get tokenLimitation") - void shouldSetAndGetTokenLimitation() { - aiConnection.setTokenLimitation(50000); - assertThat(aiConnection.getTokenLimitation()).isEqualTo(50000); - } } @Nested @DisplayName("Default value tests") class DefaultValueTests { - @Test - @DisplayName("Default tokenLimitation should be 100000") - void defaultTokenLimitationShouldBe100000() { - assertThat(aiConnection.getTokenLimitation()).isEqualTo(100000); - } - @Test @DisplayName("Id should be null for new entity") void idShouldBeNullForNewEntity() { @@ -154,14 +141,12 @@ void shouldBeAbleToUpdateAllFields() { aiConnection.setProviderKey(AIProviderKey.ANTHROPIC); aiConnection.setAiModel("claude-3-opus"); aiConnection.setApiKeyEncrypted("new-encrypted-key"); - aiConnection.setTokenLimitation(200000); assertThat(aiConnection.getName()).isEqualTo("Updated Name"); assertThat(aiConnection.getWorkspace()).isSameAs(workspace); assertThat(aiConnection.getProviderKey()).isEqualTo(AIProviderKey.ANTHROPIC); assertThat(aiConnection.getAiModel()).isEqualTo("claude-3-opus"); assertThat(aiConnection.getApiKeyEncrypted()).isEqualTo("new-encrypted-key"); - assertThat(aiConnection.getTokenLimitation()).isEqualTo(200000); } @Test diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java index 0094252d..142cba75 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java @@ -13,9 +13,11 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.ai.AiAnalysisRequest; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsAiClientService; import org.rostilos.codecrow.analysisengine.util.DiffContentFilter; import org.rostilos.codecrow.analysisengine.util.DiffParser; +import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.bitbucket.cloud.actions.GetCommitRangeDiffAction; @@ -172,6 +174,23 @@ public AiAnalysisRequest buildPrAnalysisRequest( originalSize > 0 ? (100 - (filteredSize * 100 / originalSize)) : 0); } + // Check token limit before proceeding with analysis + int maxTokenLimit = project.getEffectiveConfig().maxAnalysisTokenLimit(); + TokenEstimator.TokenEstimationResult tokenEstimate = TokenEstimator.estimateAndCheck(rawDiff, maxTokenLimit); + log.info("Token estimation for PR diff: {}", tokenEstimate.toLogString()); + + if (tokenEstimate.exceedsLimit()) { + log.warn("PR diff exceeds token limit - skipping analysis. Project={}, PR={}, Tokens={}/{}", + project.getId(), request.getPullRequestId(), + tokenEstimate.estimatedTokens(), tokenEstimate.maxAllowedTokens()); + throw new DiffTooLargeException( + tokenEstimate.estimatedTokens(), + tokenEstimate.maxAllowedTokens(), + project.getId(), + request.getPullRequestId() + ); + } + // Determine analysis mode: INCREMENTAL if we have previous analysis with different commit boolean canUseIncremental = previousAnalysis.isPresent() && previousCommitHash != null @@ -232,7 +251,7 @@ public AiAnalysisRequest buildPrAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(projectAiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(prTitle) .withPrDescription(prDescription) @@ -303,7 +322,7 @@ public AiAnalysisRequest buildBranchAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(projectAiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withPreviousAnalysisData(previousAnalysis) - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withTargetBranchName(request.getTargetBranchName()) .withCurrentCommitHash(request.getCommitHash()) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index a216316a..039eb86c 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.core.service.JobService; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.slf4j.Logger; @@ -127,7 +128,18 @@ public void processWebhookAsync( deletePlaceholderComment(provider, project, payload, finalPlaceholderCommentId); } // Delete the job entirely - don't clutter DB with ignored webhooks - jobService.deleteIgnoredJob(job, result.message()); + // If deletion fails, skip the job instead + try { + jobService.deleteIgnoredJob(job, result.message()); + } catch (Exception deleteError) { + log.warn("Failed to delete ignored job {}, skipping instead: {}", + job.getExternalId(), deleteError.getMessage()); + try { + jobService.skipJob(job, result.message()); + } catch (Exception skipError) { + log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); + } + } return; } @@ -151,6 +163,42 @@ public void processWebhookAsync( jobService.failJob(job, result.message()); } + } catch (DiffTooLargeException diffEx) { + // Handle diff too large - this is a soft skip, not an error + log.warn("Diff too large for analysis - skipping: {}", diffEx.getMessage()); + + String skipMessage = String.format( + "⚠️ **Analysis Skipped - PR Too Large**\n\n" + + "This PR's diff exceeds the configured token limit:\n" + + "- **Estimated tokens:** %,d\n" + + "- **Maximum allowed:** %,d (%.1f%% of limit)\n\n" + + "To analyze this PR, consider:\n" + + "1. Breaking it into smaller PRs\n" + + "2. Increasing the token limit in project settings\n" + + "3. Using `/codecrow analyze` command on specific commits", + diffEx.getEstimatedTokens(), + diffEx.getMaxAllowedTokens(), + diffEx.getUtilizationPercentage() + ); + + try { + if (project == null) { + project = projectRepository.findById(projectId).orElse(null); + } + if (project != null) { + initializeProjectAssociations(project); + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); + } + } catch (Exception postError) { + log.error("Failed to post skip message to VCS: {}", postError.getMessage()); + } + + try { + jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); + } catch (Exception skipError) { + log.error("Failed to skip job: {}", skipError.getMessage()); + } + } catch (Exception e) { log.error("Error processing webhook for job {}", job.getExternalId(), e); @@ -390,6 +438,45 @@ private void postErrorToVcs(EVcsProvider provider, Project project, WebhookPaylo } } + /** + * Post an info message to VCS as a comment (for skipped/info scenarios). + * If placeholderCommentId is provided, update that comment with the info. + */ + private void postInfoToVcs(EVcsProvider provider, Project project, WebhookPayload payload, + String infoMessage, String placeholderCommentId, Job job) { + try { + if (payload.pullRequestId() == null) { + return; + } + + VcsReportingService reportingService = vcsServiceFactory.getReportingService(provider); + + // If we have a placeholder comment, update it with the info + if (placeholderCommentId != null) { + reportingService.updateComment( + project, + Long.parseLong(payload.pullRequestId()), + placeholderCommentId, + infoMessage, + CODECROW_COMMAND_MARKER + ); + log.info("Updated placeholder comment {} with info message for PR {}", placeholderCommentId, payload.pullRequestId()); + } else { + // No placeholder - post new info comment + reportingService.postComment( + project, + Long.parseLong(payload.pullRequestId()), + infoMessage, + CODECROW_COMMAND_MARKER + ); + log.info("Posted info message to PR {}", payload.pullRequestId()); + } + + } catch (Exception e) { + log.error("Failed to post info to VCS: {}", e.getMessage()); + } + } + /** * Sanitize error messages for display on VCS platforms. * Removes sensitive technical details like API keys, quotas, and internal stack traces. diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java index b8ceb92f..68c5f87d 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java @@ -375,7 +375,7 @@ private AskRequest buildAskRequest( credentials.oAuthClient(), credentials.oAuthSecret(), credentials.accessToken(), - aiConnection.getTokenLimitation(), + project.getEffectiveConfig().maxAnalysisTokenLimit(), credentials.vcsProviderString(), analysisContext, context.issueReferences() diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java index 78a8dcb2..8d381e79 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java @@ -180,7 +180,7 @@ private ReviewRequest buildReviewRequest(Project project, WebhookPayload payload credentials.oAuthClient(), credentials.oAuthSecret(), credentials.accessToken(), - aiConnection.getTokenLimitation(), + project.getEffectiveConfig().maxAnalysisTokenLimit(), credentials.vcsProviderString() ); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java index c2312413..8e3f992b 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java @@ -292,7 +292,7 @@ private SummarizeRequest buildSummarizeRequest( credentials.oAuthSecret(), credentials.accessToken(), diagramType == PrSummarizeCache.DiagramType.MERMAID, - aiConnection.getTokenLimitation(), + project.getEffectiveConfig().maxAnalysisTokenLimit(), credentials.vcsProviderString() ); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java index b39202de..f8f0dc09 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java @@ -13,9 +13,11 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.AnalysisProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsAiClientService; import org.rostilos.codecrow.analysisengine.util.DiffContentFilter; import org.rostilos.codecrow.analysisengine.util.DiffParser; +import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.github.actions.GetCommitRangeDiffAction; @@ -165,6 +167,23 @@ private AiAnalysisRequest buildPrAnalysisRequest( originalSize > 0 ? (100 - (filteredSize * 100 / originalSize)) : 0); } + // Check token limit before proceeding with analysis + int maxTokenLimit = project.getEffectiveConfig().maxAnalysisTokenLimit(); + TokenEstimator.TokenEstimationResult tokenEstimate = TokenEstimator.estimateAndCheck(rawDiff, maxTokenLimit); + log.info("Token estimation for PR diff: {}", tokenEstimate.toLogString()); + + if (tokenEstimate.exceedsLimit()) { + log.warn("PR diff exceeds token limit - skipping analysis. Project={}, PR={}, Tokens={}/{}", + project.getId(), request.getPullRequestId(), + tokenEstimate.estimatedTokens(), tokenEstimate.maxAllowedTokens()); + throw new DiffTooLargeException( + tokenEstimate.estimatedTokens(), + tokenEstimate.maxAllowedTokens(), + project.getId(), + request.getPullRequestId() + ); + } + // Determine analysis mode: INCREMENTAL if we have previous analysis with different commit boolean canUseIncremental = previousAnalysis.isPresent() && previousCommitHash != null @@ -223,7 +242,7 @@ private AiAnalysisRequest buildPrAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(prTitle) .withPrDescription(prDescription) @@ -291,7 +310,7 @@ private AiAnalysisRequest buildBranchAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withPreviousAnalysisData(previousAnalysis) - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withTargetBranchName(request.getTargetBranchName()) .withCurrentCommitHash(request.getCommitHash()) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java index 1c697fc7..ed2be10a 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java @@ -13,9 +13,11 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.AnalysisProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsAiClientService; import org.rostilos.codecrow.analysisengine.util.DiffContentFilter; import org.rostilos.codecrow.analysisengine.util.DiffParser; +import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.gitlab.actions.GetCommitRangeDiffAction; @@ -166,6 +168,23 @@ private AiAnalysisRequest buildMrAnalysisRequest( originalSize > 0 ? (100 - (filteredSize * 100 / originalSize)) : 0); } + // Check token limit before proceeding with analysis + int maxTokenLimit = project.getEffectiveConfig().maxAnalysisTokenLimit(); + TokenEstimator.TokenEstimationResult tokenEstimate = TokenEstimator.estimateAndCheck(rawDiff, maxTokenLimit); + log.info("Token estimation for MR diff: {}", tokenEstimate.toLogString()); + + if (tokenEstimate.exceedsLimit()) { + log.warn("MR diff exceeds token limit - skipping analysis. Project={}, PR={}, Tokens={}/{}", + project.getId(), request.getPullRequestId(), + tokenEstimate.estimatedTokens(), tokenEstimate.maxAllowedTokens()); + throw new DiffTooLargeException( + tokenEstimate.estimatedTokens(), + tokenEstimate.maxAllowedTokens(), + project.getId(), + request.getPullRequestId() + ); + } + // Determine analysis mode: INCREMENTAL if we have previous analysis with different commit boolean canUseIncremental = previousAnalysis.isPresent() && previousCommitHash != null @@ -224,7 +243,7 @@ private AiAnalysisRequest buildMrAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(mrTitle) .withPrDescription(mrDescription) @@ -292,7 +311,7 @@ private AiAnalysisRequest buildBranchAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withPreviousAnalysisData(previousAnalysis) - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withTargetBranchName(request.getTargetBranchName()) .withCurrentCommitHash(request.getCommitHash()) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java index 15fefd59..77f62466 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java @@ -1,10 +1,12 @@ package org.rostilos.codecrow.pipelineagent.gitlab.webhookhandler; +import org.rostilos.codecrow.core.model.analysis.AnalysisLockType; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; +import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; @@ -15,6 +17,7 @@ import org.springframework.stereotype.Component; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -47,13 +50,16 @@ public class GitLabMergeRequestWebhookHandler extends AbstractWebhookHandler imp private final PullRequestAnalysisProcessor pullRequestAnalysisProcessor; private final VcsServiceFactory vcsServiceFactory; + private final AnalysisLockService analysisLockService; public GitLabMergeRequestWebhookHandler( PullRequestAnalysisProcessor pullRequestAnalysisProcessor, - VcsServiceFactory vcsServiceFactory + VcsServiceFactory vcsServiceFactory, + AnalysisLockService analysisLockService ) { this.pullRequestAnalysisProcessor = pullRequestAnalysisProcessor; this.vcsServiceFactory = vcsServiceFactory; + this.analysisLockService = analysisLockService; } @Override @@ -119,6 +125,25 @@ private WebhookResult handleMergeRequestEvent( String placeholderCommentId = null; try { + // Try to acquire lock atomically BEFORE posting placeholder + // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously + // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will + // reuse this lock since it's for the same project/branch/type + String sourceBranch = payload.sourceBranch(); + Optional earlyLock = analysisLockService.acquireLock( + project, sourceBranch, AnalysisLockType.PR_ANALYSIS, + payload.commitHash(), Long.parseLong(payload.pullRequestId())); + + if (earlyLock.isEmpty()) { + log.info("MR analysis already in progress for project={}, branch={}, MR={} - skipping duplicate webhook", + project.getId(), sourceBranch, payload.pullRequestId()); + return WebhookResult.ignored("MR analysis already in progress for this branch"); + } + + // Lock acquired - placeholder posting is now protected from race conditions + // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it + // since acquireLockWithWait() will detect the existing lock and use it + // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java index 558f236e..307dc44e 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java @@ -13,6 +13,4 @@ public class CreateAIConnectionRequest { public String aiModel; @NotBlank(message = "API key is required") public String apiKey; - @NotBlank(message = "Please specify max token limit") - public String tokenLimitation; } diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java index ae834bb5..f7805b90 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java @@ -9,5 +9,4 @@ public class UpdateAiConnectionRequest { public AIProviderKey providerKey; public String aiModel; public String apiKey; - public String tokenLimitation; } diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java index 510f0901..48949fc7 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java @@ -53,7 +53,6 @@ public AIConnection createAiConnection(Long workspaceId, CreateAIConnectionReque newAiConnection.setProviderKey(request.providerKey); newAiConnection.setAiModel(request.aiModel); newAiConnection.setApiKeyEncrypted(apiKeyEncrypted); - newAiConnection.setTokenLimitation(Integer.parseInt(request.tokenLimitation)); return connectionRepository.save(newAiConnection); } @@ -77,9 +76,6 @@ public AIConnection updateAiConnection(Long workspaceId, Long connectionId, Upda connection.setApiKeyEncrypted(apiKeyEncrypted); } - if(request.tokenLimitation != null && !request.tokenLimitation.isEmpty()) { - connection.setTokenLimitation(Integer.parseInt(request.tokenLimitation)); - } return connectionRepository.save(connection); } diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java index 2c94d528..69d3bef6 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java @@ -36,8 +36,7 @@ void shouldCreateOpenRouterConnection() { { "providerKey": "OPENROUTER", "aiModel": "anthropic/claude-3-haiku", - "apiKey": "test-api-key-openrouter", - "tokenLimitation": "200000" + "apiKey": "test-api-key-openrouter" } """; @@ -50,7 +49,6 @@ void shouldCreateOpenRouterConnection() { .statusCode(201) .body("providerKey", equalTo("OPENROUTER")) .body("aiModel", equalTo("anthropic/claude-3-haiku")) - .body("tokenLimitation", equalTo(200000)) .body("id", notNullValue()); } @@ -62,8 +60,7 @@ void shouldCreateOpenAIConnection() { { "providerKey": "OPENAI", "aiModel": "gpt-4o-mini", - "apiKey": "test-api-key-openai", - "tokenLimitation": "128000" + "apiKey": "test-api-key-openai" } """; @@ -86,8 +83,7 @@ void shouldCreateAnthropicConnection() { { "providerKey": "ANTHROPIC", "aiModel": "claude-3-haiku-20240307", - "apiKey": "test-api-key-anthropic", - "tokenLimitation": "200000" + "apiKey": "test-api-key-anthropic" } """; @@ -110,8 +106,7 @@ void shouldListAIConnections() { { "providerKey": "OPENROUTER", "aiModel": "test-model", - "apiKey": "test-key", - "tokenLimitation": "100000" + "apiKey": "test-key" } """; @@ -140,8 +135,7 @@ void shouldUpdateAIConnection() { { "providerKey": "OPENROUTER", "aiModel": "original-model", - "apiKey": "original-key", - "tokenLimitation": "100000" + "apiKey": "original-key" } """; @@ -158,8 +152,7 @@ void shouldUpdateAIConnection() { { "providerKey": "OPENROUTER", "aiModel": "updated-model", - "apiKey": "updated-key", - "tokenLimitation": "150000" + "apiKey": "updated-key" } """; @@ -170,8 +163,7 @@ void shouldUpdateAIConnection() { .patch("/api/{workspaceSlug}/ai/{connectionId}", testWorkspace.getSlug(), connectionId) .then() .statusCode(200) - .body("aiModel", equalTo("updated-model")) - .body("tokenLimitation", equalTo(150000)); + .body("aiModel", equalTo("updated-model")); } @Test @@ -182,8 +174,7 @@ void shouldDeleteAIConnection() { { "providerKey": "OPENROUTER", "aiModel": "to-delete", - "apiKey": "delete-key", - "tokenLimitation": "100000" + "apiKey": "delete-key" } """; @@ -212,8 +203,7 @@ void shouldRequireAdminRightsForAIOperations() { { "providerKey": "OPENROUTER", "aiModel": "test-model", - "apiKey": "test-key", - "tokenLimitation": "100000" + "apiKey": "test-key" } """; @@ -272,8 +262,7 @@ void shouldValidateProviderKey() { { "providerKey": "INVALID_PROVIDER", "aiModel": "test-model", - "apiKey": "test-key", - "tokenLimitation": "100000" + "apiKey": "test-key" } """; @@ -300,8 +289,7 @@ void shouldPreventCrossWorkspaceAccess() { { "providerKey": "OPENROUTER", "aiModel": "other-ws-model", - "apiKey": "other-ws-key", - "tokenLimitation": "100000" + "apiKey": "other-ws-key" } """; diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java index 5d0ffaca..bb8d4220 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java @@ -313,8 +313,7 @@ void shouldHandleWorkspaceRoleDowngrade() { { "providerKey": "OPENROUTER", "aiModel": "test", - "apiKey": "key", - "tokenLimitation": "100000" + "apiKey": "key" } """; given() diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java index ef858489..69ecb04e 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java @@ -14,7 +14,6 @@ public class AIConnectionBuilder { private AIProviderKey providerKey = AIProviderKey.OPENROUTER; private String aiModel = "anthropic/claude-3-haiku"; private String apiKeyEncrypted = "encrypted-test-key"; - private int tokenLimitation = 200000; public static AIConnectionBuilder anAIConnection() { return new AIConnectionBuilder(); @@ -45,11 +44,6 @@ public AIConnectionBuilder withApiKeyEncrypted(String apiKeyEncrypted) { return this; } - public AIConnectionBuilder withTokenLimitation(int tokenLimitation) { - this.tokenLimitation = tokenLimitation; - return this; - } - public AIConnectionBuilder openAI() { this.providerKey = AIProviderKey.OPENAI; this.aiModel = "gpt-4o-mini"; @@ -75,7 +69,6 @@ public AIConnection build() { connection.setProviderKey(providerKey); connection.setAiModel(aiModel); connection.setApiKeyEncrypted(apiKeyEncrypted); - connection.setTokenLimitation(tokenLimitation); return connection; } } diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java index c819e219..8902eaf5 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java @@ -254,7 +254,6 @@ public AIConnection createTestAiConnection(Workspace workspace, String name, AIP aiConnection.setProviderKey(provider); aiConnection.setAiModel("gpt-4"); aiConnection.setApiKeyEncrypted("test-encrypted-api-key-" + UUID.randomUUID().toString().substring(0, 8)); - aiConnection.setTokenLimitation(100000); return aiConnectionRepository.save(aiConnection); } From 7c780573801fc3dd8c52c39af26c616385237306 Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 26 Jan 2026 18:46:00 +0200 Subject: [PATCH 13/34] feat: Add pre-acquired lock key to prevent double-locking in PR analysis processing. Project PR analysis max analysis token limit implementation --- .../request/processor/PrProcessRequest.java | 9 +++ .../PullRequestAnalysisProcessor.java | 63 +++++++++++-------- .../codecrow/core/dto/project/ProjectDTO.java | 9 ++- .../codecrow/core/model/project/Project.java | 9 +++ .../core/dto/project/ProjectDTOTest.java | 4 +- ...tbucketCloudPullRequestWebhookHandler.java | 8 +-- .../GitHubPullRequestWebhookHandler.java | 8 +-- .../project/controller/ProjectController.java | 6 +- .../project/service/ProjectService.java | 7 ++- 9 files changed, 78 insertions(+), 45 deletions(-) diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java index 5dc47f1e..752efc2a 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java @@ -35,6 +35,13 @@ public class PrProcessRequest implements AnalysisProcessRequest { public String prAuthorId; public String prAuthorUsername; + + /** + * Optional pre-acquired lock key. If set, the processor will skip lock acquisition + * and use this lock key directly. This prevents double-locking when the webhook handler + * has already acquired the lock before calling the processor. + */ + public String preAcquiredLockKey; public Long getProjectId() { @@ -64,4 +71,6 @@ public String getSourceBranchName() { public String getPrAuthorId() { return prAuthorId; } public String getPrAuthorUsername() { return prAuthorUsername; } + + public String getPreAcquiredLockKey() { return preAcquiredLockKey; } } diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index f5db459a..ce7b3292 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -91,34 +91,43 @@ public Map process( // Publish analysis started event publishAnalysisStartedEvent(project, request, correlationId); - Optional lockKey = analysisLockService.acquireLockWithWait( - project, - request.getSourceBranchName(), - AnalysisLockType.PR_ANALYSIS, - request.getCommitHash(), - request.getPullRequestId(), - consumer::accept - ); - - if (lockKey.isEmpty()) { - String message = String.format( - "Failed to acquire lock after %d minutes for project=%s, PR=%d, branch=%s. Another analysis is still in progress.", - analysisLockService.getLockWaitTimeoutMinutes(), - project.getId(), - request.getPullRequestId(), - request.getSourceBranchName() - ); - log.warn(message); - - // Publish failed event due to lock timeout - publishAnalysisCompletedEvent(project, request, correlationId, startTime, - AnalysisCompletedEvent.CompletionStatus.FAILED, 0, 0, "Lock acquisition timeout"); - - throw new AnalysisLockedException( - AnalysisLockType.PR_ANALYSIS.name(), + // Check if a lock was already acquired by the caller (e.g., webhook handler) + // to prevent double-locking which causes unnecessary 2-minute waits + String lockKey; + if (request.getPreAcquiredLockKey() != null && !request.getPreAcquiredLockKey().isBlank()) { + lockKey = request.getPreAcquiredLockKey(); + log.info("Using pre-acquired lock: {} for project={}, PR={}", lockKey, project.getId(), request.getPullRequestId()); + } else { + Optional acquiredLock = analysisLockService.acquireLockWithWait( + project, request.getSourceBranchName(), - project.getId() + AnalysisLockType.PR_ANALYSIS, + request.getCommitHash(), + request.getPullRequestId(), + consumer::accept ); + + if (acquiredLock.isEmpty()) { + String message = String.format( + "Failed to acquire lock after %d minutes for project=%s, PR=%d, branch=%s. Another analysis is still in progress.", + analysisLockService.getLockWaitTimeoutMinutes(), + project.getId(), + request.getPullRequestId(), + request.getSourceBranchName() + ); + log.warn(message); + + // Publish failed event due to lock timeout + publishAnalysisCompletedEvent(project, request, correlationId, startTime, + AnalysisCompletedEvent.CompletionStatus.FAILED, 0, 0, "Lock acquisition timeout"); + + throw new AnalysisLockedException( + AnalysisLockType.PR_ANALYSIS.name(), + request.getSourceBranchName(), + project.getId() + ); + } + lockKey = acquiredLock.get(); } try { @@ -216,7 +225,7 @@ public Map process( return Map.of("status", "error", "message", e.getMessage()); } finally { - analysisLockService.releaseLock(lockKey.get()); + analysisLockService.releaseLock(lockKey); } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java index 9c2a70a3..98027a0d 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java @@ -32,7 +32,8 @@ public record ProjectDTO( String installationMethod, CommentCommandsConfigDTO commentCommandsConfig, Boolean webhooksConfigured, - Long qualityGateId + Long qualityGateId, + Integer maxAnalysisTokenLimit ) { public static ProjectDTO fromProject(Project project) { Long vcsConnectionId = null; @@ -123,6 +124,9 @@ public static ProjectDTO fromProject(Project project) { if (project.getVcsRepoBinding() != null) { webhooksConfigured = project.getVcsRepoBinding().isWebhooksConfigured(); } + + // Get maxAnalysisTokenLimit from config + Integer maxAnalysisTokenLimit = config != null ? config.maxAnalysisTokenLimit() : ProjectConfig.DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; return new ProjectDTO( project.getId(), @@ -146,7 +150,8 @@ public static ProjectDTO fromProject(Project project) { installationMethod, commentCommandsConfigDTO, webhooksConfigured, - project.getQualityGate() != null ? project.getQualityGate().getId() : null + project.getQualityGate() != null ? project.getQualityGate().getId() : null, + maxAnalysisTokenLimit ); } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java index 1a956edd..b12b4227 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java @@ -222,6 +222,15 @@ public void setConfiguration(org.rostilos.codecrow.core.model.project.config.Pro this.configuration = configuration; } + /** + * Returns the effective project configuration. + * If configuration is null, returns a new default ProjectConfig. + * This ensures callers always get a valid config with default values. + */ + public org.rostilos.codecrow.core.model.project.config.ProjectConfig getEffectiveConfig() { + return configuration != null ? configuration : new org.rostilos.codecrow.core.model.project.config.ProjectConfig(); + } + public org.rostilos.codecrow.core.model.branch.Branch getDefaultBranch() { return defaultBranch; } diff --git a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java index 63096a11..2fe071a6 100644 --- a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java +++ b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java @@ -50,7 +50,7 @@ void shouldCreateWithAllFields() { 20L, "namespace", "main", "main", 100L, stats, ragConfig, true, false, "WEBHOOK", - commandsConfig, true, 50L + commandsConfig, true, 50L, 200000 ); assertThat(dto.id()).isEqualTo(1L); @@ -84,7 +84,7 @@ void shouldCreateWithNullOptionalFields() { 1L, "Test", null, true, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, null, null, null, null + null, null, null, null, null, null, null, null ); assertThat(dto.description()).isNull(); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java index 04036114..71124e1d 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java @@ -111,9 +111,7 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro try { // Try to acquire lock atomically BEFORE posting placeholder - // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously - // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will - // reuse this lock since it's for the same project/branch/type + // This prevents race condition where multiple webhooks could post duplicate placeholders String sourceBranch = payload.sourceBranch(); Optional earlyLock = analysisLockService.acquireLock( project, sourceBranch, AnalysisLockType.PR_ANALYSIS, @@ -126,8 +124,6 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro } // Lock acquired - placeholder posting is now protected from race conditions - // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it - // since acquireLockWithWait() will detect the existing lock and use it // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); @@ -143,6 +139,8 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro request.placeholderCommentId = placeholderCommentId; request.prAuthorId = payload.prAuthorId(); request.prAuthorUsername = payload.prAuthorUsername(); + // Pass the pre-acquired lock key to avoid double-locking in the processor + request.preAcquiredLockKey = earlyLock.get(); log.info("Processing PR analysis: project={}, PR={}, source={}, target={}, placeholderCommentId={}", project.getId(), request.pullRequestId, request.sourceBranchName, request.targetBranchName, placeholderCommentId); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java index 2997a833..7f72aeaf 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java @@ -124,9 +124,7 @@ private WebhookResult handlePullRequestEvent( try { // Try to acquire lock atomically BEFORE posting placeholder - // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously - // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will - // reuse this lock since it's for the same project/branch/type + // This prevents race condition where multiple webhooks could post duplicate placeholders String sourceBranch = payload.sourceBranch(); Optional earlyLock = analysisLockService.acquireLock( project, sourceBranch, AnalysisLockType.PR_ANALYSIS, @@ -139,8 +137,6 @@ private WebhookResult handlePullRequestEvent( } // Lock acquired - placeholder posting is now protected from race conditions - // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it - // since acquireLockWithWait() will detect the existing lock and use it // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); @@ -156,6 +152,8 @@ private WebhookResult handlePullRequestEvent( request.placeholderCommentId = placeholderCommentId; request.prAuthorId = payload.prAuthorId(); request.prAuthorUsername = payload.prAuthorUsername(); + // Pass the pre-acquired lock key to avoid double-locking in the processor + request.preAcquiredLockKey = earlyLock.get(); log.info("Processing PR analysis: project={}, PR={}, source={}, target={}, placeholderCommentId={}", project.getId(), request.pullRequestId, request.sourceBranchName, request.targetBranchName, placeholderCommentId); diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java index 9704f48b..774b41fd 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java @@ -584,7 +584,8 @@ public ResponseEntity updateAnalysisSettings( project.getId(), request.prAnalysisEnabled(), request.branchAnalysisEnabled(), - installationMethod + installationMethod, + request.maxAnalysisTokenLimit() ); return new ResponseEntity<>(ProjectDTO.fromProject(updated), HttpStatus.OK); } @@ -592,7 +593,8 @@ public ResponseEntity updateAnalysisSettings( public record UpdateAnalysisSettingsRequest( Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, - String installationMethod + String installationMethod, + Integer maxAnalysisTokenLimit ) {} /** diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java index 07aeb2be..ac98db3a 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java @@ -561,7 +561,8 @@ public Project updateAnalysisSettings( Long projectId, Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, - InstallationMethod installationMethod + InstallationMethod installationMethod, + Integer maxAnalysisTokenLimit ) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -572,6 +573,7 @@ public Project updateAnalysisSettings( var branchAnalysis = currentConfig != null ? currentConfig.branchAnalysis() : null; var ragConfig = currentConfig != null ? currentConfig.ragConfig() : null; var commentCommands = currentConfig != null ? currentConfig.commentCommands() : null; + int currentMaxTokenLimit = currentConfig != null ? currentConfig.maxAnalysisTokenLimit() : ProjectConfig.DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; Boolean newPrAnalysis = prAnalysisEnabled != null ? prAnalysisEnabled : (currentConfig != null ? currentConfig.prAnalysisEnabled() : true); @@ -579,6 +581,7 @@ public Project updateAnalysisSettings( (currentConfig != null ? currentConfig.branchAnalysisEnabled() : true); var newInstallationMethod = installationMethod != null ? installationMethod : (currentConfig != null ? currentConfig.installationMethod() : null); + int newMaxTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : currentMaxTokenLimit; // Update both the direct column and the JSON config //TODO: remove duplication @@ -586,7 +589,7 @@ public Project updateAnalysisSettings( project.setBranchAnalysisEnabled(newBranchAnalysis != null ? newBranchAnalysis : true); project.setConfiguration(new ProjectConfig(useLocalMcp, mainBranch, branchAnalysis, ragConfig, - newPrAnalysis, newBranchAnalysis, newInstallationMethod, commentCommands)); + newPrAnalysis, newBranchAnalysis, newInstallationMethod, commentCommands, newMaxTokenLimit)); return projectRepository.save(project); } From 6d80d71283010177ebe0abac77e3e2552630c254 Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 26 Jan 2026 20:59:16 +0200 Subject: [PATCH 14/34] feat: Implement handling for AnalysisLockedException and DiffTooLargeException in webhook processors --- ...tbucketCloudPullRequestWebhookHandler.java | 6 ++++ .../processor/WebhookAsyncProcessor.java | 35 +++++++++++++++++++ .../CommentCommandWebhookHandler.java | 10 ++++++ .../GitHubPullRequestWebhookHandler.java | 6 ++++ 4 files changed, 57 insertions(+) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java index 71124e1d..68b30ae4 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java @@ -5,6 +5,8 @@ import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; @@ -166,6 +168,10 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro return WebhookResult.success("PR analysis completed", result); + } catch (DiffTooLargeException | AnalysisLockedException e) { + // Re-throw these exceptions so WebhookAsyncProcessor can handle them properly + log.warn("PR analysis failed with recoverable exception for project {}: {}", project.getId(), e.getMessage()); + throw e; } catch (Exception e) { log.error("PR analysis failed for project {}", project.getId(), e); // Try to update placeholder with error message diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 039eb86c..d65c7236 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.core.service.JobService; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; @@ -199,6 +200,40 @@ public void processWebhookAsync( log.error("Failed to skip job: {}", skipError.getMessage()); } + } catch (AnalysisLockedException lockEx) { + // Handle lock acquisition failure - mark job as failed + log.warn("Lock acquisition failed for analysis: {}", lockEx.getMessage()); + + String failMessage = String.format( + "⚠️ **Analysis Failed - Resource Locked**\n\n" + + "Could not acquire analysis lock after timeout:\n" + + "- **Lock type:** %s\n" + + "- **Branch:** %s\n" + + "- **Project:** %d\n\n" + + "Another analysis may be in progress. Please try again later.", + lockEx.getLockType(), + lockEx.getBranchName(), + lockEx.getProjectId() + ); + + try { + if (project == null) { + project = projectRepository.findById(projectId).orElse(null); + } + if (project != null) { + initializeProjectAssociations(project); + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); + } + } catch (Exception postError) { + log.error("Failed to post lock error to VCS: {}", postError.getMessage()); + } + + try { + jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); + } catch (Exception failError) { + log.error("Failed to fail job: {}", failError.getMessage()); + } + } catch (Exception e) { log.error("Error processing webhook for job {}", job.getExternalId(), e); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java index 6267cbd7..0a120fb7 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java @@ -13,6 +13,8 @@ import org.rostilos.codecrow.core.persistence.repository.codeanalysis.PrSummarizeCacheRepository; import org.rostilos.codecrow.core.service.CodeAnalysisService; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; import org.rostilos.codecrow.pipelineagent.generic.service.CommandAuthorizationService; @@ -445,6 +447,14 @@ private WebhookResult runPrAnalysis( // If we got here, the processor posted results directly (which it does) return WebhookResult.success("Analysis completed", Map.of("commandType", commandType)); + } catch (DiffTooLargeException e) { + // Re-throw DiffTooLargeException so WebhookAsyncProcessor can handle it with proper job status + log.warn("PR diff too large for {} command: {}", commandType, e.getMessage()); + throw e; + } catch (AnalysisLockedException e) { + // Re-throw AnalysisLockedException so WebhookAsyncProcessor can handle it with proper job status + log.warn("Lock acquisition failed for {} command: {}", commandType, e.getMessage()); + throw e; } catch (Exception e) { log.error("Error running PR analysis for {} command: {}", commandType, e.getMessage(), e); return WebhookResult.error("Analysis failed: " + e.getMessage()); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java index 7f72aeaf..87f949f9 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java @@ -5,6 +5,8 @@ import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; @@ -179,6 +181,10 @@ private WebhookResult handlePullRequestEvent( return WebhookResult.success("PR analysis completed", result); + } catch (DiffTooLargeException | AnalysisLockedException e) { + // Re-throw these exceptions so WebhookAsyncProcessor can handle them properly + log.warn("PR analysis failed with recoverable exception for project {}: {}", project.getId(), e.getMessage()); + throw e; } catch (Exception e) { log.error("PR analysis failed for project {}", project.getId(), e); // Try to update placeholder with error message From e2c14743383242c33891d4d6f989ceec2bb92b11 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 10:35:19 +0200 Subject: [PATCH 15/34] feat: Re-fetch job entities in transaction methods to handle detached entities from async contexts --- .../codecrow/core/service/JobService.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 04036c47..0a50c269 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -221,9 +221,12 @@ private String getCommandJobTitle(JobType type, Long prNumber) { /** * Start a job (transition from PENDING to RUNNING). + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job startJob(Job job) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.start(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "start", "Job started"); @@ -241,9 +244,12 @@ public Job startJob(String externalId) { /** * Complete a job successfully. + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.complete(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "complete", "Job completed successfully"); @@ -253,9 +259,12 @@ public Job completeJob(Job job) { /** * Complete a job and link it to a code analysis. + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job, CodeAnalysis codeAnalysis) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.setCodeAnalysis(codeAnalysis); job.complete(); job = jobRepository.save(job); @@ -284,9 +293,12 @@ public Job failJob(Job job, String errorMessage) { /** * Cancel a job. + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job cancelJob(Job job) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.cancel(); job = jobRepository.save(job); addLog(job, JobLogLevel.WARN, "cancel", "Job cancelled"); @@ -296,9 +308,12 @@ public Job cancelJob(Job job) { /** * Skip a job (e.g., due to branch pattern settings). + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job skipJob(Job job, String reason) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.skip(reason); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "skipped", reason); From 342c4fadbbe9b09dbe5ab56c69c02af4493fb13f Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 11:01:57 +0200 Subject: [PATCH 16/34] feat: Update JobService and WebhookAsyncProcessor to manage job entities without re-fetching in async contexts --- .../codecrow/core/service/JobService.java | 15 --- .../processor/WebhookAsyncProcessor.java | 92 ++++++++++++++----- 2 files changed, 70 insertions(+), 37 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 0a50c269..04036c47 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -221,12 +221,9 @@ private String getCommandJobTitle(JobType type, Long prNumber) { /** * Start a job (transition from PENDING to RUNNING). - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job startJob(Job job) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.start(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "start", "Job started"); @@ -244,12 +241,9 @@ public Job startJob(String externalId) { /** * Complete a job successfully. - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.complete(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "complete", "Job completed successfully"); @@ -259,12 +253,9 @@ public Job completeJob(Job job) { /** * Complete a job and link it to a code analysis. - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job, CodeAnalysis codeAnalysis) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.setCodeAnalysis(codeAnalysis); job.complete(); job = jobRepository.save(job); @@ -293,12 +284,9 @@ public Job failJob(Job job, String errorMessage) { /** * Cancel a job. - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job cancelJob(Job job) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.cancel(); job = jobRepository.save(job); addLog(job, JobLogLevel.WARN, "cancel", "Job cancelled"); @@ -308,12 +296,9 @@ public Job cancelJob(Job job) { /** * Skip a job (e.g., due to branch pattern settings). - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job skipJob(Job job, String reason) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.skip(reason); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "skipped", reason); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index d65c7236..133c148f 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -82,10 +82,11 @@ public WebhookAsyncProcessor( } /** - * Process a webhook asynchronously with proper transactional context. + * Process a webhook asynchronously. + * Note: This method is NOT transactional to avoid issues with nested transactions + * (e.g., failJob uses REQUIRES_NEW). Each inner operation manages its own transaction. */ @Async("webhookExecutor") - @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, @@ -96,19 +97,35 @@ public void processWebhookAsync( String placeholderCommentId = null; Project project = null; + // Store job external ID for re-fetching - the passed Job entity is detached + // since it was created in the HTTP request transaction which has already committed + String jobExternalId = job.getExternalId(); + + // Declare managed job reference that will be set after re-fetching + // This needs to be accessible in catch blocks for error handling + Job managedJob = null; + try { - // Re-fetch project within transaction to ensure all lazy associations are available + // Re-fetch project to ensure all lazy associations are available project = projectRepository.findById(projectId) .orElseThrow(() -> new IllegalStateException("Project not found: " + projectId)); // Initialize lazy associations we'll need initializeProjectAssociations(project); - jobService.startJob(job); + // Re-fetch the job by external ID to get a managed entity in the current context + // This is necessary because the Job was created in the HTTP request transaction + // which has already committed by the time this async method runs + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + + // Create final reference for use in lambda + final Job jobForLambda = managedJob; + + jobService.startJob(managedJob); // Post placeholder comment immediately if this is a CodeCrow command on a PR if (payload.hasCodecrowCommand() && payload.pullRequestId() != null) { - placeholderCommentId = postPlaceholderComment(provider, project, payload, job); + placeholderCommentId = postPlaceholderComment(provider, project, payload, managedJob); } // Store placeholder ID for use in result posting @@ -118,7 +135,7 @@ public void processWebhookAsync( WebhookHandler.WebhookResult result = handler.handle(payload, project, event -> { String message = (String) event.getOrDefault("message", "Processing..."); String state = (String) event.getOrDefault("state", "processing"); - jobService.info(job, state, message); + jobService.info(jobForLambda, state, message); }); // Check if the webhook was ignored (e.g., branch not matching pattern, analysis disabled) @@ -131,14 +148,14 @@ public void processWebhookAsync( // Delete the job entirely - don't clutter DB with ignored webhooks // If deletion fails, skip the job instead try { - jobService.deleteIgnoredJob(job, result.message()); + jobService.deleteIgnoredJob(managedJob, result.message()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", - job.getExternalId(), deleteError.getMessage()); + managedJob.getExternalId(), deleteError.getMessage()); try { - jobService.skipJob(job, result.message()); + jobService.skipJob(managedJob, result.message()); } catch (Exception skipError) { - log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); + log.error("Failed to skip job {}: {}", managedJob.getExternalId(), skipError.getMessage()); } } return; @@ -146,28 +163,38 @@ public void processWebhookAsync( if (result.success()) { // Post result to VCS if there's content to post - postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, job); + postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, managedJob); if (result.data().containsKey("analysisId")) { Long analysisId = ((Number) result.data().get("analysisId")).longValue(); - jobService.info(job, "complete", "Analysis completed. Analysis ID: " + analysisId); + jobService.info(managedJob, "complete", "Analysis completed. Analysis ID: " + analysisId); } - jobService.completeJob(job); + jobService.completeJob(managedJob); } else { // Post error to VCS (update placeholder if exists) - but ensure failJob is always called try { - postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, job); + postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, managedJob); } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } // Always mark the job as failed, even if posting to VCS failed - jobService.failJob(job, result.message()); + jobService.failJob(managedJob, result.message()); } } catch (DiffTooLargeException diffEx) { // Handle diff too large - this is a soft skip, not an error log.warn("Diff too large for analysis - skipping: {}", diffEx.getMessage()); + // Re-fetch job if not yet fetched + if (managedJob == null) { + try { + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + } catch (Exception fetchError) { + log.error("Failed to fetch job {} for skip operation: {}", jobExternalId, fetchError.getMessage()); + return; + } + } + String skipMessage = String.format( "⚠️ **Analysis Skipped - PR Too Large**\n\n" + "This PR's diff exceeds the configured token limit:\n" + @@ -188,14 +215,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, managedJob); } } catch (Exception postError) { log.error("Failed to post skip message to VCS: {}", postError.getMessage()); } try { - jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); + jobService.skipJob(managedJob, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); } catch (Exception skipError) { log.error("Failed to skip job: {}", skipError.getMessage()); } @@ -204,6 +231,16 @@ public void processWebhookAsync( // Handle lock acquisition failure - mark job as failed log.warn("Lock acquisition failed for analysis: {}", lockEx.getMessage()); + // Re-fetch job if not yet fetched + if (managedJob == null) { + try { + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + } catch (Exception fetchError) { + log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); + return; + } + } + String failMessage = String.format( "⚠️ **Analysis Failed - Resource Locked**\n\n" + "Could not acquire analysis lock after timeout:\n" + @@ -222,20 +259,31 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, managedJob); } } catch (Exception postError) { log.error("Failed to post lock error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); + jobService.failJob(managedJob, "Lock acquisition timeout: " + lockEx.getMessage()); } catch (Exception failError) { log.error("Failed to fail job: {}", failError.getMessage()); } } catch (Exception e) { - log.error("Error processing webhook for job {}", job.getExternalId(), e); + // Re-fetch job if not yet fetched + if (managedJob == null) { + try { + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + } catch (Exception fetchError) { + log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); + log.error("Original error processing webhook", e); + return; + } + } + + log.error("Error processing webhook for job {}", managedJob.getExternalId(), e); try { if (project == null) { @@ -243,14 +291,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, job); + postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, managedJob); } } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(job, "Processing failed: " + e.getMessage()); + jobService.failJob(managedJob, "Processing failed: " + e.getMessage()); } catch (Exception failError) { log.error("Failed to mark job as failed: {}", failError.getMessage()); } From 409c42df5ffaf13b6ab101f74496941d87977d3a Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 11:05:40 +0200 Subject: [PATCH 17/34] feat: Enable transaction management in processWebhookAsync to support lazy loading of associations --- .../generic/processor/WebhookAsyncProcessor.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 133c148f..52ce4565 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -83,10 +83,11 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. - * Note: This method is NOT transactional to avoid issues with nested transactions - * (e.g., failJob uses REQUIRES_NEW). Each inner operation manages its own transaction. + * This method uses a transaction to ensure lazy associations can be loaded. + * Inner operations like failJob use REQUIRES_NEW which creates nested transactions as needed. */ @Async("webhookExecutor") + @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, From 11c983c4b5377c330f3066cc5541b483f802547c Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 11:57:26 +0200 Subject: [PATCH 18/34] feat: Re-fetch job entities in JobService methods to ensure consistency across transaction contexts --- .../codecrow/core/service/JobService.java | 8 ++ .../processor/WebhookAsyncProcessor.java | 88 ++++--------------- 2 files changed, 27 insertions(+), 69 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 04036c47..29148529 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -224,6 +224,8 @@ private String getCommandJobTitle(JobType type, Long prNumber) { */ @Transactional public Job startJob(Job job) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.start(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "start", "Job started"); @@ -244,6 +246,8 @@ public Job startJob(String externalId) { */ @Transactional public Job completeJob(Job job) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.complete(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "complete", "Job completed successfully"); @@ -287,6 +291,8 @@ public Job failJob(Job job, String errorMessage) { */ @Transactional public Job cancelJob(Job job) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.cancel(); job = jobRepository.save(job); addLog(job, JobLogLevel.WARN, "cancel", "Job cancelled"); @@ -299,6 +305,8 @@ public Job cancelJob(Job job) { */ @Transactional public Job skipJob(Job job, String reason) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.skip(reason); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "skipped", reason); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 52ce4565..1d41b507 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -83,11 +83,8 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. - * This method uses a transaction to ensure lazy associations can be loaded. - * Inner operations like failJob use REQUIRES_NEW which creates nested transactions as needed. */ @Async("webhookExecutor") - @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, @@ -98,14 +95,6 @@ public void processWebhookAsync( String placeholderCommentId = null; Project project = null; - // Store job external ID for re-fetching - the passed Job entity is detached - // since it was created in the HTTP request transaction which has already committed - String jobExternalId = job.getExternalId(); - - // Declare managed job reference that will be set after re-fetching - // This needs to be accessible in catch blocks for error handling - Job managedJob = null; - try { // Re-fetch project to ensure all lazy associations are available project = projectRepository.findById(projectId) @@ -114,19 +103,11 @@ public void processWebhookAsync( // Initialize lazy associations we'll need initializeProjectAssociations(project); - // Re-fetch the job by external ID to get a managed entity in the current context - // This is necessary because the Job was created in the HTTP request transaction - // which has already committed by the time this async method runs - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - - // Create final reference for use in lambda - final Job jobForLambda = managedJob; - - jobService.startJob(managedJob); + jobService.startJob(job); // Post placeholder comment immediately if this is a CodeCrow command on a PR if (payload.hasCodecrowCommand() && payload.pullRequestId() != null) { - placeholderCommentId = postPlaceholderComment(provider, project, payload, managedJob); + placeholderCommentId = postPlaceholderComment(provider, project, payload, job); } // Store placeholder ID for use in result posting @@ -136,7 +117,7 @@ public void processWebhookAsync( WebhookHandler.WebhookResult result = handler.handle(payload, project, event -> { String message = (String) event.getOrDefault("message", "Processing..."); String state = (String) event.getOrDefault("state", "processing"); - jobService.info(jobForLambda, state, message); + jobService.info(job, state, message); }); // Check if the webhook was ignored (e.g., branch not matching pattern, analysis disabled) @@ -149,14 +130,14 @@ public void processWebhookAsync( // Delete the job entirely - don't clutter DB with ignored webhooks // If deletion fails, skip the job instead try { - jobService.deleteIgnoredJob(managedJob, result.message()); + jobService.deleteIgnoredJob(job, result.message()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", - managedJob.getExternalId(), deleteError.getMessage()); + job.getExternalId(), deleteError.getMessage()); try { - jobService.skipJob(managedJob, result.message()); + jobService.skipJob(job, result.message()); } catch (Exception skipError) { - log.error("Failed to skip job {}: {}", managedJob.getExternalId(), skipError.getMessage()); + log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); } } return; @@ -164,38 +145,28 @@ public void processWebhookAsync( if (result.success()) { // Post result to VCS if there's content to post - postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, managedJob); + postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, job); if (result.data().containsKey("analysisId")) { Long analysisId = ((Number) result.data().get("analysisId")).longValue(); - jobService.info(managedJob, "complete", "Analysis completed. Analysis ID: " + analysisId); + jobService.info(job, "complete", "Analysis completed. Analysis ID: " + analysisId); } - jobService.completeJob(managedJob); + jobService.completeJob(job); } else { // Post error to VCS (update placeholder if exists) - but ensure failJob is always called try { - postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, managedJob); + postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, job); } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } // Always mark the job as failed, even if posting to VCS failed - jobService.failJob(managedJob, result.message()); + jobService.failJob(job, result.message()); } } catch (DiffTooLargeException diffEx) { // Handle diff too large - this is a soft skip, not an error log.warn("Diff too large for analysis - skipping: {}", diffEx.getMessage()); - // Re-fetch job if not yet fetched - if (managedJob == null) { - try { - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - } catch (Exception fetchError) { - log.error("Failed to fetch job {} for skip operation: {}", jobExternalId, fetchError.getMessage()); - return; - } - } - String skipMessage = String.format( "⚠️ **Analysis Skipped - PR Too Large**\n\n" + "This PR's diff exceeds the configured token limit:\n" + @@ -216,14 +187,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, managedJob); + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); } } catch (Exception postError) { log.error("Failed to post skip message to VCS: {}", postError.getMessage()); } try { - jobService.skipJob(managedJob, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); + jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); } catch (Exception skipError) { log.error("Failed to skip job: {}", skipError.getMessage()); } @@ -232,16 +203,6 @@ public void processWebhookAsync( // Handle lock acquisition failure - mark job as failed log.warn("Lock acquisition failed for analysis: {}", lockEx.getMessage()); - // Re-fetch job if not yet fetched - if (managedJob == null) { - try { - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - } catch (Exception fetchError) { - log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); - return; - } - } - String failMessage = String.format( "⚠️ **Analysis Failed - Resource Locked**\n\n" + "Could not acquire analysis lock after timeout:\n" + @@ -260,31 +221,20 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, managedJob); + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); } } catch (Exception postError) { log.error("Failed to post lock error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(managedJob, "Lock acquisition timeout: " + lockEx.getMessage()); + jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); } catch (Exception failError) { log.error("Failed to fail job: {}", failError.getMessage()); } } catch (Exception e) { - // Re-fetch job if not yet fetched - if (managedJob == null) { - try { - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - } catch (Exception fetchError) { - log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); - log.error("Original error processing webhook", e); - return; - } - } - - log.error("Error processing webhook for job {}", managedJob.getExternalId(), e); + log.error("Error processing webhook for job {}", job.getExternalId(), e); try { if (project == null) { @@ -292,14 +242,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, managedJob); + postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, job); } } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(managedJob, "Processing failed: " + e.getMessage()); + jobService.failJob(job, "Processing failed: " + e.getMessage()); } catch (Exception failError) { log.error("Failed to mark job as failed: {}", failError.getMessage()); } From c75eaba20c13abbc0a170271f3a94667093c2745 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 12:04:31 +0200 Subject: [PATCH 19/34] feat: Add @Transactional annotation to processWebhookAsync for lazy loading of associations --- .../pipelineagent/generic/processor/WebhookAsyncProcessor.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 1d41b507..f9e210a2 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -83,8 +83,10 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. + * Uses @Transactional to ensure lazy associations can be loaded. */ @Async("webhookExecutor") + @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, From 8afc0ada6a03ed71399c630d65cdbe00519d865a Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 12:27:09 +0200 Subject: [PATCH 20/34] feat: Implement self-injection in WebhookAsyncProcessor for proper transaction management in async context --- .../processor/WebhookAsyncProcessor.java | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index f9e210a2..85dd25b2 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -15,6 +15,8 @@ import java.io.IOException; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -71,6 +73,11 @@ public class WebhookAsyncProcessor { private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; + // Self-injection for @Transactional proxy to work from @Async method + @Autowired + @Lazy + private WebhookAsyncProcessor self; + public WebhookAsyncProcessor( ProjectRepository projectRepository, JobService jobService, @@ -83,16 +90,35 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. - * Uses @Transactional to ensure lazy associations can be loaded. + * Delegates to a transactional method to ensure lazy associations can be loaded. + * NOTE: @Async and @Transactional cannot be on the same method - the transaction + * proxy gets bypassed. We use self-injection to call a separate @Transactional method. */ @Async("webhookExecutor") - @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, WebhookPayload payload, WebhookHandler handler, Job job + ) { + log.info("processWebhookAsync started for job {} (projectId={}, event={})", + job.getExternalId(), projectId, payload.eventType()); + // Delegate to transactional method via self-reference to ensure proxy is used + self.processWebhookInTransaction(provider, projectId, payload, handler, job); + } + + /** + * Process webhook within a transaction. + * Called from async method via self-injection to ensure transaction proxy works. + */ + @Transactional + public void processWebhookInTransaction( + EVcsProvider provider, + Long projectId, + WebhookPayload payload, + WebhookHandler handler, + Job job ) { String placeholderCommentId = null; Project project = null; From 402486b97def35d3098f8e2a6a8b4c2b4c72f4c4 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 13:09:26 +0200 Subject: [PATCH 21/34] feat: Enhance logging and error handling in processWebhookAsync for improved job management --- .../processor/WebhookAsyncProcessor.java | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 85dd25b2..819855ee 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -104,8 +104,19 @@ public void processWebhookAsync( ) { log.info("processWebhookAsync started for job {} (projectId={}, event={})", job.getExternalId(), projectId, payload.eventType()); - // Delegate to transactional method via self-reference to ensure proxy is used - self.processWebhookInTransaction(provider, projectId, payload, handler, job); + try { + // Delegate to transactional method via self-reference to ensure proxy is used + self.processWebhookInTransaction(provider, projectId, payload, handler, job); + log.info("processWebhookAsync completed normally for job {}", job.getExternalId()); + } catch (Exception e) { + log.error("processWebhookAsync FAILED for job {}: {}", job.getExternalId(), e.getMessage(), e); + // Try to fail the job so it doesn't stay in PENDING + try { + jobService.failJob(job, "Async processing failed: " + e.getMessage()); + } catch (Exception failError) { + log.error("Failed to mark job {} as failed: {}", job.getExternalId(), failError.getMessage()); + } + } } /** @@ -120,6 +131,7 @@ public void processWebhookInTransaction( WebhookHandler handler, Job job ) { + log.info("processWebhookInTransaction ENTERED for job {}", job.getExternalId()); String placeholderCommentId = null; Project project = null; @@ -131,7 +143,9 @@ public void processWebhookInTransaction( // Initialize lazy associations we'll need initializeProjectAssociations(project); + log.info("Calling jobService.startJob for job {}", job.getExternalId()); jobService.startJob(job); + log.info("jobService.startJob completed for job {}", job.getExternalId()); // Post placeholder comment immediately if this is a CodeCrow command on a PR if (payload.hasCodecrowCommand() && payload.pullRequestId() != null) { @@ -142,15 +156,17 @@ public void processWebhookInTransaction( final String finalPlaceholderCommentId = placeholderCommentId; // Create event consumer that logs to job + log.info("Calling handler.handle for job {}", job.getExternalId()); WebhookHandler.WebhookResult result = handler.handle(payload, project, event -> { String message = (String) event.getOrDefault("message", "Processing..."); String state = (String) event.getOrDefault("state", "processing"); jobService.info(job, state, message); }); + log.info("handler.handle completed for job {}, result status={}", job.getExternalId(), result.status()); // Check if the webhook was ignored (e.g., branch not matching pattern, analysis disabled) if ("ignored".equals(result.status())) { - log.info("Webhook ignored: {}", result.message()); + log.info("Webhook ignored for job {}: {}", job.getExternalId(), result.message()); // Delete placeholder if we posted one for an ignored command if (finalPlaceholderCommentId != null) { deletePlaceholderComment(provider, project, payload, finalPlaceholderCommentId); @@ -158,7 +174,9 @@ public void processWebhookInTransaction( // Delete the job entirely - don't clutter DB with ignored webhooks // If deletion fails, skip the job instead try { + log.info("Deleting ignored job {}", job.getExternalId()); jobService.deleteIgnoredJob(job, result.message()); + log.info("Successfully deleted ignored job {}", job.getExternalId()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", job.getExternalId(), deleteError.getMessage()); @@ -179,7 +197,9 @@ public void processWebhookInTransaction( Long analysisId = ((Number) result.data().get("analysisId")).longValue(); jobService.info(job, "complete", "Analysis completed. Analysis ID: " + analysisId); } + log.info("Calling jobService.completeJob for job {}", job.getExternalId()); jobService.completeJob(job); + log.info("jobService.completeJob completed for job {}", job.getExternalId()); } else { // Post error to VCS (update placeholder if exists) - but ensure failJob is always called try { From fdcdca0aac15e47ba8f99826d810384a55cfa854 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:24:23 +0200 Subject: [PATCH 22/34] feat: Implement webhook deduplication service to prevent duplicate commit analysis --- frontend | 2 +- .../BitbucketCloudBranchWebhookHandler.java | 18 +++- .../pipelineagent/config/AsyncConfig.java | 11 ++- .../controller/ProviderWebhookController.java | 5 ++ .../service/WebhookDeduplicationService.java | 86 +++++++++++++++++++ 5 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java diff --git a/frontend b/frontend index fdbb0555..d97b8264 160000 --- a/frontend +++ b/frontend @@ -1 +1 @@ -Subproject commit fdbb055524794f49a0299fd7f020177243855e58 +Subproject commit d97b826464e13edb3da8a2b9a3e4b32680f23001 diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java index 81dd08c9..4b6a4a9e 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.processor.analysis.BranchAnalysisProcessor; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; +import org.rostilos.codecrow.pipelineagent.generic.service.WebhookDeduplicationService; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.AbstractWebhookHandler; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; import org.slf4j.Logger; @@ -38,13 +39,16 @@ public class BitbucketCloudBranchWebhookHandler extends AbstractWebhookHandler i private final BranchAnalysisProcessor branchAnalysisProcessor; private final RagOperationsService ragOperationsService; + private final WebhookDeduplicationService deduplicationService; public BitbucketCloudBranchWebhookHandler( BranchAnalysisProcessor branchAnalysisProcessor, - @Autowired(required = false) RagOperationsService ragOperationsService + @Autowired(required = false) RagOperationsService ragOperationsService, + WebhookDeduplicationService deduplicationService ) { this.branchAnalysisProcessor = branchAnalysisProcessor; this.ragOperationsService = ragOperationsService; + this.deduplicationService = deduplicationService; } @Override @@ -87,6 +91,12 @@ public WebhookResult handle(WebhookPayload payload, Project project, Consumer { + log.error("WEBHOOK EXECUTOR REJECTED TASK! Queue is full. Pool size: {}, Active: {}, Queue size: {}", + e.getPoolSize(), e.getActiveCount(), e.getQueue().size()); + // Try to run in caller thread as fallback + if (!e.isShutdown()) { + r.run(); + } + }); executor.initialize(); - log.info("Webhook executor initialized with core={}, max={}", 4, 8); + log.info("Webhook executor initialized with core={}, max={}, queueCapacity={}", 4, 8, 100); return executor; } diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java index d475014f..95ecfc0a 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java @@ -260,6 +260,9 @@ private ResponseEntity processWebhook(EVcsProvider provider, WebhookPayload p String jobUrl = buildJobUrl(project, job); String logsStreamUrl = buildJobLogsStreamUrl(job); + log.info("Dispatching webhook to async processor: job={}, event={}", + job.getExternalId(), payload.eventType()); + // Process webhook asynchronously with proper transactional context webhookAsyncProcessor.processWebhookAsync( provider, @@ -269,6 +272,8 @@ private ResponseEntity processWebhook(EVcsProvider provider, WebhookPayload p job ); + log.info("Webhook dispatched to async processor: job={}", job.getExternalId()); + return ResponseEntity.accepted().body(Map.of( "status", "accepted", "message", "Webhook received, processing started", diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java new file mode 100644 index 00000000..a7f87088 --- /dev/null +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java @@ -0,0 +1,86 @@ +package org.rostilos.codecrow.pipelineagent.generic.service; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Service to deduplicate webhook events based on commit hash. + * + * When a PR is merged in Bitbucket, it sends both: + * - pullrequest:fulfilled (with merge commit) + * - repo:push (with same merge commit) + * + * Both events would trigger analysis for the same commit, causing duplicate processing. + * This service tracks recently analyzed commits and skips duplicates within a time window. + */ +@Service +public class WebhookDeduplicationService { + + private static final Logger log = LoggerFactory.getLogger(WebhookDeduplicationService.class); + + /** + * Time window in seconds to consider events as duplicates. + */ + private static final long DEDUP_WINDOW_SECONDS = 30; + + /** + * Cache of recently analyzed commits. + * Key: "projectId:commitHash" + * Value: timestamp when the analysis was triggered + */ + private final Map recentCommitAnalyses = new ConcurrentHashMap<>(); + + /** + * Check if a commit analysis should be skipped as a duplicate. + * If not a duplicate, records this commit for future deduplication. + * + * @param projectId The project ID + * @param commitHash The commit being analyzed + * @param eventType The webhook event type (for logging) + * @return true if this is a duplicate and should be skipped, false if it should proceed + */ + public boolean isDuplicateCommitAnalysis(Long projectId, String commitHash, String eventType) { + if (commitHash == null || commitHash.isBlank()) { + return false; + } + + String key = projectId + ":" + commitHash; + Instant now = Instant.now(); + + Instant lastAnalysis = recentCommitAnalyses.get(key); + + if (lastAnalysis != null) { + long secondsSinceLastAnalysis = now.getEpochSecond() - lastAnalysis.getEpochSecond(); + + if (secondsSinceLastAnalysis < DEDUP_WINDOW_SECONDS) { + log.info("Skipping duplicate commit analysis: project={}, commit={}, event={}, " + + "lastAnalysis={}s ago (within {}s window)", + projectId, commitHash, eventType, secondsSinceLastAnalysis, DEDUP_WINDOW_SECONDS); + return true; + } + } + + // Record this analysis + recentCommitAnalyses.put(key, now); + + // Cleanup old entries + cleanupOldEntries(now); + + return false; + } + + /** + * Remove entries older than the dedup window to prevent memory growth. + */ + private void cleanupOldEntries(Instant now) { + recentCommitAnalyses.entrySet().removeIf(entry -> { + long age = now.getEpochSecond() - entry.getValue().getEpochSecond(); + return age > DEDUP_WINDOW_SECONDS * 2; + }); + } +} From e3213617d76e45d1d15cdbec1bc89452420f06a3 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:35:33 +0200 Subject: [PATCH 23/34] feat: Enhance job deletion process with logging and persistence context management --- .../rostilos/codecrow/core/service/JobService.java | 3 +++ .../pipeline-agent/src/main/java/module-info.java | 1 + .../generic/processor/WebhookAsyncProcessor.java | 11 +++++++++++ 3 files changed, 15 insertions(+) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 29148529..39a9aca2 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -338,8 +338,11 @@ public void deleteIgnoredJob(Job job, String reason) { } // Delete any logs first (foreign key constraint) jobLogRepository.deleteByJobId(jobId); + log.info("Deleted job logs for ignored job {}", job.getExternalId()); // Delete the job jobRepository.delete(existingJob.get()); + jobRepository.flush(); // Force immediate execution + log.info("Successfully deleted ignored job {} from database", job.getExternalId()); } /** diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java b/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java index 739f411f..3188b03a 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java @@ -9,6 +9,7 @@ requires spring.beans; requires org.slf4j; requires jakarta.validation; + requires jakarta.persistence; requires spring.web; requires jjwt.api; requires okhttp3; diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 819855ee..f374edea 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -13,6 +13,8 @@ import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.slf4j.Logger; +import jakarta.persistence.EntityManager; +import jakarta.persistence.PersistenceContext; import java.io.IOException; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -73,6 +75,9 @@ public class WebhookAsyncProcessor { private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; + @PersistenceContext + private EntityManager entityManager; + // Self-injection for @Transactional proxy to work from @Async method @Autowired @Lazy @@ -176,6 +181,12 @@ public void processWebhookInTransaction( try { log.info("Deleting ignored job {}", job.getExternalId()); jobService.deleteIgnoredJob(job, result.message()); + // CRITICAL: Detach the job from this transaction's persistence context + // to prevent JPA from re-saving it when the outer transaction commits + if (entityManager.contains(job)) { + entityManager.detach(job); + log.info("Detached deleted job {} from persistence context", job.getExternalId()); + } log.info("Successfully deleted ignored job {}", job.getExternalId()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", From ebd0fad92d106dec8dbf6cfd550c1c74ca3f26b2 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:43:36 +0200 Subject: [PATCH 24/34] feat: Improve job deletion process with enhanced logging and error handling --- .../codecrow/core/service/JobService.java | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 39a9aca2..31840dcf 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -336,13 +336,21 @@ public void deleteIgnoredJob(Job job, String reason) { log.warn("Cannot delete ignored job {} - not found in database", job.getExternalId()); return; } - // Delete any logs first (foreign key constraint) - jobLogRepository.deleteByJobId(jobId); - log.info("Deleted job logs for ignored job {}", job.getExternalId()); - // Delete the job - jobRepository.delete(existingJob.get()); - jobRepository.flush(); // Force immediate execution - log.info("Successfully deleted ignored job {} from database", job.getExternalId()); + try { + // Delete any logs first (foreign key constraint) + jobLogRepository.deleteByJobId(jobId); + jobLogRepository.flush(); + log.info("Deleted job logs for ignored job {}", job.getExternalId()); + + // Delete the job + log.info("About to delete job entity {} (id={})", job.getExternalId(), jobId); + jobRepository.deleteById(jobId); + jobRepository.flush(); // Force immediate execution + log.info("Successfully deleted ignored job {} from database", job.getExternalId()); + } catch (Exception e) { + log.error("Failed to delete ignored job {}: {} - {}", job.getExternalId(), e.getClass().getSimpleName(), e.getMessage(), e); + throw e; // Re-throw so caller knows deletion failed + } } /** From 092b36138bf1dabf66c74f3d53a7dc768921aa31 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:56:56 +0200 Subject: [PATCH 25/34] feat: Add method to delete job by ID in JobRepository and update JobService for direct deletion --- .../repository/job/JobRepository.java | 4 ++++ .../codecrow/core/service/JobService.java | 18 ++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java index 6dd799b5..b7f75a97 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java @@ -101,4 +101,8 @@ Page findByProjectIdAndDateRange( @Modifying @Query("DELETE FROM Job j WHERE j.project.id = :projectId") void deleteByProjectId(@Param("projectId") Long projectId); + + @Modifying + @Query("DELETE FROM Job j WHERE j.id = :jobId") + void deleteJobById(@Param("jobId") Long jobId); } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 31840dcf..db182728 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -325,31 +325,25 @@ public Job skipJob(Job job, String reason) { @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW) public void deleteIgnoredJob(Job job, String reason) { log.info("Deleting ignored job {} ({}): {}", job.getExternalId(), job.getJobType(), reason); - // Re-fetch the job to ensure we have a fresh entity in this new transaction Long jobId = job.getId(); if (jobId == null) { log.warn("Cannot delete ignored job - job ID is null"); return; } - Optional existingJob = jobRepository.findById(jobId); - if (existingJob.isEmpty()) { - log.warn("Cannot delete ignored job {} - not found in database", job.getExternalId()); - return; - } + try { - // Delete any logs first (foreign key constraint) + // Use direct JPQL queries to avoid JPA entity lifecycle issues + // Delete logs first (foreign key constraint) jobLogRepository.deleteByJobId(jobId); - jobLogRepository.flush(); log.info("Deleted job logs for ignored job {}", job.getExternalId()); - // Delete the job + // Delete the job using direct JPQL query (bypasses entity state tracking) log.info("About to delete job entity {} (id={})", job.getExternalId(), jobId); - jobRepository.deleteById(jobId); - jobRepository.flush(); // Force immediate execution + jobRepository.deleteJobById(jobId); log.info("Successfully deleted ignored job {} from database", job.getExternalId()); } catch (Exception e) { log.error("Failed to delete ignored job {}: {} - {}", job.getExternalId(), e.getClass().getSimpleName(), e.getMessage(), e); - throw e; // Re-throw so caller knows deletion failed + throw e; } } From 61d2620c9bb1977d7e52d3f8d4c38918c0be4d93 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 15:04:14 +0200 Subject: [PATCH 26/34] feat: Simplify job handling by marking ignored jobs as SKIPPED instead of deleting --- .../processor/WebhookAsyncProcessor.java | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index f374edea..b1144c53 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -13,8 +13,6 @@ import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.slf4j.Logger; -import jakarta.persistence.EntityManager; -import jakarta.persistence.PersistenceContext; import java.io.IOException; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -75,9 +73,6 @@ public class WebhookAsyncProcessor { private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; - @PersistenceContext - private EntityManager entityManager; - // Self-injection for @Transactional proxy to work from @Async method @Autowired @Lazy @@ -176,27 +171,10 @@ public void processWebhookInTransaction( if (finalPlaceholderCommentId != null) { deletePlaceholderComment(provider, project, payload, finalPlaceholderCommentId); } - // Delete the job entirely - don't clutter DB with ignored webhooks - // If deletion fails, skip the job instead - try { - log.info("Deleting ignored job {}", job.getExternalId()); - jobService.deleteIgnoredJob(job, result.message()); - // CRITICAL: Detach the job from this transaction's persistence context - // to prevent JPA from re-saving it when the outer transaction commits - if (entityManager.contains(job)) { - entityManager.detach(job); - log.info("Detached deleted job {} from persistence context", job.getExternalId()); - } - log.info("Successfully deleted ignored job {}", job.getExternalId()); - } catch (Exception deleteError) { - log.warn("Failed to delete ignored job {}, skipping instead: {}", - job.getExternalId(), deleteError.getMessage()); - try { - jobService.skipJob(job, result.message()); - } catch (Exception skipError) { - log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); - } - } + // Mark job as SKIPPED - simpler and more reliable than deletion + // which can have transaction/lock issues with concurrent requests + jobService.skipJob(job, result.message()); + log.info("Marked ignored job {} as SKIPPED", job.getExternalId()); return; } From 704a7a256429f514695d8266af9ca7e9724ed45a Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:42:43 +0200 Subject: [PATCH 27/34] feat: Enhance AI connection logging and refactor placeholder management in webhook processing --- .../service/BitbucketAiClientService.java | 4 + .../processor/WebhookAsyncProcessor.java | 78 ++----------------- .../command/ReviewCommandProcessor.java | 5 ++ .../generic/utils/CommentPlaceholders.java | 66 ++++++++++++++++ .../AbstractWebhookHandler.java | 1 - .../github/service/GitHubAiClientService.java | 4 + .../gitlab/service/GitLabAiClientService.java | 4 + .../project/service/ProjectService.java | 25 ++++-- 8 files changed, 111 insertions(+), 76 deletions(-) create mode 100644 java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java index 142cba75..6ebb0894 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java @@ -124,6 +124,10 @@ public AiAnalysisRequest buildPrAnalysisRequest( VcsConnection vcsConnection = vcsInfo.vcsConnection(); AIConnection aiConnection = project.getAiBinding().getAiConnection(); AIConnection projectAiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building PR analysis request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); // Initialize variables List changedFiles = Collections.emptyList(); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index b1144c53..d509c93c 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -6,6 +6,7 @@ import org.rostilos.codecrow.core.persistence.repository.project.ProjectRepository; import org.rostilos.codecrow.core.service.JobService; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; +import org.rostilos.codecrow.pipelineagent.generic.utils.CommentPlaceholders; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; @@ -30,45 +31,6 @@ public class WebhookAsyncProcessor { private static final Logger log = LoggerFactory.getLogger(WebhookAsyncProcessor.class); - /** Comment markers for CodeCrow command responses */ - private static final String CODECROW_COMMAND_MARKER = ""; - private static final String CODECROW_SUMMARY_MARKER = ""; - private static final String CODECROW_REVIEW_MARKER = ""; - - /** Placeholder messages for commands */ - private static final String PLACEHOLDER_ANALYZE = """ - 🔄 **CodeCrow is analyzing this PR...** - - This may take a few minutes depending on the size of the changes. - I'll update this comment with the results when the analysis is complete. - """; - - private static final String PLACEHOLDER_SUMMARIZE = """ - 🔄 **CodeCrow is generating a summary...** - - I'm analyzing the changes and creating diagrams. - This comment will be updated with the summary when ready. - """; - - private static final String PLACEHOLDER_REVIEW = """ - 🔄 **CodeCrow is reviewing this PR...** - - I'm examining the code changes for potential issues. - This comment will be updated with the review results when complete. - """; - - private static final String PLACEHOLDER_ASK = """ - 🔄 **CodeCrow is processing your question...** - - I'm analyzing the context to provide a helpful answer. - """; - - private static final String PLACEHOLDER_DEFAULT = """ - 🔄 **CodeCrow is processing...** - - Please wait while I complete this task. - """; - private final ProjectRepository projectRepository; private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; @@ -419,7 +381,7 @@ private void postAskReply(VcsReportingService reportingService, Project project, private void postWithMarker(VcsReportingService reportingService, Project project, WebhookPayload payload, String content, String commandType, String placeholderCommentId, Job job) throws IOException { - String marker = getMarkerForCommandType(commandType); + String marker = CommentPlaceholders.getMarkerForCommandType(commandType); // If we have a placeholder comment, update it instead of creating a new one if (placeholderCommentId != null) { @@ -490,7 +452,7 @@ private void postErrorToVcs(EVcsProvider provider, Project project, WebhookPaylo Long.parseLong(payload.pullRequestId()), placeholderCommentId, content, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Updated placeholder comment {} with error for PR {}", placeholderCommentId, payload.pullRequestId()); } else { @@ -499,7 +461,7 @@ private void postErrorToVcs(EVcsProvider provider, Project project, WebhookPaylo project, Long.parseLong(payload.pullRequestId()), content, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Posted error to PR {}", payload.pullRequestId()); } @@ -529,7 +491,7 @@ private void postInfoToVcs(EVcsProvider provider, Project project, WebhookPayloa Long.parseLong(payload.pullRequestId()), placeholderCommentId, infoMessage, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Updated placeholder comment {} with info message for PR {}", placeholderCommentId, payload.pullRequestId()); } else { @@ -538,7 +500,7 @@ private void postInfoToVcs(EVcsProvider provider, Project project, WebhookPayloa project, Long.parseLong(payload.pullRequestId()), infoMessage, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Posted info message to PR {}", payload.pullRequestId()); } @@ -641,8 +603,8 @@ private String postPlaceholderComment(EVcsProvider provider, Project project, ? payload.getCodecrowCommand().type().name().toLowerCase() : "default"; - String placeholderContent = getPlaceholderMessage(commandType); - String marker = getMarkerForCommandType(commandType); + String placeholderContent = CommentPlaceholders.getPlaceholderMessage(commandType); + String marker = CommentPlaceholders.getMarkerForCommandType(commandType); // Delete any previous comments with the same marker before posting placeholder try { @@ -691,28 +653,4 @@ private void deletePlaceholderComment(EVcsProvider provider, Project project, log.warn("Failed to delete placeholder comment {}: {}", commentId, e.getMessage()); } } - - /** - * Get the placeholder message for a command type. - */ - private String getPlaceholderMessage(String commandType) { - return switch (commandType.toLowerCase()) { - case "analyze" -> PLACEHOLDER_ANALYZE; - case "summarize" -> PLACEHOLDER_SUMMARIZE; - case "review" -> PLACEHOLDER_REVIEW; - case "ask" -> PLACEHOLDER_ASK; - default -> PLACEHOLDER_DEFAULT; - }; - } - - /** - * Get the comment marker for a command type. - */ - private String getMarkerForCommandType(String commandType) { - return switch (commandType.toLowerCase()) { - case "summarize" -> CODECROW_SUMMARY_MARKER; - case "review" -> CODECROW_REVIEW_MARKER; - default -> CODECROW_COMMAND_MARKER; - }; - } } diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java index 8d381e79..92a00954 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java @@ -154,6 +154,11 @@ private ReviewRequest buildReviewRequest(Project project, WebhookPayload payload } AIConnection aiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building review command request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); + String decryptedApiKey = tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted()); // Get VCS credentials using centralized extractor diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java new file mode 100644 index 00000000..1145895a --- /dev/null +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java @@ -0,0 +1,66 @@ +package org.rostilos.codecrow.pipelineagent.generic.utils; + +public class CommentPlaceholders { + /** Comment markers for CodeCrow command responses */ + public static final String CODECROW_COMMAND_MARKER = ""; + public static final String CODECROW_SUMMARY_MARKER = ""; + public static final String CODECROW_REVIEW_MARKER = ""; + + /** Placeholder messages for commands */ + public static final String PLACEHOLDER_ANALYZE = """ + 🔄 **CodeCrow is analyzing this PR...** + + This may take a few minutes depending on the size of the changes. + I'll update this comment with the results when the analysis is complete. + """; + + public static final String PLACEHOLDER_SUMMARIZE = """ + 🔄 **CodeCrow is generating a summary...** + + I'm analyzing the changes and creating diagrams. + This comment will be updated with the summary when ready. + """; + + public static final String PLACEHOLDER_REVIEW = """ + 🔄 **CodeCrow is reviewing this PR...** + + I'm examining the code changes for potential issues. + This comment will be updated with the review results when complete. + """; + + public static final String PLACEHOLDER_ASK = """ + 🔄 **CodeCrow is processing your question...** + + I'm analyzing the context to provide a helpful answer. + """; + + public static final String PLACEHOLDER_DEFAULT = """ + 🔄 **CodeCrow is processing...** + + Please wait while I complete this task. + """; + + /** + * Get the placeholder message for a command type. + */ + public static String getPlaceholderMessage(String commandType) { + return switch (commandType.toLowerCase()) { + case "analyze" -> PLACEHOLDER_ANALYZE; + case "summarize" -> PLACEHOLDER_SUMMARIZE; + case "review" -> PLACEHOLDER_REVIEW; + case "ask" -> PLACEHOLDER_ASK; + default -> PLACEHOLDER_DEFAULT; + }; + } + + /** + * Get the comment marker for a command type. + */ + public static String getMarkerForCommandType(String commandType) { + return switch (commandType.toLowerCase()) { + case "summarize" -> CODECROW_SUMMARY_MARKER; + case "review" -> CODECROW_REVIEW_MARKER; + default -> CODECROW_COMMAND_MARKER; + }; + } +} diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java index 989c4438..e4f5e0e3 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java @@ -3,7 +3,6 @@ import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.project.config.BranchAnalysisConfig; -import org.rostilos.codecrow.core.model.project.config.ProjectConfig; import org.rostilos.codecrow.core.util.BranchPatternMatcher; import java.util.List; diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java index f8f0dc09..69193350 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java @@ -118,6 +118,10 @@ private AiAnalysisRequest buildPrAnalysisRequest( VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); AIConnection aiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building PR analysis request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); // Initialize variables List changedFiles = Collections.emptyList(); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java index ed2be10a..fd736acf 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java @@ -118,6 +118,10 @@ private AiAnalysisRequest buildMrAnalysisRequest( VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); AIConnection aiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building MR analysis request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); // Initialize variables List changedFiles = Collections.emptyList(); diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java index ac98db3a..f14718a9 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java @@ -375,18 +375,33 @@ public void updateRepositorySettings(Long workspaceId, Long projectId, UpdateRep @Transactional public boolean bindAiConnection(Long workspaceId, Long projectId, BindAiConnectionRequest request) throws SecurityException { - Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) + // Use findByIdWithConnections to eagerly fetch aiBinding for proper orphan removal + Project project = projectRepository.findByIdWithConnections(projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); + + // Verify workspace ownership + if (!project.getWorkspace().getId().equals(workspaceId)) { + throw new NoSuchElementException("Project not found in workspace"); + } if (request.getAiConnectionId() != null) { Long aiConnectionId = request.getAiConnectionId(); AIConnection aiConnection = aiConnectionRepository.findByWorkspace_IdAndId(workspaceId, aiConnectionId) .orElseThrow(() -> new NoSuchElementException("Ai connection not found")); - ProjectAiConnectionBinding aiConnectionBinding = new ProjectAiConnectionBinding(); - aiConnectionBinding.setProject(project); - aiConnectionBinding.setAiConnection(aiConnection); - project.setAiConnectionBinding(aiConnectionBinding); + // Check if there's an existing binding that needs to be updated + ProjectAiConnectionBinding existingBinding = project.getAiBinding(); + if (existingBinding != null) { + // Update existing binding instead of creating new one + existingBinding.setAiConnection(aiConnection); + } else { + // Create new binding + ProjectAiConnectionBinding aiConnectionBinding = new ProjectAiConnectionBinding(); + aiConnectionBinding.setProject(project); + aiConnectionBinding.setAiConnection(aiConnection); + project.setAiConnectionBinding(aiConnectionBinding); + } + projectRepository.save(project); return true; } From 2e42ebc3d7ca43a21adca084d7eb2e957e319b1b Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:43:30 +0200 Subject: [PATCH 28/34] feat: Add logging for LLM creation and enhance diff snippet extraction for RAG context --- .../mcp-client/llm/llm_factory.py | 3 + .../service/multi_stage_orchestrator.py | 196 ++++++++++++------ .../mcp-client/service/review_service.py | 3 + 3 files changed, 136 insertions(+), 66 deletions(-) diff --git a/python-ecosystem/mcp-client/llm/llm_factory.py b/python-ecosystem/mcp-client/llm/llm_factory.py index 470e7c4e..61b14e85 100644 --- a/python-ecosystem/mcp-client/llm/llm_factory.py +++ b/python-ecosystem/mcp-client/llm/llm_factory.py @@ -139,6 +139,9 @@ def create_llm(ai_model: str, ai_provider: str, ai_api_key: str, temperature: Op # Normalize provider provider = LLMFactory._normalize_provider(ai_provider) + # CRITICAL: Log the model being used for debugging + logger.info(f"Creating LLM instance: provider={provider}, model={ai_model}, temperature={temperature}") + # Check for unsupported Gemini thinking models (applies to all providers) LLMFactory._check_unsupported_gemini_model(ai_model) diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index f58c702e..e4b9846b 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -399,6 +399,44 @@ def _format_previous_issues_for_batch(self, issues: List[Any]) -> str: lines.append("=== END PREVIOUS ISSUES ===") return "\n".join(lines) + def _extract_diff_snippets(self, diff_content: str) -> List[str]: + """ + Extract meaningful code snippets from diff content for RAG semantic search. + Focuses on added/modified lines that represent significant code changes. + """ + if not diff_content: + return [] + + snippets = [] + current_snippet_lines = [] + + for line in diff_content.split("\n"): + # Focus on added lines (new code) + if line.startswith("+") and not line.startswith("+++"): + clean_line = line[1:].strip() + # Skip trivial lines + if (clean_line and + len(clean_line) > 10 and # Minimum meaningful length + not clean_line.startswith("//") and # Skip comments + not clean_line.startswith("#") and + not clean_line.startswith("*") and + not clean_line == "{" and + not clean_line == "}" and + not clean_line == ""): + current_snippet_lines.append(clean_line) + + # Batch into snippets of 3-5 lines + if len(current_snippet_lines) >= 3: + snippets.append(" ".join(current_snippet_lines)) + current_snippet_lines = [] + + # Add remaining lines as final snippet + if current_snippet_lines: + snippets.append(" ".join(current_snippet_lines)) + + # Limit to most significant snippets + return snippets[:10] + def _get_diff_snippets_for_batch( self, all_diff_snippets: List[str], @@ -536,6 +574,48 @@ async def _execute_stage_1_file_reviews( logger.info(f"Stage 1 Complete: {len(all_issues)} issues found across {total_files} files") return all_issues + async def _fetch_batch_rag_context( + self, + request: ReviewRequestDto, + batch_file_paths: List[str], + batch_diff_snippets: List[str] + ) -> Optional[Dict[str, Any]]: + """ + Fetch RAG context specifically for this batch of files. + Uses batch file paths and diff snippets for targeted semantic search. + """ + if not self.rag_client: + return None + + try: + # Determine branch for RAG query + rag_branch = request.targetBranchName or request.commitHash or "main" + + logger.info(f"Fetching per-batch RAG context for {len(batch_file_paths)} files") + + rag_response = await self.rag_client.get_pr_context( + workspace=request.projectWorkspace, + project=request.projectNamespace, + branch=rag_branch, + changed_files=batch_file_paths, + diff_snippets=batch_diff_snippets, + pr_title=request.prTitle, + pr_description=request.prDescription, + top_k=10 # Fewer chunks per batch for focused context + ) + + if rag_response and rag_response.get("context"): + context = rag_response.get("context") + chunk_count = len(context.get("relevant_code", [])) + logger.info(f"Per-batch RAG: retrieved {chunk_count} chunks for files {batch_file_paths}") + return context + + return None + + except Exception as e: + logger.warning(f"Failed to fetch per-batch RAG context: {e}") + return None + async def _review_file_batch( self, request: ReviewRequestDto, @@ -550,6 +630,7 @@ async def _review_file_batch( """ batch_files_data = [] batch_file_paths = [] + batch_diff_snippets = [] project_rules = "1. No hardcoded secrets.\n2. Use dependency injection.\n3. Verify all inputs." # For incremental mode, use deltaDiff instead of full diff @@ -560,7 +641,7 @@ async def _review_file_batch( else: diff_source = processed_diff - # Collect file paths and diffs for this batch + # Collect file paths, diffs, and extract snippets for this batch for item in batch_items: file_info = item["file"] batch_file_paths.append(file_info.path) @@ -571,6 +652,9 @@ async def _review_file_batch( for f in diff_source.files: if f.path == file_info.path or f.path.endswith("/" + file_info.path): file_diff = f.content + # Extract code snippets from diff for RAG semantic search + if file_diff: + batch_diff_snippets.extend(self._extract_diff_snippets(file_diff)) break batch_files_data.append({ @@ -582,18 +666,32 @@ async def _review_file_batch( "is_incremental": is_incremental # Pass mode to prompt builder }) - # Use initial RAG context (already fetched with all files/snippets) - # The initial query is more comprehensive - it uses ALL changed files and snippets - # Per-batch filtering is done in _format_rag_context via relevant_files param + # Fetch per-batch RAG context using batch-specific files and diff snippets rag_context_text = "" - if fallback_rag_context: - logger.info(f"Using initial RAG context for batch: {batch_file_paths}") + batch_rag_context = None + + if self.rag_client: + batch_rag_context = await self._fetch_batch_rag_context( + request, batch_file_paths, batch_diff_snippets + ) + + # Use batch-specific RAG context if available, otherwise fall back to initial context + if batch_rag_context: + logger.info(f"Using per-batch RAG context for: {batch_file_paths}") + rag_context_text = self._format_rag_context( + batch_rag_context, + set(batch_file_paths), + pr_changed_files=request.changedFiles + ) + elif fallback_rag_context: + logger.info(f"Using fallback RAG context for batch: {batch_file_paths}") rag_context_text = self._format_rag_context( fallback_rag_context, set(batch_file_paths), pr_changed_files=request.changedFiles ) - logger.info(f"RAG context for batch: {len(rag_context_text)} chars") + + logger.info(f"RAG context for batch: {len(rag_context_text)} chars") # For incremental mode, filter previous issues relevant to this batch # Also pass previous issues in FULL mode if they exist (subsequent PR iterations) @@ -882,13 +980,15 @@ def _format_rag_context( ) -> str: """ Format RAG context into a readable string for the prompt. - Includes rich AST metadata (imports, extends, implements) for better LLM context. + + IMPORTANT: We trust RAG's semantic similarity scores for relevance. + The RAG system already uses embeddings to find semantically related code. + We only filter out chunks from files being modified in the PR (stale data from main branch). Args: rag_context: RAG response with code chunks - relevant_files: Files in current batch to prioritize - pr_changed_files: ALL files modified in the PR - chunks from these files - are marked as potentially stale (from main branch, not PR branch) + relevant_files: (UNUSED - kept for API compatibility) - we trust RAG scores instead + pr_changed_files: Files modified in the PR - chunks from these may be stale """ if not rag_context: logger.debug("RAG context is empty or None") @@ -900,35 +1000,32 @@ def _format_rag_context( logger.debug("No chunks found in RAG context (keys: %s)", list(rag_context.keys())) return "" - logger.debug(f"Processing {len(chunks)} RAG chunks for context") - logger.debug(f"PR changed files for filtering: {pr_changed_files[:5] if pr_changed_files else 'none'}...") + logger.info(f"Processing {len(chunks)} RAG chunks (trusting semantic similarity scores)") - # Normalize PR changed files for comparison + # Normalize PR changed files for stale-data detection only pr_changed_set = set() if pr_changed_files: for f in pr_changed_files: pr_changed_set.add(f) - # Also add just the filename for matching if "/" in f: pr_changed_set.add(f.rsplit("/", 1)[-1]) formatted_parts = [] included_count = 0 - skipped_modified = 0 - skipped_relevance = 0 + skipped_stale = 0 + for chunk in chunks: - if included_count >= 15: # Increased from 10 for more context - logger.debug(f"Reached chunk limit of 15, stopping") + if included_count >= 15: + logger.debug(f"Reached chunk limit of 15") break - # Extract metadata metadata = chunk.get("metadata", {}) path = metadata.get("path", chunk.get("path", "unknown")) chunk_type = metadata.get("content_type", metadata.get("type", "code")) score = chunk.get("score", chunk.get("relevance_score", 0)) - # Check if this chunk is from a file being modified in the PR - is_from_modified_file = False + # Only filter: chunks from PR-modified files with LOW scores (likely stale) + # High-score chunks from modified files may still be relevant (other parts of same file) if pr_changed_set: path_filename = path.rsplit("/", 1)[-1] if "/" in path else path is_from_modified_file = ( @@ -936,30 +1033,11 @@ def _format_rag_context( path_filename in pr_changed_set or any(path.endswith(f) or f.endswith(path) for f in pr_changed_set) ) - - # For chunks from modified files: - # - Skip if very low relevance (score < 0.70) - likely not useful - # - Include if moderate+ relevance (score >= 0.70) - context is valuable - if is_from_modified_file: - if score < 0.70: - logger.debug(f"Skipping RAG chunk from modified file (low score): {path} (score={score})") - skipped_modified += 1 - continue - else: - logger.debug(f"Including RAG chunk from modified file (relevant): {path} (score={score})") - - # Optionally filter by relevance to batch files - if relevant_files: - # Include if the chunk's path relates to any batch file - is_relevant = any( - path in f or f in path or - path.rsplit("/", 1)[-1] == f.rsplit("/", 1)[-1] - for f in relevant_files - ) - # Also include chunks with moderate+ score regardless - if not is_relevant and score < 0.70: - logger.debug(f"Skipping RAG chunk (not relevant to batch and low score): {path} (score={score})") - skipped_relevance += 1 + + # Skip ONLY low-score chunks from modified files (likely stale/outdated) + if is_from_modified_file and score < 0.70: + logger.debug(f"Skipping stale chunk from modified file: {path} (score={score:.2f})") + skipped_stale += 1 continue text = chunk.get("text", chunk.get("content", "")) @@ -968,38 +1046,27 @@ def _format_rag_context( included_count += 1 - # Build rich metadata context from AST-extracted info - meta_lines = [] - meta_lines.append(f"File: {path}") + # Build rich metadata context + meta_lines = [f"File: {path}"] - # Include namespace/package if available if metadata.get("namespace"): meta_lines.append(f"Namespace: {metadata['namespace']}") elif metadata.get("package"): meta_lines.append(f"Package: {metadata['package']}") - # Include class/function name if metadata.get("primary_name"): meta_lines.append(f"Definition: {metadata['primary_name']}") elif metadata.get("semantic_names"): meta_lines.append(f"Definitions: {', '.join(metadata['semantic_names'][:5])}") - # Include inheritance info (extends, implements) if metadata.get("extends"): extends = metadata["extends"] - if isinstance(extends, list): - meta_lines.append(f"Extends: {', '.join(extends)}") - else: - meta_lines.append(f"Extends: {extends}") + meta_lines.append(f"Extends: {', '.join(extends) if isinstance(extends, list) else extends}") if metadata.get("implements"): implements = metadata["implements"] - if isinstance(implements, list): - meta_lines.append(f"Implements: {', '.join(implements)}") - else: - meta_lines.append(f"Implements: {implements}") + meta_lines.append(f"Implements: {', '.join(implements) if isinstance(implements, list) else implements}") - # Include imports (abbreviated if too many) if metadata.get("imports"): imports = metadata["imports"] if isinstance(imports, list): @@ -1008,13 +1075,11 @@ def _format_rag_context( else: meta_lines.append(f"Imports: {'; '.join(imports[:5])}... (+{len(imports)-5} more)") - # Include parent context (for nested methods) if metadata.get("parent_context"): parent_ctx = metadata["parent_context"] if isinstance(parent_ctx, list): meta_lines.append(f"Parent: {'.'.join(parent_ctx)}") - # Include content type for understanding chunk nature if chunk_type and chunk_type != "code": meta_lines.append(f"Type: {chunk_type}") @@ -1026,11 +1091,10 @@ def _format_rag_context( ) if not formatted_parts: - logger.warning(f"No RAG chunks included in prompt (total: {len(chunks)}, skipped_modified: {skipped_modified}, skipped_relevance: {skipped_relevance}). " - f"PR changed files: {pr_changed_files[:5] if pr_changed_files else 'none'}...") + logger.warning(f"No RAG chunks included (total: {len(chunks)}, skipped_stale: {skipped_stale})") return "" - logger.info(f"Included {len(formatted_parts)} RAG chunks in prompt context (total: {len(chunks)}, skipped: {skipped_modified} low-score modified, {skipped_relevance} low relevance)") + logger.info(f"Included {len(formatted_parts)} RAG chunks (skipped {skipped_stale} stale from modified files)") return "\n".join(formatted_parts) def _emit_status(self, state: str, message: str): diff --git a/python-ecosystem/mcp-client/service/review_service.py b/python-ecosystem/mcp-client/service/review_service.py index ef9bb770..ba57a085 100644 --- a/python-ecosystem/mcp-client/service/review_service.py +++ b/python-ecosystem/mcp-client/service/review_service.py @@ -393,6 +393,9 @@ def _create_mcp_client(self, config: Dict[str, Any]) -> MCPClient: def _create_llm(self, request: ReviewRequestDto): """Create LLM instance from request parameters and initialize reranker.""" try: + # Log the model being used for this request + logger.info(f"Creating LLM for project {request.projectId}: provider={request.aiProvider}, model={request.aiModel}") + llm = LLMFactory.create_llm( request.aiModel, request.aiProvider, From d036fa972b1efdc551da88bb1179045f76faffe0 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:44:12 +0200 Subject: [PATCH 29/34] feat: Implement AST-based code splitter and scoring configuration - Added AST-based code splitter using Tree-sitter for accurate code parsing. - Introduced TreeSitterParser for dynamic language loading and caching. - Created scoring configuration for RAG query result reranking with configurable boost factors and priority patterns. - Refactored RAGQueryService to utilize the new scoring configuration for enhanced result ranking. - Improved metadata extraction and handling for better context in scoring. --- .../rag-pipeline/src/rag_pipeline/__init__.py | 7 +- .../src/rag_pipeline/core/__init__.py | 11 +- .../src/rag_pipeline/core/ast_splitter.py | 1401 ----------------- .../src/rag_pipeline/core/chunking.py | 171 -- .../src/rag_pipeline/core/index_manager.py | 31 +- .../core/index_manager/__init__.py | 10 + .../core/index_manager/branch_manager.py | 172 ++ .../core/index_manager/collection_manager.py | 164 ++ .../core/index_manager/indexer.py | 398 +++++ .../core/index_manager/manager.py | 290 ++++ .../core/index_manager/point_operations.py | 151 ++ .../core/index_manager/stats_manager.py | 156 ++ .../rag_pipeline/core/semantic_splitter.py | 455 ------ .../rag_pipeline/core/splitter/__init__.py | 53 + .../rag_pipeline/core/splitter/languages.py | 139 ++ .../rag_pipeline/core/splitter/metadata.py | 339 ++++ .../core/splitter/queries/c_sharp.scm | 56 + .../rag_pipeline/core/splitter/queries/go.scm | 26 + .../core/splitter/queries/java.scm | 45 + .../core/splitter/queries/javascript.scm | 42 + .../core/splitter/queries/php.scm | 40 + .../core/splitter/queries/python.scm | 28 + .../core/splitter/queries/rust.scm | 46 + .../core/splitter/queries/typescript.scm | 52 + .../core/splitter/query_runner.py | 360 +++++ .../rag_pipeline/core/splitter/splitter.py | 720 +++++++++ .../rag_pipeline/core/splitter/tree_parser.py | 129 ++ .../src/rag_pipeline/models/scoring_config.py | 232 +++ .../rag_pipeline/services/query_service.py | 75 +- 29 files changed, 3682 insertions(+), 2117 deletions(-) delete mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py delete mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py delete mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py index 3dc56254..e0fc412d 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py @@ -2,7 +2,7 @@ CodeCrow RAG Pipeline A RAG (Retrieval-Augmented Generation) pipeline for code repositories. -Provides indexing and querying capabilities for code using LlamaIndex and MongoDB. +Provides indexing and querying capabilities for code using LlamaIndex, Tree-sitter and Qdrant. """ __version__ = "1.0.0" @@ -11,7 +11,7 @@ from .core.index_manager import RAGIndexManager from .services.query_service import RAGQueryService from .core.loader import DocumentLoader -from .core.chunking import CodeAwareSplitter, FunctionAwareSplitter +from .core.splitter import ASTCodeSplitter from .utils.utils import make_namespace, detect_language_from_path __all__ = [ @@ -21,8 +21,7 @@ "RAGIndexManager", "RAGQueryService", "DocumentLoader", - "CodeAwareSplitter", - "FunctionAwareSplitter", + "ASTCodeSplitter", "make_namespace", "detect_language_from_path", ] diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py index dcb908ff..43cff58b 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py @@ -1,17 +1,10 @@ """Core functionality for indexing and document processing""" __all__ = [ "DocumentLoader", - "CodeAwareSplitter", - "FunctionAwareSplitter", - "SemanticCodeSplitter", "ASTCodeSplitter", "RAGIndexManager" ] from .index_manager import RAGIndexManager -from .chunking import CodeAwareSplitter, FunctionAwareSplitter -from .semantic_splitter import SemanticCodeSplitter -from .ast_splitter import ASTCodeSplitter -from .loader import DocumentLoader - - +from .splitter import ASTCodeSplitter +from .loader import DocumentLoader \ No newline at end of file diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py deleted file mode 100644 index 691a3b0d..00000000 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py +++ /dev/null @@ -1,1401 +0,0 @@ -""" -AST-based Code Splitter using Tree-sitter for accurate code parsing. - -This module provides true AST-aware code chunking that: -1. Uses Tree-sitter for accurate AST parsing (15+ languages) -2. Splits code into semantic units (classes, functions, methods) -3. Uses RecursiveCharacterTextSplitter for oversized chunks (large methods) -4. Enriches metadata for better RAG retrieval -5. Maintains parent context ("breadcrumbs") for nested structures -6. Uses deterministic IDs for Qdrant deduplication - -Key benefits over regex-based splitting: -- Accurate function/class boundary detection -- Language-aware parsing for 15+ languages -- Better metadata: content_type, language, semantic_names, parent_class -- Handles edge cases (nested functions, decorators, etc.) -- Deterministic chunk IDs prevent duplicates on re-indexing -""" - -import re -import hashlib -import logging -from typing import List, Dict, Any, Optional, Set -from pathlib import Path -from dataclasses import dataclass, field -from enum import Enum - -from langchain_text_splitters import RecursiveCharacterTextSplitter, Language -from llama_index.core.schema import Document as LlamaDocument, TextNode - -logger = logging.getLogger(__name__) - - -class ContentType(Enum): - """Content type as determined by AST parsing""" - FUNCTIONS_CLASSES = "functions_classes" # Full function/class definition - SIMPLIFIED_CODE = "simplified_code" # Remaining code with placeholders - FALLBACK = "fallback" # Non-AST parsed content - OVERSIZED_SPLIT = "oversized_split" # Large chunk split by RecursiveCharacterTextSplitter - - -# Map file extensions to LangChain Language enum -EXTENSION_TO_LANGUAGE: Dict[str, Language] = { - # Python - '.py': Language.PYTHON, - '.pyw': Language.PYTHON, - '.pyi': Language.PYTHON, - - # Java/JVM - '.java': Language.JAVA, - '.kt': Language.KOTLIN, - '.kts': Language.KOTLIN, - '.scala': Language.SCALA, - - # JavaScript/TypeScript - '.js': Language.JS, - '.jsx': Language.JS, - '.mjs': Language.JS, - '.cjs': Language.JS, - '.ts': Language.TS, - '.tsx': Language.TS, - - # Systems languages - '.go': Language.GO, - '.rs': Language.RUST, - '.c': Language.C, - '.h': Language.C, - '.cpp': Language.CPP, - '.cc': Language.CPP, - '.cxx': Language.CPP, - '.hpp': Language.CPP, - '.hxx': Language.CPP, - '.cs': Language.CSHARP, - - # Web/Scripting - '.php': Language.PHP, - '.phtml': Language.PHP, # PHP template files (Magento, Zend, etc.) - '.php3': Language.PHP, - '.php4': Language.PHP, - '.php5': Language.PHP, - '.phps': Language.PHP, - '.inc': Language.PHP, # PHP include files - '.rb': Language.RUBY, - '.erb': Language.RUBY, # Ruby template files - '.lua': Language.LUA, - '.pl': Language.PERL, - '.pm': Language.PERL, - '.swift': Language.SWIFT, - - # Markup/Config - '.md': Language.MARKDOWN, - '.markdown': Language.MARKDOWN, - '.html': Language.HTML, - '.htm': Language.HTML, - '.rst': Language.RST, - '.tex': Language.LATEX, - '.proto': Language.PROTO, - '.sol': Language.SOL, - '.hs': Language.HASKELL, - '.cob': Language.COBOL, - '.cbl': Language.COBOL, - '.xml': Language.HTML, # Use HTML splitter for XML -} - -# Languages that support full AST parsing via tree-sitter -AST_SUPPORTED_LANGUAGES = { - Language.PYTHON, Language.JAVA, Language.KOTLIN, Language.JS, Language.TS, - Language.GO, Language.RUST, Language.C, Language.CPP, Language.CSHARP, - Language.PHP, Language.RUBY, Language.SCALA, Language.LUA, Language.PERL, - Language.SWIFT, Language.HASKELL, Language.COBOL -} - -# Tree-sitter language name mapping (tree-sitter-languages uses these names) -LANGUAGE_TO_TREESITTER: Dict[Language, str] = { - Language.PYTHON: 'python', - Language.JAVA: 'java', - Language.KOTLIN: 'kotlin', - Language.JS: 'javascript', - Language.TS: 'typescript', - Language.GO: 'go', - Language.RUST: 'rust', - Language.C: 'c', - Language.CPP: 'cpp', - Language.CSHARP: 'c_sharp', - Language.PHP: 'php', - Language.RUBY: 'ruby', - Language.SCALA: 'scala', - Language.LUA: 'lua', - Language.PERL: 'perl', - Language.SWIFT: 'swift', - Language.HASKELL: 'haskell', -} - -# Node types that represent semantic CHUNKING units (classes, functions) -# NOTE: imports, namespace, inheritance are now extracted DYNAMICALLY from AST -# by pattern matching on node type names - no hardcoded mappings needed! -SEMANTIC_NODE_TYPES: Dict[str, Dict[str, List[str]]] = { - 'python': { - 'class': ['class_definition'], - 'function': ['function_definition', 'async_function_definition'], - }, - 'java': { - 'class': ['class_declaration', 'interface_declaration', 'enum_declaration'], - 'function': ['method_declaration', 'constructor_declaration'], - }, - 'javascript': { - 'class': ['class_declaration'], - 'function': ['function_declaration', 'method_definition', 'arrow_function', 'generator_function_declaration'], - }, - 'typescript': { - 'class': ['class_declaration', 'interface_declaration'], - 'function': ['function_declaration', 'method_definition', 'arrow_function'], - }, - 'go': { - 'class': ['type_declaration'], # structs, interfaces - 'function': ['function_declaration', 'method_declaration'], - }, - 'rust': { - 'class': ['struct_item', 'impl_item', 'trait_item', 'enum_item'], - 'function': ['function_item'], - }, - 'c_sharp': { - 'class': ['class_declaration', 'interface_declaration', 'struct_declaration'], - 'function': ['method_declaration', 'constructor_declaration'], - }, - 'kotlin': { - 'class': ['class_declaration', 'object_declaration', 'interface_declaration'], - 'function': ['function_declaration'], - }, - 'php': { - 'class': ['class_declaration', 'interface_declaration', 'trait_declaration'], - 'function': ['function_definition', 'method_declaration'], - }, - 'ruby': { - 'class': ['class', 'module'], - 'function': ['method', 'singleton_method'], - }, - 'cpp': { - 'class': ['class_specifier', 'struct_specifier'], - 'function': ['function_definition'], - }, - 'c': { - 'class': ['struct_specifier'], - 'function': ['function_definition'], - }, - 'scala': { - 'class': ['class_definition', 'object_definition', 'trait_definition'], - 'function': ['function_definition'], - }, -} - -# Metadata extraction patterns (fallback when AST doesn't provide names) -METADATA_PATTERNS = { - 'python': { - 'class': re.compile(r'^class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'^(?:async\s+)?def\s+(\w+)\s*\(', re.MULTILINE), - }, - 'java': { - 'class': re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - }, - 'javascript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - }, - 'typescript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:export\s+)?interface\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - }, - 'go': { - 'function': re.compile(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', re.MULTILINE), - 'struct': re.compile(r'^type\s+(\w+)\s+struct\s*\{', re.MULTILINE), - }, - 'rust': { - 'function': re.compile(r'^(?:pub\s+)?(?:async\s+)?fn\s+(\w+)', re.MULTILINE), - 'struct': re.compile(r'^(?:pub\s+)?struct\s+(\w+)', re.MULTILINE), - }, - 'c_sharp': { - 'class': re.compile(r'(?:public\s+|private\s+|internal\s+)?(?:abstract\s+|sealed\s+)?class\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected|internal)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - }, - 'kotlin': { - 'class': re.compile(r'(?:data\s+|sealed\s+|open\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:fun|suspend\s+fun)\s+(\w+)\s*\(', re.MULTILINE), - }, - 'php': { - 'class': re.compile(r'(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:public|private|protected|static|\s)*function\s+(\w+)\s*\(', re.MULTILINE), - }, -} - -# Patterns for extracting class inheritance, interfaces, and imports -CLASS_INHERITANCE_PATTERNS = { - 'php': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w\\]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w\\]+)?\s+implements\s+([\w\\,\s]+)', re.MULTILINE), - 'use': re.compile(r'^use\s+([\w\\]+)(?:\s+as\s+\w+)?;', re.MULTILINE), - 'namespace': re.compile(r'^namespace\s+([\w\\]+);', re.MULTILINE), - 'type_hint': re.compile(r'@var\s+(\\?[\w\\|]+)', re.MULTILINE), - # PHTML template type hints: /** @var \Namespace\Class $variable */ - 'template_type': re.compile(r'/\*\*\s*@var\s+([\w\\]+)\s+\$\w+\s*\*/', re.MULTILINE), - # Variable type hints in PHPDoc: @param Type $name, @return Type - 'phpdoc_types': re.compile(r'@(?:param|return|throws)\s+([\w\\|]+)', re.MULTILINE), - }, - 'java': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+([\w.]+(?:\.\*)?);', re.MULTILINE), - 'package': re.compile(r'^package\s+([\w.]+);', re.MULTILINE), - }, - 'kotlin': { - 'extends': re.compile(r'class\s+\w+\s*:\s*([\w.]+)(?:\([^)]*\))?', re.MULTILINE), - 'import': re.compile(r'^import\s+([\w.]+(?:\.\*)?)', re.MULTILINE), - 'package': re.compile(r'^package\s+([\w.]+)', re.MULTILINE), - }, - 'python': { - 'extends': re.compile(r'class\s+\w+\s*\(\s*([\w.,\s]+)\s*\)\s*:', re.MULTILINE), - 'import': re.compile(r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s*]+)', re.MULTILINE), - }, - 'typescript': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), - }, - 'javascript': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), - 'require': re.compile(r'require\s*\(\s*["\']([^"\']+)["\']\s*\)', re.MULTILINE), - }, - 'c_sharp': { - 'extends': re.compile(r'class\s+\w+\s*:\s*([\w.]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+\s*:\s*(?:[\w.]+\s*,\s*)*([\w.,\s]+)', re.MULTILINE), - 'using': re.compile(r'^using\s+([\w.]+);', re.MULTILINE), - 'namespace': re.compile(r'^namespace\s+([\w.]+)', re.MULTILINE), - }, - 'go': { - 'import': re.compile(r'^import\s+(?:\(\s*)?"([^"]+)"', re.MULTILINE), - 'package': re.compile(r'^package\s+(\w+)', re.MULTILINE), - }, - 'rust': { - 'use': re.compile(r'^use\s+([\w:]+(?:::\{[^}]+\})?);', re.MULTILINE), - 'impl_for': re.compile(r'impl\s+(?:<[^>]+>\s+)?([\w:]+)\s+for\s+([\w:]+)', re.MULTILINE), - }, - 'scala': { - 'extends': re.compile(r'(?:class|object|trait)\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'with': re.compile(r'with\s+([\w.]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+([\w._{}]+)', re.MULTILINE), - 'package': re.compile(r'^package\s+([\w.]+)', re.MULTILINE), - }, -} - - -@dataclass -class ASTChunk: - """Represents a chunk of code from AST parsing""" - content: str - content_type: ContentType - language: str - path: str - semantic_names: List[str] = field(default_factory=list) - parent_context: List[str] = field(default_factory=list) # Breadcrumb: ["MyClass", "inner_method"] - docstring: Optional[str] = None - signature: Optional[str] = None - start_line: int = 0 - end_line: int = 0 - node_type: Optional[str] = None - class_metadata: Dict[str, Any] = field(default_factory=dict) # extends, implements from AST - file_metadata: Dict[str, Any] = field(default_factory=dict) # imports, namespace from AST - - -def generate_deterministic_id(path: str, content: str, chunk_index: int = 0) -> str: - """ - Generate a deterministic ID for a chunk based on file path and content. - - This ensures the same code chunk always gets the same ID, preventing - duplicates in Qdrant during re-indexing. - - Args: - path: File path - content: Chunk content - chunk_index: Index of chunk within file (for disambiguation) - - Returns: - Deterministic hex ID string - """ - # Use path + content hash + index for uniqueness - hash_input = f"{path}:{chunk_index}:{content[:500]}" # First 500 chars for efficiency - return hashlib.sha256(hash_input.encode('utf-8')).hexdigest()[:32] - - -def compute_file_hash(content: str) -> str: - """Compute hash of file content for change detection""" - return hashlib.sha256(content.encode('utf-8')).hexdigest() - - -class ASTCodeSplitter: - """ - AST-based code splitter using Tree-sitter for accurate parsing. - - Features: - - True AST parsing via tree-sitter for accurate code structure detection - - Splits code into semantic units (classes, functions, methods) - - Maintains parent context (breadcrumbs) for nested structures - - Falls back to RecursiveCharacterTextSplitter for oversized chunks - - Uses deterministic IDs for Qdrant deduplication - - Enriches metadata for improved RAG retrieval - - Usage: - splitter = ASTCodeSplitter(max_chunk_size=2000) - nodes = splitter.split_documents(documents) - """ - - DEFAULT_MAX_CHUNK_SIZE = 2000 - DEFAULT_MIN_CHUNK_SIZE = 100 - DEFAULT_CHUNK_OVERLAP = 200 - DEFAULT_PARSER_THRESHOLD = 10 # Minimum lines for AST parsing - - def __init__( - self, - max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, - min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, - chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, - parser_threshold: int = DEFAULT_PARSER_THRESHOLD - ): - """ - Initialize AST code splitter. - - Args: - max_chunk_size: Maximum characters per chunk. Larger chunks are split. - min_chunk_size: Minimum characters for a valid chunk. - chunk_overlap: Overlap between chunks when splitting oversized content. - parser_threshold: Minimum lines for AST parsing (smaller files use fallback). - """ - self.max_chunk_size = max_chunk_size - self.min_chunk_size = min_chunk_size - self.chunk_overlap = chunk_overlap - self.parser_threshold = parser_threshold - - # Cache text splitters for oversized chunks - self._splitter_cache: Dict[Language, RecursiveCharacterTextSplitter] = {} - - # Default text splitter for unknown languages - self._default_splitter = RecursiveCharacterTextSplitter( - chunk_size=max_chunk_size, - chunk_overlap=chunk_overlap, - length_function=len, - ) - - # Track if tree-sitter is available - self._tree_sitter_available: Optional[bool] = None - # Cache for language modules and parsers - self._language_cache: Dict[str, Any] = {} - - def _get_tree_sitter_language(self, lang_name: str): - """ - Get tree-sitter Language object for a language name. - Uses the new tree-sitter API with individual language packages. - - Note: Different packages have different APIs: - - Most use: module.language() - - PHP uses: module.language_php() - - TypeScript uses: module.language_typescript() - """ - if lang_name in self._language_cache: - return self._language_cache[lang_name] - - try: - from tree_sitter import Language - - # Map language names to their package modules and function names - # Format: (module_name, function_name or None for 'language') - lang_modules = { - 'python': ('tree_sitter_python', 'language'), - 'java': ('tree_sitter_java', 'language'), - 'javascript': ('tree_sitter_javascript', 'language'), - 'typescript': ('tree_sitter_typescript', 'language_typescript'), - 'go': ('tree_sitter_go', 'language'), - 'rust': ('tree_sitter_rust', 'language'), - 'c': ('tree_sitter_c', 'language'), - 'cpp': ('tree_sitter_cpp', 'language'), - 'c_sharp': ('tree_sitter_c_sharp', 'language'), - 'ruby': ('tree_sitter_ruby', 'language'), - 'php': ('tree_sitter_php', 'language_php'), - } - - lang_info = lang_modules.get(lang_name) - if not lang_info: - return None - - module_name, func_name = lang_info - - # Dynamic import of language module - import importlib - lang_module = importlib.import_module(module_name) - - # Get the language function - lang_func = getattr(lang_module, func_name, None) - if not lang_func: - logger.debug(f"Module {module_name} has no {func_name} function") - return None - - # Create Language object using the new API - language = Language(lang_func()) - self._language_cache[lang_name] = language - return language - - except Exception as e: - logger.debug(f"Could not load tree-sitter language '{lang_name}': {e}") - return None - - def _check_tree_sitter(self) -> bool: - """Check if tree-sitter is available""" - if self._tree_sitter_available is None: - try: - from tree_sitter import Parser, Language - import tree_sitter_python as tspython - - # Test with the new API - py_language = Language(tspython.language()) - parser = Parser(py_language) - parser.parse(b"def test(): pass") - - self._tree_sitter_available = True - logger.info("tree-sitter is available and working") - except ImportError as e: - logger.warning(f"tree-sitter not installed: {e}") - self._tree_sitter_available = False - except Exception as e: - logger.warning(f"tree-sitter error: {type(e).__name__}: {e}") - self._tree_sitter_available = False - return self._tree_sitter_available - - def _get_language_from_path(self, path: str) -> Optional[Language]: - """Determine Language enum from file path""" - ext = Path(path).suffix.lower() - return EXTENSION_TO_LANGUAGE.get(ext) - - def _get_treesitter_language(self, language: Language) -> Optional[str]: - """Get tree-sitter language name from Language enum""" - return LANGUAGE_TO_TREESITTER.get(language) - - def _get_text_splitter(self, language: Language) -> RecursiveCharacterTextSplitter: - """Get language-specific text splitter for oversized chunks""" - if language not in self._splitter_cache: - try: - self._splitter_cache[language] = RecursiveCharacterTextSplitter.from_language( - language=language, - chunk_size=self.max_chunk_size, - chunk_overlap=self.chunk_overlap, - ) - except Exception: - # Fallback if language not supported - self._splitter_cache[language] = self._default_splitter - return self._splitter_cache[language] - - def _parse_inheritance_clause(self, clause_text: str, language: str) -> List[str]: - """ - Parse inheritance clause from AST node text to extract class/interface names. - - Handles various formats: - - PHP: "extends ParentClass" or "implements Interface1, Interface2" - - Java: "extends Parent" or "implements I1, I2" - - Python: "(Parent1, Parent2)" - - TypeScript/JS: "extends Parent implements Interface" - """ - if not clause_text: - return [] - - # Clean up the clause - text = clause_text.strip() - - # Remove common keywords - for keyword in ['extends', 'implements', 'with', ':']: - text = text.replace(keyword, ' ') - - # Handle parentheses (Python style) - text = text.strip('()') - - # Split by comma and clean up - names = [] - for part in text.split(','): - name = part.strip() - # Remove any generic type parameters - if '<' in name: - name = name.split('<')[0].strip() - # Remove any constructor calls () - if '(' in name: - name = name.split('(')[0].strip() - if name and name not in ('', ' '): - names.append(name) - - return names - - def _parse_namespace(self, ns_text: str, language: str) -> Optional[str]: - """ - Extract namespace/package name from AST node text. - - Handles various formats: - - PHP: "namespace Vendor\\Package\\Module;" - - Java/Kotlin: "package com.example.app;" - - C#: "namespace MyNamespace { ... }" - """ - if not ns_text: - return None - - # Clean up - text = ns_text.strip() - - # Remove keywords and semicolons - for keyword in ['namespace', 'package']: - text = text.replace(keyword, ' ') - - text = text.strip().rstrip(';').rstrip('{').strip() - - return text if text else None - - def _parse_with_ast( - self, - text: str, - language: Language, - path: str - ) -> List[ASTChunk]: - """ - Parse code using AST via tree-sitter. - - Returns list of ASTChunk objects with content and metadata. - """ - if not self._check_tree_sitter(): - return [] - - ts_lang = self._get_treesitter_language(language) - if not ts_lang: - logger.debug(f"No tree-sitter mapping for {language}, using fallback") - return [] - - try: - from tree_sitter import Parser - - # Get Language object for this language - lang_obj = self._get_tree_sitter_language(ts_lang) - if not lang_obj: - logger.debug(f"tree-sitter language '{ts_lang}' not available") - return [] - - # Create parser with the language - parser = Parser(lang_obj) - tree = parser.parse(bytes(text, "utf8")) - - # Extract chunks with breadcrumb context - chunks = self._extract_ast_chunks_with_context( - tree.root_node, - text, - ts_lang, - path - ) - - return chunks - - except Exception as e: - logger.warning(f"AST parsing failed for {path}: {e}") - return [] - - def _extract_ast_chunks_with_context( - self, - root_node, - source_code: str, - language: str, - path: str - ) -> List[ASTChunk]: - """ - Extract function/class chunks from AST tree with parent context (breadcrumbs). - - This solves the "context loss" problem by tracking parent classes/modules - so that a method knows it belongs to a specific class. - - Also extracts file-level metadata dynamically from AST - no hardcoded mappings needed. - """ - chunks = [] - processed_ranges: Set[tuple] = set() # Track (start, end) to avoid duplicates - - # IMPORTANT: Tree-sitter returns byte positions, not character positions - # We need to slice bytes and decode, not slice the string directly - source_bytes = source_code.encode('utf-8') - - # File-level metadata collected dynamically from AST - file_metadata: Dict[str, Any] = { - 'imports': [], - 'types_referenced': [], - } - - # Get node types for this language (only for chunking - class/function boundaries) - lang_node_types = SEMANTIC_NODE_TYPES.get(language, {}) - class_types = set(lang_node_types.get('class', [])) - function_types = set(lang_node_types.get('function', [])) - all_semantic_types = class_types | function_types - - def get_node_text(node) -> str: - """Get full text content of a node using byte positions""" - return source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') - - def extract_identifiers(node) -> List[str]: - """Recursively extract all identifier names from a node""" - identifiers = [] - if node.type in ('identifier', 'name', 'type_identifier', 'qualified_name', 'scoped_identifier'): - identifiers.append(get_node_text(node)) - for child in node.children: - identifiers.extend(extract_identifiers(child)) - return identifiers - - def extract_ast_metadata(node, metadata: Dict[str, Any]) -> None: - """ - Dynamically extract metadata from AST node based on node type patterns. - No hardcoded language mappings - uses common naming conventions in tree-sitter. - """ - node_type = node.type - node_text = get_node_text(node).strip() - - # === IMPORTS (pattern: *import*, *use*, *require*, *include*) === - if any(kw in node_type for kw in ('import', 'use_', 'require', 'include', 'using')): - if node_text and len(node_text) < 500: # Skip huge nodes - metadata['imports'].append(node_text) - return # Don't recurse into import children - - # === NAMESPACE/PACKAGE (pattern: *namespace*, *package*, *module*) === - if any(kw in node_type for kw in ('namespace', 'package', 'module_declaration')): - # Extract the name part - names = extract_identifiers(node) - if names: - metadata['namespace'] = names[0] if len(names) == 1 else '.'.join(names) - elif node_text: - # Fallback: parse from text - metadata['namespace'] = self._parse_namespace(node_text, language) - return - - # Recurse into children - for child in node.children: - extract_ast_metadata(child, metadata) - - def extract_class_metadata_from_ast(node) -> Dict[str, Any]: - """ - Dynamically extract class metadata (extends, implements) from AST. - Uses common tree-sitter naming patterns - no manual mapping needed. - """ - meta: Dict[str, Any] = {} - - def find_inheritance(n, depth=0): - """Recursively find inheritance-related nodes""" - node_type = n.type - - # === EXTENDS / SUPERCLASS (pattern: *super*, *base*, *extends*, *heritage*) === - if any(kw in node_type for kw in ('super', 'base_clause', 'extends', 'heritage', 'parent')): - names = extract_identifiers(n) - if names: - meta.setdefault('extends', []).extend(names) - meta['parent_types'] = meta.get('extends', []) - return # Found it, don't go deeper - - # === IMPLEMENTS / INTERFACES (pattern: *implement*, *interface_clause*, *conform*) === - if any(kw in node_type for kw in ('implement', 'interface_clause', 'conform', 'protocol')): - names = extract_identifiers(n) - if names: - meta.setdefault('implements', []).extend(names) - return - - # === TRAIT/MIXIN (pattern: *trait*, *mixin*, *with*) === - if any(kw in node_type for kw in ('trait', 'mixin', 'with_clause')): - names = extract_identifiers(n) - if names: - meta.setdefault('traits', []).extend(names) - return - - # === TYPE PARAMETERS / GENERICS === - if any(kw in node_type for kw in ('type_parameter', 'generic', 'type_argument')): - names = extract_identifiers(n) - if names: - meta.setdefault('type_params', []).extend(names) - return - - # Recurse but limit depth to avoid going too deep - if depth < 5: - for child in n.children: - find_inheritance(child, depth + 1) - - for child in node.children: - find_inheritance(child) - - # Deduplicate - for key in meta: - if isinstance(meta[key], list): - meta[key] = list(dict.fromkeys(meta[key])) # Preserve order, remove dupes - - return meta - - def get_node_name(node) -> Optional[str]: - """Extract name from a node (class/function name)""" - for child in node.children: - if child.type in ('identifier', 'name', 'type_identifier', 'property_identifier'): - return get_node_text(child) - return None - - def traverse(node, parent_context: List[str], depth: int = 0): - """ - Recursively traverse AST and extract semantic chunks with breadcrumbs. - - Args: - node: Current AST node - parent_context: List of parent class/function names (breadcrumb) - depth: Current depth in tree - """ - node_range = (node.start_byte, node.end_byte) - - # Check if this is a semantic unit - if node.type in all_semantic_types: - # Skip if already processed (nested in another chunk) - if node_range in processed_ranges: - return - - # Use bytes for slicing since tree-sitter returns byte positions - content = source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') - - # Calculate line numbers (use bytes for consistency) - start_line = source_bytes[:node.start_byte].count(b'\n') + 1 - end_line = start_line + content.count('\n') - - # Get the name of this node - node_name = get_node_name(node) - - # Determine content type - is_class = node.type in class_types - - # Extract class inheritance metadata dynamically from AST - class_metadata = {} - if is_class: - class_metadata = extract_class_metadata_from_ast(node) - - chunk = ASTChunk( - content=content, - content_type=ContentType.FUNCTIONS_CLASSES, - language=language, - path=path, - semantic_names=[node_name] if node_name else [], - parent_context=list(parent_context), # Copy the breadcrumb - start_line=start_line, - end_line=end_line, - node_type=node.type, - class_metadata=class_metadata, - ) - - chunks.append(chunk) - processed_ranges.add(node_range) - - # If this is a class, traverse children with updated context - if is_class and node_name: - new_context = parent_context + [node_name] - for child in node.children: - traverse(child, new_context, depth + 1) - else: - # Continue traversing children with current context - for child in node.children: - traverse(child, parent_context, depth + 1) - - # First pass: extract file-level metadata (imports, namespace) from entire AST - extract_ast_metadata(root_node, file_metadata) - - # Second pass: extract semantic chunks (classes, functions) - traverse(root_node, []) - - # Clean up file_metadata - remove empty values - clean_file_metadata = {k: v for k, v in file_metadata.items() if v} - - # Create simplified code (skeleton with placeholders) - simplified = self._create_simplified_code(source_code, chunks, language) - if simplified and simplified.strip() and len(simplified.strip()) > 50: - chunks.append(ASTChunk( - content=simplified, - content_type=ContentType.SIMPLIFIED_CODE, - language=language, - path=path, - semantic_names=[], - parent_context=[], - start_line=1, - end_line=source_code.count('\n') + 1, - node_type='simplified', - file_metadata=clean_file_metadata, # Include imports/namespace from AST - )) - - # Also attach file_metadata to all chunks for enriched metadata - for chunk in chunks: - if not chunk.file_metadata: - chunk.file_metadata = clean_file_metadata - - return chunks - - def _create_simplified_code( - self, - source_code: str, - chunks: List[ASTChunk], - language: str - ) -> str: - """ - Create simplified code with placeholders for extracted chunks. - - This gives RAG context about the overall file structure without - including full function/class bodies. - - Example output: - # Code for: class MyClass: - # Code for: def my_function(): - if __name__ == "__main__": - main() - """ - if not chunks: - return source_code - - # Get chunks that are functions_classes type (not simplified) - semantic_chunks = [c for c in chunks if c.content_type == ContentType.FUNCTIONS_CLASSES] - - if not semantic_chunks: - return source_code - - # Sort by start position (reverse) to replace from end - sorted_chunks = sorted( - semantic_chunks, - key=lambda x: source_code.find(x.content), - reverse=True - ) - - result = source_code - - # Comment style by language - comment_prefix = { - 'python': '#', - 'javascript': '//', - 'typescript': '//', - 'java': '//', - 'kotlin': '//', - 'go': '//', - 'rust': '//', - 'c': '//', - 'cpp': '//', - 'c_sharp': '//', - 'php': '//', - 'ruby': '#', - 'lua': '--', - 'perl': '#', - 'scala': '//', - }.get(language, '//') - - for chunk in sorted_chunks: - # Find the position of this chunk in the source - pos = result.find(chunk.content) - if pos == -1: - continue - - # Extract first line for placeholder - first_line = chunk.content.split('\n')[0].strip() - # Truncate if too long - if len(first_line) > 60: - first_line = first_line[:60] + '...' - - # Add breadcrumb context to placeholder - breadcrumb = "" - if chunk.parent_context: - breadcrumb = f" (in {'.'.join(chunk.parent_context)})" - - placeholder = f"{comment_prefix} Code for: {first_line}{breadcrumb}\n" - - result = result[:pos] + placeholder + result[pos + len(chunk.content):] - - return result.strip() - - def _extract_metadata( - self, - chunk: ASTChunk, - base_metadata: Dict[str, Any] - ) -> Dict[str, Any]: - """Extract and enrich metadata from an AST chunk""" - metadata = dict(base_metadata) - - # Core AST metadata - metadata['content_type'] = chunk.content_type.value - metadata['node_type'] = chunk.node_type - - # Breadcrumb context (critical for RAG) - if chunk.parent_context: - metadata['parent_context'] = chunk.parent_context - metadata['parent_class'] = chunk.parent_context[-1] if chunk.parent_context else None - metadata['full_path'] = '.'.join(chunk.parent_context + chunk.semantic_names[:1]) - - # Semantic names - if chunk.semantic_names: - metadata['semantic_names'] = chunk.semantic_names[:10] - metadata['primary_name'] = chunk.semantic_names[0] - - # Line numbers - metadata['start_line'] = chunk.start_line - metadata['end_line'] = chunk.end_line - - # === Use AST-extracted metadata (from tree-sitter) === - - # File-level metadata from AST (imports, namespace, package) - if chunk.file_metadata: - if chunk.file_metadata.get('imports'): - metadata['imports'] = chunk.file_metadata['imports'][:20] - if chunk.file_metadata.get('namespace'): - metadata['namespace'] = chunk.file_metadata['namespace'] - if chunk.file_metadata.get('package'): - metadata['package'] = chunk.file_metadata['package'] - - # Class-level metadata from AST (extends, implements) - if chunk.class_metadata: - if chunk.class_metadata.get('extends'): - metadata['extends'] = chunk.class_metadata['extends'] - metadata['parent_types'] = chunk.class_metadata['extends'] - if chunk.class_metadata.get('implements'): - metadata['implements'] = chunk.class_metadata['implements'] - - # Try to extract additional metadata via regex patterns - patterns = METADATA_PATTERNS.get(chunk.language, {}) - - # Extract docstring - docstring = self._extract_docstring(chunk.content, chunk.language) - if docstring: - metadata['docstring'] = docstring[:500] - - # Extract signature - signature = self._extract_signature(chunk.content, chunk.language) - if signature: - metadata['signature'] = signature - - # Extract additional names not caught by AST - if not chunk.semantic_names: - names = [] - for pattern_type, pattern in patterns.items(): - matches = pattern.findall(chunk.content) - names.extend(matches) - if names: - metadata['semantic_names'] = list(set(names))[:10] - metadata['primary_name'] = names[0] - - # Fallback: Extract inheritance via regex if AST didn't find it - if 'extends' not in metadata and 'implements' not in metadata: - self._extract_inheritance_metadata(chunk.content, chunk.language, metadata) - - return metadata - - def _extract_inheritance_metadata( - self, - content: str, - language: str, - metadata: Dict[str, Any] - ) -> None: - """Extract inheritance, interfaces, and imports from code chunk""" - inheritance_patterns = CLASS_INHERITANCE_PATTERNS.get(language, {}) - - if not inheritance_patterns: - return - - # Extract extends (parent class) - if 'extends' in inheritance_patterns: - match = inheritance_patterns['extends'].search(content) - if match: - extends = match.group(1).strip() - # Clean up and split multiple classes (for multiple inheritance) - extends_list = [e.strip() for e in extends.split(',') if e.strip()] - if extends_list: - metadata['extends'] = extends_list - metadata['parent_types'] = extends_list # Alias for searchability - - # Extract implements (interfaces) - if 'implements' in inheritance_patterns: - match = inheritance_patterns['implements'].search(content) - if match: - implements = match.group(1).strip() - implements_list = [i.strip() for i in implements.split(',') if i.strip()] - if implements_list: - metadata['implements'] = implements_list - - # Extract imports/use statements - import_key = None - for key in ['import', 'use', 'using', 'require']: - if key in inheritance_patterns: - import_key = key - break - - if import_key: - matches = inheritance_patterns[import_key].findall(content) - if matches: - # Flatten if matches are tuples (from groups in regex) - imports = [] - for m in matches: - if isinstance(m, tuple): - imports.extend([x.strip() for x in m if x and x.strip()]) - else: - imports.append(m.strip()) - if imports: - metadata['imports'] = imports[:20] # Limit to 20 - - # Extract namespace/package - for key in ['namespace', 'package']: - if key in inheritance_patterns: - match = inheritance_patterns[key].search(content) - if match: - metadata[key] = match.group(1).strip() - break - - # Extract Rust impl for - if 'impl_for' in inheritance_patterns: - matches = inheritance_patterns['impl_for'].findall(content) - if matches: - # matches are tuples of (trait, type) - metadata['impl_traits'] = [m[0] for m in matches if m[0]] - metadata['impl_types'] = [m[1] for m in matches if m[1]] - - # Extract Scala with traits - if 'with' in inheritance_patterns: - matches = inheritance_patterns['with'].findall(content) - if matches: - metadata['with_traits'] = matches - - # Extract PHP type hints from docblocks - if 'type_hint' in inheritance_patterns: - matches = inheritance_patterns['type_hint'].findall(content) - if matches: - # Extract unique types referenced in docblocks - type_refs = list(set(matches))[:10] - metadata['type_references'] = type_refs - - # Extract PHTML template type hints (/** @var \Class $var */) - if 'template_type' in inheritance_patterns: - matches = inheritance_patterns['template_type'].findall(content) - if matches: - template_types = list(set(matches))[:10] - # Merge with type_references if exists - existing = metadata.get('type_references', []) - metadata['type_references'] = list(set(existing + template_types))[:15] - # Also add to related_classes for better searchability - metadata['related_classes'] = template_types - - # Extract PHP PHPDoc types (@param, @return, @throws) - if 'phpdoc_types' in inheritance_patterns: - matches = inheritance_patterns['phpdoc_types'].findall(content) - if matches: - # Filter and clean type names, handle union types - phpdoc_types = [] - for m in matches: - for t in m.split('|'): - t = t.strip().lstrip('\\') - if t and t[0].isupper(): # Only class names - phpdoc_types.append(t) - if phpdoc_types: - existing = metadata.get('type_references', []) - metadata['type_references'] = list(set(existing + phpdoc_types))[:20] - - def _extract_docstring(self, content: str, language: str) -> Optional[str]: - """Extract docstring from code chunk""" - if language == 'python': - match = re.search(r'"""([\s\S]*?)"""|\'\'\'([\s\S]*?)\'\'\'', content) - if match: - return (match.group(1) or match.group(2)).strip() - - elif language in ('javascript', 'typescript', 'java', 'kotlin', 'c_sharp', 'php', 'go', 'scala'): - match = re.search(r'/\*\*([\s\S]*?)\*/', content) - if match: - doc = match.group(1) - doc = re.sub(r'^\s*\*\s?', '', doc, flags=re.MULTILINE) - return doc.strip() - - elif language == 'rust': - lines = [] - for line in content.split('\n'): - if line.strip().startswith('///'): - lines.append(line.strip()[3:].strip()) - elif lines: - break - if lines: - return '\n'.join(lines) - - return None - - def _extract_signature(self, content: str, language: str) -> Optional[str]: - """Extract function/method signature from code chunk""" - lines = content.split('\n') - - for line in lines[:15]: - line = line.strip() - - if language == 'python': - if line.startswith(('def ', 'async def ', 'class ')): - sig = line - if line.startswith('class ') and ':' in line: - return line.split(':')[0] + ':' - if ')' not in sig and ':' not in sig: - idx = -1 - for i, l in enumerate(lines): - if l.strip() == line: - idx = i - break - if idx >= 0: - for next_line in lines[idx+1:idx+5]: - sig += ' ' + next_line.strip() - if ')' in next_line: - break - if ':' in sig: - return sig.split(':')[0] + ':' - return sig - - elif language in ('java', 'kotlin', 'c_sharp'): - if any(kw in line for kw in ['public ', 'private ', 'protected ', 'internal ', 'fun ']): - if '(' in line and not line.startswith('//'): - return line.split('{')[0].strip() - - elif language in ('javascript', 'typescript'): - if line.startswith(('function ', 'async function ', 'class ')): - return line.split('{')[0].strip() - if '=>' in line and '(' in line: - return line.split('=>')[0].strip() + ' =>' - - elif language == 'go': - if line.startswith('func ') or line.startswith('type '): - return line.split('{')[0].strip() - - elif language == 'rust': - if line.startswith(('fn ', 'pub fn ', 'async fn ', 'pub async fn ', 'impl ', 'struct ', 'trait ')): - return line.split('{')[0].strip() - - return None - - def _split_oversized_chunk( - self, - chunk: ASTChunk, - language: Optional[Language], - base_metadata: Dict[str, Any], - path: str - ) -> List[TextNode]: - """ - Split an oversized chunk using RecursiveCharacterTextSplitter. - - This is used when AST-parsed chunks (e.g., very large classes/functions) - still exceed the max_chunk_size. - """ - splitter = ( - self._get_text_splitter(language) - if language and language in AST_SUPPORTED_LANGUAGES - else self._default_splitter - ) - - sub_chunks = splitter.split_text(chunk.content) - nodes = [] - - # Parent ID for linking sub-chunks - parent_id = generate_deterministic_id(path, chunk.content, 0) - - for i, sub_chunk in enumerate(sub_chunks): - if not sub_chunk or not sub_chunk.strip(): - continue - - if len(sub_chunk.strip()) < self.min_chunk_size and len(sub_chunks) > 1: - continue - - metadata = dict(base_metadata) - metadata['content_type'] = ContentType.OVERSIZED_SPLIT.value - metadata['original_content_type'] = chunk.content_type.value - metadata['parent_chunk_id'] = parent_id - metadata['sub_chunk_index'] = i - metadata['total_sub_chunks'] = len(sub_chunks) - - # Preserve breadcrumb context - if chunk.parent_context: - metadata['parent_context'] = chunk.parent_context - metadata['parent_class'] = chunk.parent_context[-1] - - if chunk.semantic_names: - metadata['semantic_names'] = chunk.semantic_names - metadata['primary_name'] = chunk.semantic_names[0] - - # Deterministic ID for this sub-chunk - chunk_id = generate_deterministic_id(path, sub_chunk, i) - - node = TextNode( - id_=chunk_id, - text=sub_chunk, - metadata=metadata - ) - nodes.append(node) - - return nodes - - def split_documents(self, documents: List[LlamaDocument]) -> List[TextNode]: - """ - Split LlamaIndex documents using AST-based parsing. - - Args: - documents: List of LlamaIndex Document objects - - Returns: - List of TextNode objects with enriched metadata and deterministic IDs - """ - all_nodes = [] - - for doc in documents: - path = doc.metadata.get('path', 'unknown') - - # Determine Language enum - language = self._get_language_from_path(path) - - # Check if AST parsing is supported and beneficial - line_count = doc.text.count('\n') + 1 - use_ast = ( - language is not None - and language in AST_SUPPORTED_LANGUAGES - and line_count >= self.parser_threshold - and self._check_tree_sitter() - ) - - if use_ast: - nodes = self._split_with_ast(doc, language) - else: - nodes = self._split_fallback(doc, language) - - all_nodes.extend(nodes) - logger.debug(f"Split {path} into {len(nodes)} chunks (AST={use_ast})") - - return all_nodes - - def _split_with_ast( - self, - doc: LlamaDocument, - language: Language - ) -> List[TextNode]: - """Split document using AST parsing with breadcrumb context""" - text = doc.text - path = doc.metadata.get('path', 'unknown') - - # Try AST parsing - ast_chunks = self._parse_with_ast(text, language, path) - - if not ast_chunks: - return self._split_fallback(doc, language) - - nodes = [] - chunk_counter = 0 - - for ast_chunk in ast_chunks: - # Check if chunk is oversized - if len(ast_chunk.content) > self.max_chunk_size: - # Split oversized chunk - sub_nodes = self._split_oversized_chunk( - ast_chunk, - language, - doc.metadata, - path - ) - nodes.extend(sub_nodes) - chunk_counter += len(sub_nodes) - else: - # Create node with enriched metadata - metadata = self._extract_metadata(ast_chunk, doc.metadata) - metadata['chunk_index'] = chunk_counter - metadata['total_chunks'] = len(ast_chunks) - - # Deterministic ID - chunk_id = generate_deterministic_id(path, ast_chunk.content, chunk_counter) - - node = TextNode( - id_=chunk_id, - text=ast_chunk.content, - metadata=metadata - ) - nodes.append(node) - chunk_counter += 1 - - return nodes - - def _split_fallback( - self, - doc: LlamaDocument, - language: Optional[Language] = None - ) -> List[TextNode]: - """Fallback splitting using RecursiveCharacterTextSplitter""" - text = doc.text - path = doc.metadata.get('path', 'unknown') - - if not text or not text.strip(): - return [] - - splitter = ( - self._get_text_splitter(language) - if language and language in AST_SUPPORTED_LANGUAGES - else self._default_splitter - ) - - chunks = splitter.split_text(text) - nodes = [] - text_offset = 0 - - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - - if len(chunk.strip()) < self.min_chunk_size and len(chunks) > 1: - continue - - # Truncate if too large - if len(chunk) > 30000: - chunk = chunk[:30000] - - # Calculate line numbers - start_line = text[:text_offset].count('\n') + 1 if text_offset > 0 else 1 - chunk_pos = text.find(chunk, text_offset) - if chunk_pos >= 0: - text_offset = chunk_pos + len(chunk) - end_line = start_line + chunk.count('\n') - - # Extract metadata using regex patterns - lang_str = doc.metadata.get('language', 'text') - metadata = dict(doc.metadata) - metadata['content_type'] = ContentType.FALLBACK.value - metadata['chunk_index'] = i - metadata['total_chunks'] = len(chunks) - metadata['start_line'] = start_line - metadata['end_line'] = end_line - - # Try to extract semantic names - patterns = METADATA_PATTERNS.get(lang_str, {}) - names = [] - for pattern_type, pattern in patterns.items(): - matches = pattern.findall(chunk) - names.extend(matches) - if names: - metadata['semantic_names'] = list(set(names))[:10] - metadata['primary_name'] = names[0] - - # Extract class inheritance, interfaces, and imports (also for fallback) - self._extract_inheritance_metadata(chunk, lang_str, metadata) - - # Deterministic ID - chunk_id = generate_deterministic_id(path, chunk, i) - - node = TextNode( - id_=chunk_id, - text=chunk, - metadata=metadata - ) - nodes.append(node) - - return nodes - - @staticmethod - def get_supported_languages() -> List[str]: - """Return list of languages with AST support""" - return list(LANGUAGE_TO_TREESITTER.values()) - - @staticmethod - def is_ast_supported(path: str) -> bool: - """Check if AST parsing is supported for a file""" - ext = Path(path).suffix.lower() - lang = EXTENSION_TO_LANGUAGE.get(ext) - return lang is not None and lang in AST_SUPPORTED_LANGUAGES diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py deleted file mode 100644 index 3c8cd8cc..00000000 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import List -import uuid -from llama_index.core.node_parser import SentenceSplitter -from llama_index.core.schema import Document, TextNode -from ..utils.utils import is_code_file - - -class CodeAwareSplitter: - """ - Code-aware text splitter that handles code and text differently. - - DEPRECATED: Use SemanticCodeSplitter instead, which provides: - - Full AST-aware parsing for multiple languages - - Better metadata extraction (docstrings, signatures, imports) - - Smarter chunk merging and boundary detection - - This class just wraps SentenceSplitter with different chunk sizes for - code vs text. For truly semantic code splitting, use SemanticCodeSplitter. - """ - - def __init__(self, code_chunk_size: int = 800, code_overlap: int = 200, - text_chunk_size: int = 1000, text_overlap: int = 200): - self.code_splitter = SentenceSplitter( - chunk_size=code_chunk_size, - chunk_overlap=code_overlap, - separator="\n\n", - ) - - self.text_splitter = SentenceSplitter( - chunk_size=text_chunk_size, - chunk_overlap=text_overlap, - ) - - def split_documents(self, documents: List[Document]) -> List[TextNode]: - """Split documents into chunks based on their language type""" - result = [] - - for doc in documents: - language = doc.metadata.get("language", "text") - is_code = is_code_file(language) - - splitter = self.code_splitter if is_code else self.text_splitter - - nodes = splitter.get_nodes_from_documents([doc]) - - for i, node in enumerate(nodes): - # Skip empty or whitespace-only chunks - if not node.text or not node.text.strip(): - continue - - # Truncate text if too large (>30k chars ≈ 7.5k tokens) - text = node.text - if len(text) > 30000: - text = text[:30000] - - metadata = dict(doc.metadata) - metadata["chunk_index"] = i - metadata["total_chunks"] = len(nodes) - - # Create TextNode with explicit UUID - chunk_node = TextNode( - id_=str(uuid.uuid4()), - text=text, - metadata=metadata - ) - result.append(chunk_node) - - return result - - def split_text_for_language(self, text: str, language: str) -> List[str]: - """Split text based on language type""" - is_code = is_code_file(language) - splitter = self.code_splitter if is_code else self.text_splitter - - temp_doc = Document(text=text, metadata={"language": language}) - nodes = splitter.get_nodes_from_documents([temp_doc]) - - return [node.text for node in nodes] - - -class FunctionAwareSplitter: - """ - Advanced splitter that tries to preserve function boundaries. - - DEPRECATED: Use SemanticCodeSplitter instead, which provides: - - Full AST-aware parsing for multiple languages - - Better metadata extraction (docstrings, signatures, imports) - - Smarter chunk merging and boundary detection - - This class is kept for backward compatibility only. - """ - - def __init__(self, max_chunk_size: int = 800, overlap: int = 200): - self.max_chunk_size = max_chunk_size - self.overlap = overlap - self.fallback_splitter = SentenceSplitter( - chunk_size=max_chunk_size, - chunk_overlap=overlap, - ) - - def split_by_functions(self, text: str, language: str) -> List[str]: - """Try to split code by functions/classes""" - - if language == 'python': - return self._split_python(text) - elif language in ['javascript', 'typescript', 'java', 'cpp', 'c', 'go', 'rust', 'php']: - return self._split_brace_language(text) - else: - temp_doc = Document(text=text) - nodes = self.fallback_splitter.get_nodes_from_documents([temp_doc]) - return [node.text for node in nodes] - - def _split_python(self, text: str) -> List[str]: - """Split Python code by top-level definitions""" - lines = text.split('\n') - chunks = [] - current_chunk = [] - - for line in lines: - stripped = line.lstrip() - - if stripped.startswith(('def ', 'class ', 'async def ')): - if current_chunk and len('\n'.join(current_chunk)) > 50: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - - current_chunk.append(line) - - if len('\n'.join(current_chunk)) > self.max_chunk_size: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - - if current_chunk: - chunks.append('\n'.join(current_chunk)) - - return chunks if chunks else [text] - - def _split_brace_language(self, text: str) -> List[str]: - """Split brace-based languages by functions/classes""" - chunks = [] - current_chunk = [] - brace_count = 0 - in_function = False - - lines = text.split('\n') - - for line in lines: - if any(keyword in line for keyword in - ['function ', 'class ', 'def ', 'fn ', 'func ', 'public ', 'private ', 'protected ']): - if '{' in line: - in_function = True - - current_chunk.append(line) - - brace_count += line.count('{') - line.count('}') - - if in_function and brace_count == 0 and len(current_chunk) > 3: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - in_function = False - - if len('\n'.join(current_chunk)) > self.max_chunk_size: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - brace_count = 0 - - if current_chunk: - chunks.append('\n'.join(current_chunk)) - - return chunks if chunks else [text] - diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py index 1694fcd9..7eebf78e 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py @@ -20,8 +20,7 @@ from ..models.config import RAGConfig, IndexStats from ..utils.utils import make_namespace, make_project_namespace -from .semantic_splitter import SemanticCodeSplitter -from .ast_splitter import ASTCodeSplitter +from .splitter import ASTCodeSplitter from .loader import DocumentLoader from .openrouter_embedding import OpenRouterEmbedding @@ -57,25 +56,15 @@ def __init__(self, config: RAGConfig): Settings.chunk_size = config.chunk_size Settings.chunk_overlap = config.chunk_overlap - # Choose splitter based on environment variable or config - # AST splitter provides better semantic chunking for supported languages - use_ast_splitter = os.environ.get('RAG_USE_AST_SPLITTER', 'true').lower() == 'true' - - if use_ast_splitter: - logger.info("Using ASTCodeSplitter for code chunking (tree-sitter based)") - self.splitter = ASTCodeSplitter( - max_chunk_size=config.chunk_size, - min_chunk_size=min(200, config.chunk_size // 4), - chunk_overlap=config.chunk_overlap, - parser_threshold=10 # Minimum lines for AST parsing - ) - else: - logger.info("Using SemanticCodeSplitter for code chunking (regex-based)") - self.splitter = SemanticCodeSplitter( - max_chunk_size=config.chunk_size, - min_chunk_size=min(200, config.chunk_size // 4), - overlap=config.chunk_overlap - ) + # AST splitter with tree-sitter query-based parsing + # Falls back internally when tree-sitter unavailable + logger.info("Using ASTCodeSplitter for code chunking (tree-sitter query-based)") + self.splitter = ASTCodeSplitter( + max_chunk_size=config.chunk_size, + min_chunk_size=min(200, config.chunk_size // 4), + chunk_overlap=config.chunk_overlap, + parser_threshold=10 # Minimum lines for AST parsing + ) self.loader = DocumentLoader(config) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py new file mode 100644 index 00000000..d06d664f --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py @@ -0,0 +1,10 @@ +""" +Index Manager module for RAG indexing operations. + +Provides RAGIndexManager as the main entry point for indexing repositories, +managing collections, and handling branch-level operations. +""" + +from .manager import RAGIndexManager + +__all__ = ["RAGIndexManager"] diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py new file mode 100644 index 00000000..3782a4f6 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py @@ -0,0 +1,172 @@ +""" +Branch-level operations for RAG indices. + +Handles branch-specific point management within project collections. +""" + +import logging +from typing import List, Set, Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue, PointStruct + +logger = logging.getLogger(__name__) + + +class BranchManager: + """Manages branch-level operations within project collections.""" + + def __init__(self, client: QdrantClient): + self.client = client + + def delete_branch_points( + self, + collection_name: str, + branch: str + ) -> bool: + """Delete all points for a specific branch from the collection.""" + logger.info(f"Deleting all points for branch '{branch}' from {collection_name}") + + try: + self.client.delete( + collection_name=collection_name, + points_selector=Filter( + must=[ + FieldCondition( + key="branch", + match=MatchValue(value=branch) + ) + ] + ) + ) + logger.info(f"Successfully deleted all points for branch '{branch}'") + return True + except Exception as e: + logger.error(f"Failed to delete branch '{branch}': {e}") + return False + + def get_branch_point_count( + self, + collection_name: str, + branch: str + ) -> int: + """Get the number of points for a specific branch.""" + try: + result = self.client.count( + collection_name=collection_name, + count_filter=Filter( + must=[ + FieldCondition( + key="branch", + match=MatchValue(value=branch) + ) + ] + ) + ) + return result.count + except Exception as e: + logger.error(f"Failed to get point count for branch '{branch}': {e}") + return 0 + + def get_indexed_branches(self, collection_name: str) -> List[str]: + """Get list of branches that have points in the collection.""" + try: + branches: Set[str] = set() + offset = None + limit = 100 + + while True: + results = self.client.scroll( + collection_name=collection_name, + limit=limit, + offset=offset, + with_payload=["branch"], + with_vectors=False + ) + + points, next_offset = results + + for point in points: + if point.payload and "branch" in point.payload: + branches.add(point.payload["branch"]) + + if next_offset is None or len(points) < limit: + break + offset = next_offset + + return list(branches) + except Exception as e: + logger.error(f"Failed to get indexed branches: {e}") + return [] + + def preserve_other_branch_points( + self, + collection_name: str, + exclude_branch: str + ) -> List[PointStruct]: + """Preserve points from branches other than the one being reindexed. + + Used during full reindex to keep data from other branches. + """ + logger.info(f"Preserving points from branches other than '{exclude_branch}'...") + + preserved_points = [] + offset = None + + try: + while True: + results = self.client.scroll( + collection_name=collection_name, + limit=100, + offset=offset, + scroll_filter=Filter( + must_not=[ + FieldCondition( + key="branch", + match=MatchValue(value=exclude_branch) + ) + ] + ), + with_payload=True, + with_vectors=True + ) + points, next_offset = results + preserved_points.extend(points) + + if next_offset is None or len(points) < 100: + break + offset = next_offset + + logger.info(f"Found {len(preserved_points)} points from other branches to preserve") + return preserved_points + except Exception as e: + logger.warning(f"Could not read existing points: {e}") + return [] + + def copy_points_to_collection( + self, + points: List, + target_collection: str, + batch_size: int = 50 + ) -> None: + """Copy preserved points to a new collection.""" + if not points: + return + + logger.info(f"Copying {len(points)} points to {target_collection}...") + + for i in range(0, len(points), batch_size): + batch = points[i:i + batch_size] + points_to_upsert = [ + PointStruct( + id=p.id, + vector=p.vector, + payload=p.payload + ) for p in batch + ] + self.client.upsert( + collection_name=target_collection, + points=points_to_upsert + ) + + logger.info("Points copied successfully") diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py new file mode 100644 index 00000000..81a3a407 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py @@ -0,0 +1,164 @@ +""" +Qdrant collection and alias management utilities. + +Handles collection creation, alias operations, and resolution. +""" + +import logging +import time +from typing import Optional, List + +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, VectorParams, + CreateAlias, DeleteAlias, CreateAliasOperation, DeleteAliasOperation +) + +logger = logging.getLogger(__name__) + + +class CollectionManager: + """Manages Qdrant collections and aliases.""" + + def __init__(self, client: QdrantClient, embedding_dim: int): + self.client = client + self.embedding_dim = embedding_dim + + def ensure_collection_exists(self, collection_name: str) -> None: + """Ensure Qdrant collection exists with proper configuration. + + If the collection_name is actually an alias, use the aliased collection instead. + """ + if self.alias_exists(collection_name): + logger.info(f"Collection name {collection_name} is an alias, using existing aliased collection") + return + + collections = self.client.get_collections().collections + collection_names = [c.name for c in collections] + logger.debug(f"Existing collections: {collection_names}") + + if collection_name not in collection_names: + logger.info(f"Creating Qdrant collection: {collection_name}") + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=self.embedding_dim, + distance=Distance.COSINE + ) + ) + logger.info(f"Created collection {collection_name}") + else: + logger.info(f"Collection {collection_name} already exists") + + def create_versioned_collection(self, base_name: str) -> str: + """Create a new versioned collection for atomic swap indexing.""" + versioned_name = f"{base_name}_v{int(time.time())}" + logger.info(f"Creating versioned collection: {versioned_name}") + + self.client.create_collection( + collection_name=versioned_name, + vectors_config=VectorParams( + size=self.embedding_dim, + distance=Distance.COSINE + ) + ) + return versioned_name + + def delete_collection(self, collection_name: str) -> bool: + """Delete a collection.""" + try: + self.client.delete_collection(collection_name) + logger.info(f"Deleted collection: {collection_name}") + return True + except Exception as e: + logger.warning(f"Failed to delete collection {collection_name}: {e}") + return False + + def collection_exists(self, collection_name: str) -> bool: + """Check if a collection exists (not alias).""" + collections = self.client.get_collections().collections + return collection_name in [c.name for c in collections] + + def get_collection_names(self) -> List[str]: + """Get all collection names.""" + collections = self.client.get_collections().collections + return [c.name for c in collections] + + # Alias operations + + def alias_exists(self, alias_name: str) -> bool: + """Check if an alias exists.""" + try: + aliases = self.client.get_aliases() + exists = any(a.alias_name == alias_name for a in aliases.aliases) + logger.debug(f"Checking if alias '{alias_name}' exists: {exists}") + return exists + except Exception as e: + logger.warning(f"Error checking alias {alias_name}: {e}") + return False + + def resolve_alias(self, alias_name: str) -> Optional[str]: + """Resolve an alias to its underlying collection name.""" + try: + aliases = self.client.get_aliases() + for alias in aliases.aliases: + if alias.alias_name == alias_name: + return alias.collection_name + except Exception as e: + logger.debug(f"Error resolving alias {alias_name}: {e}") + return None + + def atomic_alias_swap( + self, + alias_name: str, + new_collection: str, + old_alias_exists: bool + ) -> None: + """Perform atomic alias swap for zero-downtime reindexing.""" + alias_operations = [] + + if old_alias_exists: + alias_operations.append( + DeleteAliasOperation(delete_alias=DeleteAlias(alias_name=alias_name)) + ) + + alias_operations.append( + CreateAliasOperation(create_alias=CreateAlias( + alias_name=alias_name, + collection_name=new_collection + )) + ) + + self.client.update_collection_aliases( + change_aliases_operations=alias_operations + ) + logger.info(f"Alias swap completed: {alias_name} -> {new_collection}") + + def delete_alias(self, alias_name: str) -> bool: + """Delete an alias.""" + try: + self.client.delete_alias(alias_name) + logger.info(f"Deleted alias: {alias_name}") + return True + except Exception as e: + logger.warning(f"Failed to delete alias {alias_name}: {e}") + return False + + def cleanup_orphaned_versioned_collections( + self, + base_name: str, + current_target: Optional[str] = None, + exclude_name: Optional[str] = None + ) -> int: + """Clean up orphaned versioned collections from failed indexing attempts.""" + cleaned = 0 + collection_names = self.get_collection_names() + + for coll_name in collection_names: + if coll_name.startswith(f"{base_name}_v") and coll_name != exclude_name: + if current_target != coll_name: + logger.info(f"Cleaning up orphaned versioned collection: {coll_name}") + if self.delete_collection(coll_name): + cleaned += 1 + + return cleaned diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py new file mode 100644 index 00000000..2bf145a4 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py @@ -0,0 +1,398 @@ +""" +Repository indexing operations. + +Handles full repository indexing with atomic swap and streaming processing. +""" + +import gc +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional, List + +from qdrant_client.models import Filter, FieldCondition, MatchAny, MatchValue + +from ...models.config import RAGConfig, IndexStats +from ...utils.utils import make_namespace +from .collection_manager import CollectionManager +from .branch_manager import BranchManager +from .point_operations import PointOperations +from .stats_manager import StatsManager + +logger = logging.getLogger(__name__) + +# Memory-efficient batch sizes +DOCUMENT_BATCH_SIZE = 50 +INSERT_BATCH_SIZE = 50 + + +class RepositoryIndexer: + """Handles repository indexing operations.""" + + def __init__( + self, + config: RAGConfig, + collection_manager: CollectionManager, + branch_manager: BranchManager, + point_ops: PointOperations, + stats_manager: StatsManager, + splitter, + loader + ): + self.config = config + self.collection_manager = collection_manager + self.branch_manager = branch_manager + self.point_ops = point_ops + self.stats_manager = stats_manager + self.splitter = splitter + self.loader = loader + + def estimate_repository_size( + self, + repo_path: str, + exclude_patterns: Optional[List[str]] = None + ) -> tuple[int, int]: + """Estimate repository size (file count and chunk count) without actually indexing.""" + logger.info(f"Estimating repository size for: {repo_path}") + + repo_path_obj = Path(repo_path) + file_list = list(self.loader.iter_repository_files(repo_path_obj, exclude_patterns)) + file_count = len(file_list) + logger.info(f"Found {file_count} files for estimation") + + if file_count == 0: + return 0, 0 + + SAMPLE_SIZE = 100 + chunk_count = 0 + + if file_count <= SAMPLE_SIZE: + for i in range(0, file_count, DOCUMENT_BATCH_SIZE): + batch = file_list[i:i + DOCUMENT_BATCH_SIZE] + documents = self.loader.load_file_batch( + batch, repo_path_obj, "estimate", "estimate", "estimate", "estimate" + ) + if documents: + chunks = self.splitter.split_documents(documents) + chunk_count += len(chunks) + del chunks + del documents + gc.collect() + else: + import random + sample_files = random.sample(file_list, SAMPLE_SIZE) + sample_chunk_count = 0 + + for i in range(0, len(sample_files), DOCUMENT_BATCH_SIZE): + batch = sample_files[i:i + DOCUMENT_BATCH_SIZE] + documents = self.loader.load_file_batch( + batch, repo_path_obj, "estimate", "estimate", "estimate", "estimate" + ) + if documents: + chunks = self.splitter.split_documents(documents) + sample_chunk_count += len(chunks) + del chunks + del documents + + avg_chunks_per_file = sample_chunk_count / SAMPLE_SIZE + chunk_count = int(avg_chunks_per_file * file_count) + logger.info(f"Estimated ~{avg_chunks_per_file:.1f} chunks/file from {SAMPLE_SIZE} samples") + gc.collect() + + logger.info(f"Estimated {chunk_count} chunks from {file_count} files") + return file_count, chunk_count + + def index_repository( + self, + repo_path: str, + workspace: str, + project: str, + branch: str, + commit: str, + alias_name: str, + exclude_patterns: Optional[List[str]] = None + ) -> IndexStats: + """Index entire repository for a branch using atomic swap strategy.""" + logger.info(f"Indexing repository: {workspace}/{project}/{branch} from {repo_path}") + + repo_path_obj = Path(repo_path) + temp_collection_name = self.collection_manager.create_versioned_collection(alias_name) + + # Check existing collection and preserve other branch data + old_collection_exists = self.collection_manager.alias_exists(alias_name) + if not old_collection_exists: + old_collection_exists = self.collection_manager.collection_exists(alias_name) + + existing_other_branch_points = [] + if old_collection_exists: + actual_collection = self.collection_manager.resolve_alias(alias_name) or alias_name + existing_other_branch_points = self.branch_manager.preserve_other_branch_points( + actual_collection, branch + ) + + # Clean up orphaned versioned collections + current_target = self.collection_manager.resolve_alias(alias_name) + self.collection_manager.cleanup_orphaned_versioned_collections( + alias_name, current_target, temp_collection_name + ) + + # Get file list + file_list = list(self.loader.iter_repository_files(repo_path_obj, exclude_patterns)) + total_files = len(file_list) + logger.info(f"Found {total_files} files to index for branch '{branch}'") + + if total_files == 0: + logger.warning("No documents to index") + self.collection_manager.delete_collection(temp_collection_name) + return self.stats_manager.get_branch_stats( + workspace, project, branch, + self.collection_manager.resolve_alias(alias_name) or alias_name + ) + + # Validate limits + if self.config.max_files_per_index > 0 and total_files > self.config.max_files_per_index: + self.collection_manager.delete_collection(temp_collection_name) + raise ValueError( + f"Repository exceeds file limit: {total_files} files (max: {self.config.max_files_per_index})." + ) + + if self.config.max_chunks_per_index > 0: + logger.info("Estimating chunk count before indexing...") + _, estimated_chunks = self.estimate_repository_size(repo_path, exclude_patterns) + if estimated_chunks > self.config.max_chunks_per_index * 1.2: + self.collection_manager.delete_collection(temp_collection_name) + raise ValueError( + f"Repository estimated to exceed chunk limit: ~{estimated_chunks} chunks (max: {self.config.max_chunks_per_index})." + ) + + document_count = 0 + chunk_count = 0 + successful_chunks = 0 + failed_chunks = 0 + + try: + # Copy preserved points from other branches + if existing_other_branch_points: + self.branch_manager.copy_points_to_collection( + existing_other_branch_points, + temp_collection_name, + INSERT_BATCH_SIZE + ) + + # Stream process files in batches + logger.info("Starting memory-efficient streaming indexing...") + batch_num = 0 + total_batches = (total_files + DOCUMENT_BATCH_SIZE - 1) // DOCUMENT_BATCH_SIZE + + for i in range(0, total_files, DOCUMENT_BATCH_SIZE): + batch_num += 1 + file_batch = file_list[i:i + DOCUMENT_BATCH_SIZE] + + documents = self.loader.load_file_batch( + file_batch, repo_path_obj, workspace, project, branch, commit + ) + document_count += len(documents) + + if not documents: + continue + + chunks = self.splitter.split_documents(documents) + batch_chunk_count = len(chunks) + chunk_count += batch_chunk_count + + # Check chunk limit + if self.config.max_chunks_per_index > 0 and chunk_count > self.config.max_chunks_per_index: + self.collection_manager.delete_collection(temp_collection_name) + raise ValueError(f"Repository exceeds chunk limit: {chunk_count}+ chunks.") + + # Process and upsert + success, failed = self.point_ops.process_and_upsert_chunks( + chunks, temp_collection_name, workspace, project, branch + ) + successful_chunks += success + failed_chunks += failed + + logger.info( + f"Batch {batch_num}/{total_batches}: processed {len(documents)} files, " + f"{batch_chunk_count} chunks" + ) + + del documents + del chunks + + if batch_num % 5 == 0: + gc.collect() + + logger.info( + f"Streaming indexing complete: {document_count} files, " + f"{successful_chunks}/{chunk_count} chunks indexed ({failed_chunks} failed)" + ) + + # Verify and perform atomic swap + temp_info = self.point_ops.client.get_collection(temp_collection_name) + if temp_info.points_count == 0: + raise Exception("Temporary collection is empty after indexing") + + self._perform_atomic_swap( + alias_name, temp_collection_name, old_collection_exists + ) + + except Exception as e: + logger.error(f"Indexing failed: {e}") + self.collection_manager.delete_collection(temp_collection_name) + raise e + finally: + del existing_other_branch_points + gc.collect() + + self.stats_manager.store_metadata( + workspace, project, branch, commit, document_count, chunk_count + ) + + namespace = make_namespace(workspace, project, branch) + return IndexStats( + namespace=namespace, + document_count=document_count, + chunk_count=successful_chunks, + last_updated=datetime.now(timezone.utc).isoformat(), + workspace=workspace, + project=project, + branch=branch + ) + + def _perform_atomic_swap( + self, + alias_name: str, + temp_collection_name: str, + old_collection_exists: bool + ) -> None: + """Perform atomic alias swap with migration handling.""" + logger.info("Performing atomic alias swap...") + + is_direct_collection = ( + self.collection_manager.collection_exists(alias_name) and + not self.collection_manager.alias_exists(alias_name) + ) + + old_versioned_name = None + if old_collection_exists and not is_direct_collection: + old_versioned_name = self.collection_manager.resolve_alias(alias_name) + + try: + self.collection_manager.atomic_alias_swap( + alias_name, temp_collection_name, + old_collection_exists and not is_direct_collection + ) + except Exception as alias_err: + if is_direct_collection and "already exists" in str(alias_err).lower(): + logger.info("Migrating from direct collection to alias-based indexing...") + self.collection_manager.delete_collection(alias_name) + self.collection_manager.atomic_alias_swap(alias_name, temp_collection_name, False) + else: + raise alias_err + + if old_versioned_name and old_versioned_name != temp_collection_name: + self.collection_manager.delete_collection(old_versioned_name) + + +class FileOperations: + """Handles individual file update and delete operations.""" + + def __init__( + self, + client, + point_ops: PointOperations, + collection_manager: CollectionManager, + stats_manager: StatsManager, + splitter, + loader + ): + self.client = client + self.point_ops = point_ops + self.collection_manager = collection_manager + self.stats_manager = stats_manager + self.splitter = splitter + self.loader = loader + + def update_files( + self, + file_paths: List[str], + repo_base: str, + workspace: str, + project: str, + branch: str, + commit: str, + collection_name: str + ) -> IndexStats: + """Update specific files in the index (Delete Old -> Insert New).""" + logger.info(f"Updating {len(file_paths)} files in {workspace}/{project} for branch '{branch}'") + + repo_base_obj = Path(repo_base) + file_path_objs = [Path(fp) for fp in file_paths] + + self.collection_manager.ensure_collection_exists(collection_name) + + # Delete old chunks for these files and branch + logger.info(f"Purging existing vectors for {len(file_paths)} files in branch '{branch}'...") + self.client.delete( + collection_name=collection_name, + points_selector=Filter( + must=[ + FieldCondition(key="path", match=MatchAny(any=file_paths)), + FieldCondition(key="branch", match=MatchValue(value=branch)) + ] + ) + ) + + # Load and split new content + documents = self.loader.load_specific_files( + file_paths=file_path_objs, + repo_base=repo_base_obj, + workspace=workspace, + project=project, + branch=branch, + commit=commit + ) + + if not documents: + logger.warning("No documents loaded from provided paths.") + return self.stats_manager.get_project_stats(workspace, project, collection_name) + + chunks = self.splitter.split_documents(documents) + logger.info(f"Generated {len(chunks)} new chunks") + + # Process and upsert + self.point_ops.process_and_upsert_chunks( + chunks, collection_name, workspace, project, branch + ) + + logger.info(f"Successfully updated {len(chunks)} chunks for branch '{branch}'") + return self.stats_manager.get_project_stats(workspace, project, collection_name) + + def delete_files( + self, + file_paths: List[str], + workspace: str, + project: str, + branch: str, + collection_name: str + ) -> IndexStats: + """Delete specific files from the index for a specific branch.""" + logger.info(f"Deleting {len(file_paths)} files from {workspace}/{project} branch '{branch}'") + + try: + self.client.delete( + collection_name=collection_name, + points_selector=Filter( + must=[ + FieldCondition(key="path", match=MatchAny(any=file_paths)), + FieldCondition(key="branch", match=MatchValue(value=branch)) + ] + ) + ) + logger.info(f"Deleted {len(file_paths)} files from branch '{branch}'") + except Exception as e: + logger.warning(f"Error deleting files: {e}") + + return self.stats_manager.get_project_stats(workspace, project, collection_name) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py new file mode 100644 index 00000000..7197f2ad --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py @@ -0,0 +1,290 @@ +""" +Main RAG Index Manager. + +Composes all index management components and provides the public API. +""" + +import logging +from typing import Optional, List + +from llama_index.core import Settings +from qdrant_client import QdrantClient + +from ...models.config import RAGConfig, IndexStats +from ...utils.utils import make_namespace, make_project_namespace +from ..splitter import ASTCodeSplitter +from ..loader import DocumentLoader +from ..openrouter_embedding import OpenRouterEmbedding + +from .collection_manager import CollectionManager +from .branch_manager import BranchManager +from .point_operations import PointOperations +from .stats_manager import StatsManager +from .indexer import RepositoryIndexer, FileOperations + +logger = logging.getLogger(__name__) + + +class RAGIndexManager: + """Manage RAG indices for code repositories using Qdrant. + + This is the main entry point for all indexing operations. + """ + + def __init__(self, config: RAGConfig): + self.config = config + + # Qdrant client + self.qdrant_client = QdrantClient(url=config.qdrant_url) + logger.info(f"Connected to Qdrant at {config.qdrant_url}") + + # Embedding model + self.embed_model = OpenRouterEmbedding( + api_key=config.openrouter_api_key, + model=config.openrouter_model, + api_base=config.openrouter_base_url, + timeout=60.0, + max_retries=3, + expected_dim=config.embedding_dim + ) + + # Global settings + Settings.embed_model = self.embed_model + Settings.chunk_size = config.chunk_size + Settings.chunk_overlap = config.chunk_overlap + + # Splitter and loader + logger.info("Using ASTCodeSplitter for code chunking (tree-sitter query-based)") + self.splitter = ASTCodeSplitter( + max_chunk_size=config.chunk_size, + min_chunk_size=min(200, config.chunk_size // 4), + chunk_overlap=config.chunk_overlap, + parser_threshold=10 + ) + self.loader = DocumentLoader(config) + + # Component managers + self._collection_manager = CollectionManager( + self.qdrant_client, config.embedding_dim + ) + self._branch_manager = BranchManager(self.qdrant_client) + self._point_ops = PointOperations( + self.qdrant_client, self.embed_model, batch_size=50 + ) + self._stats_manager = StatsManager( + self.qdrant_client, config.qdrant_collection_prefix + ) + + # Higher-level operations + self._indexer = RepositoryIndexer( + config=config, + collection_manager=self._collection_manager, + branch_manager=self._branch_manager, + point_ops=self._point_ops, + stats_manager=self._stats_manager, + splitter=self.splitter, + loader=self.loader + ) + self._file_ops = FileOperations( + client=self.qdrant_client, + point_ops=self._point_ops, + collection_manager=self._collection_manager, + stats_manager=self._stats_manager, + splitter=self.splitter, + loader=self.loader + ) + + # Collection naming + + def _get_project_collection_name(self, workspace: str, project: str) -> str: + """Generate Qdrant collection name from workspace/project.""" + namespace = make_project_namespace(workspace, project) + return f"{self.config.qdrant_collection_prefix}_{namespace}" + + def _get_collection_name(self, workspace: str, project: str, branch: str) -> str: + """Generate collection name (DEPRECATED - use _get_project_collection_name).""" + namespace = make_namespace(workspace, project, branch) + return f"{self.config.qdrant_collection_prefix}_{namespace}" + + # Repository indexing + + def estimate_repository_size( + self, + repo_path: str, + exclude_patterns: Optional[List[str]] = None + ) -> tuple[int, int]: + """Estimate repository size (file count and chunk count).""" + return self._indexer.estimate_repository_size(repo_path, exclude_patterns) + + def index_repository( + self, + repo_path: str, + workspace: str, + project: str, + branch: str, + commit: str, + exclude_patterns: Optional[List[str]] = None + ) -> IndexStats: + """Index entire repository for a branch using atomic swap strategy.""" + alias_name = self._get_project_collection_name(workspace, project) + return self._indexer.index_repository( + repo_path=repo_path, + workspace=workspace, + project=project, + branch=branch, + commit=commit, + alias_name=alias_name, + exclude_patterns=exclude_patterns + ) + + # File operations + + def update_files( + self, + file_paths: List[str], + repo_base: str, + workspace: str, + project: str, + branch: str, + commit: str + ) -> IndexStats: + """Update specific files in the index (Delete Old -> Insert New).""" + collection_name = self._get_project_collection_name(workspace, project) + return self._file_ops.update_files( + file_paths=file_paths, + repo_base=repo_base, + workspace=workspace, + project=project, + branch=branch, + commit=commit, + collection_name=collection_name + ) + + def delete_files( + self, + file_paths: List[str], + workspace: str, + project: str, + branch: str + ) -> IndexStats: + """Delete specific files from the index for a specific branch.""" + collection_name = self._get_project_collection_name(workspace, project) + return self._file_ops.delete_files( + file_paths=file_paths, + workspace=workspace, + project=project, + branch=branch, + collection_name=collection_name + ) + + # Branch operations + + def delete_branch(self, workspace: str, project: str, branch: str) -> bool: + """Delete all points for a specific branch from the project collection.""" + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_manager.collection_exists(collection_name): + if not self._collection_manager.alias_exists(collection_name): + logger.warning(f"Collection {collection_name} does not exist") + return False + + return self._branch_manager.delete_branch_points(collection_name, branch) + + def get_branch_point_count(self, workspace: str, project: str, branch: str) -> int: + """Get the number of points for a specific branch.""" + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_manager.collection_exists(collection_name): + if not self._collection_manager.alias_exists(collection_name): + return 0 + + return self._branch_manager.get_branch_point_count(collection_name, branch) + + def get_indexed_branches(self, workspace: str, project: str) -> List[str]: + """Get list of branches that have points in the collection.""" + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_manager.collection_exists(collection_name): + if not self._collection_manager.alias_exists(collection_name): + return [] + + return self._branch_manager.get_indexed_branches(collection_name) + + # Index management + + def delete_index(self, workspace: str, project: str, branch: str): + """Delete branch data from project index.""" + if branch and branch != "*": + self.delete_branch(workspace, project, branch) + else: + self.delete_project_index(workspace, project) + + def delete_project_index(self, workspace: str, project: str): + """Delete entire project collection (all branches).""" + collection_name = self._get_project_collection_name(workspace, project) + namespace = make_project_namespace(workspace, project) + + logger.info(f"Deleting entire project index for {namespace}") + + try: + if self._collection_manager.alias_exists(collection_name): + actual_collection = self._collection_manager.resolve_alias(collection_name) + self._collection_manager.delete_alias(collection_name) + if actual_collection: + self._collection_manager.delete_collection(actual_collection) + else: + self._collection_manager.delete_collection(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + except Exception as e: + logger.warning(f"Failed to delete Qdrant collection: {e}") + + # Statistics + + def _get_index_stats(self, workspace: str, project: str, branch: str) -> IndexStats: + """Get statistics about a branch index (backward compatibility).""" + return self._get_branch_index_stats(workspace, project, branch) + + def _get_branch_index_stats(self, workspace: str, project: str, branch: str) -> IndexStats: + """Get statistics about a specific branch within a project collection.""" + collection_name = self._get_project_collection_name(workspace, project) + return self._stats_manager.get_branch_stats( + workspace, project, branch, collection_name + ) + + def _get_project_index_stats(self, workspace: str, project: str) -> IndexStats: + """Get statistics about a project's index (all branches combined).""" + collection_name = self._get_project_collection_name(workspace, project) + return self._stats_manager.get_project_stats(workspace, project, collection_name) + + def list_indices(self) -> List[IndexStats]: + """List all project indices with branch breakdown.""" + return self._stats_manager.list_all_indices( + self._collection_manager.alias_exists + ) + + # Legacy/compatibility methods + + def _ensure_collection_exists(self, collection_name: str): + """Ensure Qdrant collection exists (legacy compatibility).""" + self._collection_manager.ensure_collection_exists(collection_name) + + def _alias_exists(self, alias_name: str) -> bool: + """Check if an alias exists (legacy compatibility).""" + return self._collection_manager.alias_exists(alias_name) + + def _resolve_alias_to_collection(self, alias_name: str) -> Optional[str]: + """Resolve an alias to its collection (legacy compatibility).""" + return self._collection_manager.resolve_alias(alias_name) + + def _generate_point_id( + self, + workspace: str, + project: str, + branch: str, + path: str, + chunk_index: int + ) -> str: + """Generate deterministic point ID (legacy compatibility).""" + return PointOperations.generate_point_id( + workspace, project, branch, path, chunk_index + ) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py new file mode 100644 index 00000000..b9682bf0 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py @@ -0,0 +1,151 @@ +""" +Point operations for embedding and upserting vectors. + +Handles embedding generation, point creation, and batch upsert operations. +""" + +import logging +import uuid +from datetime import datetime, timezone +from typing import List, Dict, Tuple + +from llama_index.core.schema import TextNode +from qdrant_client import QdrantClient +from qdrant_client.models import PointStruct + +logger = logging.getLogger(__name__) + + +class PointOperations: + """Handles point embedding and upsert operations.""" + + def __init__(self, client: QdrantClient, embed_model, batch_size: int = 50): + self.client = client + self.embed_model = embed_model + self.batch_size = batch_size + + @staticmethod + def generate_point_id( + workspace: str, + project: str, + branch: str, + path: str, + chunk_index: int + ) -> str: + """Generate deterministic point ID for upsert (same content = same ID = replace).""" + key = f"{workspace}:{project}:{branch}:{path}:{chunk_index}" + return str(uuid.uuid5(uuid.NAMESPACE_DNS, key)) + + def prepare_chunks_for_embedding( + self, + chunks: List[TextNode], + workspace: str, + project: str, + branch: str + ) -> List[Tuple[str, TextNode]]: + """Prepare chunks with deterministic IDs for embedding. + + Returns list of (point_id, chunk) tuples. + """ + # Group chunks by file path + chunks_by_file: Dict[str, List[TextNode]] = {} + for chunk in chunks: + path = chunk.metadata.get("path", "unknown") + if path not in chunks_by_file: + chunks_by_file[path] = [] + chunks_by_file[path].append(chunk) + + # Assign deterministic IDs + chunk_data = [] + for path, file_chunks in chunks_by_file.items(): + for chunk_index, chunk in enumerate(file_chunks): + point_id = self.generate_point_id(workspace, project, branch, path, chunk_index) + chunk.metadata["indexed_at"] = datetime.now(timezone.utc).isoformat() + chunk_data.append((point_id, chunk)) + + return chunk_data + + def embed_and_create_points( + self, + chunk_data: List[Tuple[str, TextNode]] + ) -> List[PointStruct]: + """Embed chunks and create Qdrant points. + + Args: + chunk_data: List of (point_id, chunk) tuples + + Returns: + List of PointStruct ready for upsert + """ + if not chunk_data: + return [] + + # Batch embed all chunks at once + texts_to_embed = [chunk.text for _, chunk in chunk_data] + embeddings = self.embed_model.get_text_embedding_batch(texts_to_embed) + + # Build points with embeddings + points = [] + for (point_id, chunk), embedding in zip(chunk_data, embeddings): + points.append(PointStruct( + id=point_id, + vector=embedding, + payload={ + **chunk.metadata, + "text": chunk.text, + "_node_content": chunk.text, + } + )) + + return points + + def upsert_points( + self, + collection_name: str, + points: List[PointStruct] + ) -> Tuple[int, int]: + """Upsert points to collection in batches. + + Returns: + Tuple of (successful_count, failed_count) + """ + successful = 0 + failed = 0 + + for i in range(0, len(points), self.batch_size): + batch = points[i:i + self.batch_size] + try: + self.client.upsert( + collection_name=collection_name, + points=batch + ) + successful += len(batch) + except Exception as e: + logger.error(f"Failed to upsert batch starting at {i}: {e}") + failed += len(batch) + + return successful, failed + + def process_and_upsert_chunks( + self, + chunks: List[TextNode], + collection_name: str, + workspace: str, + project: str, + branch: str + ) -> Tuple[int, int]: + """Full pipeline: prepare, embed, and upsert chunks. + + Returns: + Tuple of (successful_count, failed_count) + """ + # Prepare chunks with IDs + chunk_data = self.prepare_chunks_for_embedding( + chunks, workspace, project, branch + ) + + # Embed and create points + points = self.embed_and_create_points(chunk_data) + + # Upsert to collection + return self.upsert_points(collection_name, points) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py new file mode 100644 index 00000000..226f021f --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py @@ -0,0 +1,156 @@ +""" +Index statistics and metadata operations. +""" + +import logging +from datetime import datetime, timezone +from typing import List, Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue + +from ...models.config import IndexStats +from ...utils.utils import make_namespace, make_project_namespace + +logger = logging.getLogger(__name__) + + +class StatsManager: + """Manages index statistics and metadata.""" + + def __init__(self, client: QdrantClient, collection_prefix: str): + self.client = client + self.collection_prefix = collection_prefix + + def get_branch_stats( + self, + workspace: str, + project: str, + branch: str, + collection_name: str + ) -> IndexStats: + """Get statistics about a specific branch within a project collection.""" + namespace = make_namespace(workspace, project, branch) + + try: + count_result = self.client.count( + collection_name=collection_name, + count_filter=Filter( + must=[ + FieldCondition( + key="branch", + match=MatchValue(value=branch) + ) + ] + ) + ) + chunk_count = count_result.count + + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=chunk_count, + last_updated=datetime.now(timezone.utc).isoformat(), + workspace=workspace, + project=project, + branch=branch + ) + except Exception: + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=0, + last_updated="", + workspace=workspace, + project=project, + branch=branch + ) + + def get_project_stats( + self, + workspace: str, + project: str, + collection_name: str + ) -> IndexStats: + """Get statistics about a project's index (all branches combined).""" + namespace = make_project_namespace(workspace, project) + + try: + collection_info = self.client.get_collection(collection_name) + chunk_count = collection_info.points_count + + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=chunk_count, + last_updated=datetime.now(timezone.utc).isoformat(), + workspace=workspace, + project=project, + branch="*" + ) + except Exception: + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=0, + last_updated="", + workspace=workspace, + project=project, + branch="*" + ) + + def list_all_indices(self, alias_checker) -> List[IndexStats]: + """List all project indices with branch breakdown. + + Args: + alias_checker: Function to check if name is an alias + """ + indices = [] + collections = self.client.get_collections().collections + + for collection in collections: + if collection.name.startswith(f"{self.collection_prefix}_"): + namespace = collection.name[len(f"{self.collection_prefix}_"):] + parts = namespace.split("__") + + if len(parts) == 2: + # New format: workspace__project + workspace, project = parts + stats = self.get_project_stats( + workspace, project, collection.name + ) + indices.append(stats) + elif len(parts) == 3: + # Legacy format: workspace__project__branch + workspace, project, branch = parts + stats = self.get_branch_stats( + workspace, project, branch, collection.name + ) + indices.append(stats) + + return indices + + def store_metadata( + self, + workspace: str, + project: str, + branch: str, + commit: str, + document_count: int, + chunk_count: int + ) -> None: + """Store/log metadata for an indexing operation.""" + namespace = make_namespace(workspace, project, branch) + + metadata = { + "namespace": namespace, + "workspace": workspace, + "project": project, + "branch": branch, + "commit": commit, + "document_count": document_count, + "chunk_count": chunk_count, + "last_updated": datetime.now(timezone.utc).isoformat(), + } + + logger.info(f"Indexed {namespace}: {document_count} docs, {chunk_count} chunks") diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py deleted file mode 100644 index 349cdf23..00000000 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py +++ /dev/null @@ -1,455 +0,0 @@ -""" -Semantic Code Splitter - Intelligent code splitting using LangChain's language-aware splitters. - -This module provides smart code chunking that: -1. Uses LangChain's RecursiveCharacterTextSplitter with language-specific separators -2. Supports 25+ programming languages out of the box -3. Enriches metadata with semantic information (function names, imports, etc.) -4. Falls back gracefully for unsupported languages -""" - -import re -import hashlib -import logging -from typing import List, Dict, Any, Optional -from dataclasses import dataclass, field -from enum import Enum - -from langchain_text_splitters import RecursiveCharacterTextSplitter, Language -from llama_index.core.schema import Document, TextNode - -logger = logging.getLogger(__name__) - - -class ChunkType(Enum): - """Type of code chunk for semantic understanding""" - CLASS = "class" - FUNCTION = "function" - METHOD = "method" - INTERFACE = "interface" - MODULE = "module" - IMPORTS = "imports" - CONSTANTS = "constants" - DOCUMENTATION = "documentation" - CONFIG = "config" - MIXED = "mixed" - UNKNOWN = "unknown" - - -@dataclass -class CodeBlock: - """Represents a logical block of code""" - content: str - chunk_type: ChunkType - name: Optional[str] = None - parent_name: Optional[str] = None - start_line: int = 0 - end_line: int = 0 - imports: List[str] = field(default_factory=list) - docstring: Optional[str] = None - signature: Optional[str] = None - - -# Map internal language names to LangChain Language enum -LANGUAGE_MAP: Dict[str, Language] = { - 'python': Language.PYTHON, - 'java': Language.JAVA, - 'kotlin': Language.KOTLIN, - 'javascript': Language.JS, - 'typescript': Language.TS, - 'go': Language.GO, - 'rust': Language.RUST, - 'php': Language.PHP, - 'ruby': Language.RUBY, - 'scala': Language.SCALA, - 'swift': Language.SWIFT, - 'c': Language.C, - 'cpp': Language.CPP, - 'csharp': Language.CSHARP, - 'markdown': Language.MARKDOWN, - 'html': Language.HTML, - 'latex': Language.LATEX, - 'rst': Language.RST, - 'lua': Language.LUA, - 'perl': Language.PERL, - 'haskell': Language.HASKELL, - 'solidity': Language.SOL, - 'proto': Language.PROTO, - 'cobol': Language.COBOL, -} - -# Patterns for metadata extraction -METADATA_PATTERNS = { - 'python': { - 'class': re.compile(r'^class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'^(?:async\s+)?def\s+(\w+)\s*\(', re.MULTILINE), - 'import': re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$', re.MULTILINE), - 'docstring': re.compile(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\''), - }, - 'java': { - 'class': re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected)\s+(?:static\s+)?(?:final\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - 'import': re.compile(r'^import\s+[\w.*]+;', re.MULTILINE), - }, - 'javascript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - 'arrow': re.compile(r'(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>', re.MULTILINE), - 'import': re.compile(r'^import\s+.*?from\s+[\'"]([^\'"]+)[\'"]', re.MULTILINE), - }, - 'typescript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:export\s+)?interface\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - 'type': re.compile(r'(?:export\s+)?type\s+(\w+)', re.MULTILINE), - 'import': re.compile(r'^import\s+.*?from\s+[\'"]([^\'"]+)[\'"]', re.MULTILINE), - }, - 'go': { - 'function': re.compile(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', re.MULTILINE), - 'struct': re.compile(r'^type\s+(\w+)\s+struct\s*\{', re.MULTILINE), - 'interface': re.compile(r'^type\s+(\w+)\s+interface\s*\{', re.MULTILINE), - }, - 'rust': { - 'function': re.compile(r'^(?:pub\s+)?(?:async\s+)?fn\s+(\w+)', re.MULTILINE), - 'struct': re.compile(r'^(?:pub\s+)?struct\s+(\w+)', re.MULTILINE), - 'impl': re.compile(r'^impl(?:<[^>]+>)?\s+(?:\w+\s+for\s+)?(\w+)', re.MULTILINE), - 'trait': re.compile(r'^(?:pub\s+)?trait\s+(\w+)', re.MULTILINE), - }, - 'php': { - 'class': re.compile(r'(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'interface\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:public|private|protected|static|\s)*function\s+(\w+)\s*\(', re.MULTILINE), - }, - 'csharp': { - 'class': re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|sealed\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected)\s+(?:static\s+)?(?:async\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - }, -} - - -class SemanticCodeSplitter: - """ - Intelligent code splitter using LangChain's language-aware text splitters. - - Features: - - Uses LangChain's RecursiveCharacterTextSplitter with language-specific separators - - Supports 25+ programming languages (Python, Java, JS/TS, Go, Rust, PHP, etc.) - - Enriches chunks with semantic metadata (function names, classes, imports) - - Graceful fallback for unsupported languages - """ - - DEFAULT_CHUNK_SIZE = 1500 - DEFAULT_CHUNK_OVERLAP = 200 - DEFAULT_MIN_CHUNK_SIZE = 100 - - def __init__( - self, - max_chunk_size: int = DEFAULT_CHUNK_SIZE, - min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, - overlap: int = DEFAULT_CHUNK_OVERLAP - ): - self.max_chunk_size = max_chunk_size - self.min_chunk_size = min_chunk_size - self.overlap = overlap - - # Cache splitters for reuse - self._splitter_cache: Dict[str, RecursiveCharacterTextSplitter] = {} - - # Default splitter for unknown languages - self._default_splitter = RecursiveCharacterTextSplitter( - chunk_size=max_chunk_size, - chunk_overlap=overlap, - length_function=len, - is_separator_regex=False, - ) - - @staticmethod - def _make_deterministic_id(namespace: str, path: str, chunk_index: int) -> str: - """Generate deterministic chunk ID for idempotent indexing""" - key = f"{namespace}:{path}:{chunk_index}" - return hashlib.sha256(key.encode()).hexdigest()[:32] - - def _get_splitter(self, language: str) -> RecursiveCharacterTextSplitter: - """Get or create a language-specific splitter""" - if language in self._splitter_cache: - return self._splitter_cache[language] - - lang_enum = LANGUAGE_MAP.get(language.lower()) - - if lang_enum: - splitter = RecursiveCharacterTextSplitter.from_language( - language=lang_enum, - chunk_size=self.max_chunk_size, - chunk_overlap=self.overlap, - ) - self._splitter_cache[language] = splitter - return splitter - - return self._default_splitter - - def split_documents(self, documents: List[Document]) -> List[TextNode]: - """Split documents into semantic chunks with enriched metadata""" - return list(self.iter_split_documents(documents)) - - def iter_split_documents(self, documents: List[Document]): - """Generator that yields chunks one at a time for memory efficiency""" - for doc in documents: - language = doc.metadata.get("language", "text") - path = doc.metadata.get("path", "unknown") - - try: - for node in self._split_document(doc, language): - yield node - except Exception as e: - logger.warning(f"Splitting failed for {path}: {e}, using fallback") - for node in self._fallback_split(doc): - yield node - - def _split_document(self, doc: Document, language: str) -> List[TextNode]: - """Split a single document using language-aware splitter""" - text = doc.text - - if not text or not text.strip(): - return [] - - # Get language-specific splitter - splitter = self._get_splitter(language) - - # Split the text - chunks = splitter.split_text(text) - - # Filter empty chunks and convert to nodes with metadata - nodes = [] - text_offset = 0 - - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - - # Skip very small chunks unless they're standalone - if len(chunk.strip()) < self.min_chunk_size and len(chunks) > 1: - # Try to find and merge with adjacent chunk - continue - - # Calculate approximate line numbers - start_line = text[:text_offset].count('\n') + 1 if text_offset > 0 else 1 - chunk_pos = text.find(chunk, text_offset) - if chunk_pos >= 0: - text_offset = chunk_pos + len(chunk) - end_line = start_line + chunk.count('\n') - - # Extract semantic metadata - metadata = self._extract_metadata(chunk, language, doc.metadata) - metadata.update({ - 'chunk_index': i, - 'total_chunks': len(chunks), - 'start_line': start_line, - 'end_line': end_line, - }) - - chunk_id = self._make_deterministic_id( - metadata.get('namespace', ''), - metadata.get('path', ''), - i - ) - node = TextNode( - id_=chunk_id, - text=chunk, - metadata=metadata - ) - nodes.append(node) - - return nodes - - def _extract_metadata( - self, - chunk: str, - language: str, - base_metadata: Dict[str, Any] - ) -> Dict[str, Any]: - """Extract semantic metadata from a code chunk""" - metadata = dict(base_metadata) - - # Determine chunk type and extract names - chunk_type = ChunkType.MIXED - names = [] - imports = [] - - patterns = METADATA_PATTERNS.get(language.lower(), {}) - - # Check for classes - if 'class' in patterns: - matches = patterns['class'].findall(chunk) - if matches: - chunk_type = ChunkType.CLASS - names.extend(matches) - - # Check for interfaces - if 'interface' in patterns: - matches = patterns['interface'].findall(chunk) - if matches: - chunk_type = ChunkType.INTERFACE - names.extend(matches) - - # Check for functions/methods - if chunk_type == ChunkType.MIXED: - for key in ['function', 'method', 'arrow']: - if key in patterns: - matches = patterns[key].findall(chunk) - if matches: - chunk_type = ChunkType.FUNCTION - names.extend(matches) - break - - # Check for imports - if 'import' in patterns: - import_matches = patterns['import'].findall(chunk) - if import_matches: - imports = import_matches[:10] # Limit - if not names: # Pure import block - chunk_type = ChunkType.IMPORTS - - # Check for documentation files - if language in ('markdown', 'rst', 'text'): - chunk_type = ChunkType.DOCUMENTATION - - # Check for config files - if language in ('json', 'yaml', 'yml', 'toml', 'xml', 'ini'): - chunk_type = ChunkType.CONFIG - - # Extract docstring if present - docstring = self._extract_docstring(chunk, language) - - # Extract function signature - signature = self._extract_signature(chunk, language) - - # Update metadata - metadata['chunk_type'] = chunk_type.value - - if names: - metadata['semantic_names'] = names[:5] # Limit to 5 names - metadata['primary_name'] = names[0] - - if imports: - metadata['imports'] = imports - - if docstring: - metadata['docstring'] = docstring[:500] # Limit size - - if signature: - metadata['signature'] = signature - - return metadata - - def _extract_docstring(self, chunk: str, language: str) -> Optional[str]: - """Extract docstring from code chunk""" - if language == 'python': - # Python docstrings - match = re.search(r'"""([\s\S]*?)"""|\'\'\'([\s\S]*?)\'\'\'', chunk) - if match: - return (match.group(1) or match.group(2)).strip() - - elif language in ('javascript', 'typescript', 'java', 'csharp', 'php', 'go'): - # JSDoc / JavaDoc style - match = re.search(r'/\*\*([\s\S]*?)\*/', chunk) - if match: - # Clean up the comment - doc = match.group(1) - doc = re.sub(r'^\s*\*\s?', '', doc, flags=re.MULTILINE) - return doc.strip() - - return None - - def _extract_signature(self, chunk: str, language: str) -> Optional[str]: - """Extract function/method signature from code chunk""" - lines = chunk.split('\n') - - for line in lines[:10]: # Check first 10 lines - line = line.strip() - - if language == 'python': - if line.startswith(('def ', 'async def ')): - # Get full signature including multi-line params - sig = line - if ')' not in sig: - # Multi-line signature - idx = lines.index(line.strip()) if line.strip() in lines else -1 - if idx >= 0: - for next_line in lines[idx+1:idx+5]: - sig += ' ' + next_line.strip() - if ')' in next_line: - break - return sig.split(':')[0] + ':' - - elif language in ('java', 'csharp', 'kotlin'): - if any(kw in line for kw in ['public ', 'private ', 'protected ', 'internal ']): - if '(' in line and not line.startswith('//'): - return line.split('{')[0].strip() - - elif language in ('javascript', 'typescript'): - if line.startswith(('function ', 'async function ')): - return line.split('{')[0].strip() - if '=>' in line and '(' in line: - return line.split('=>')[0].strip() + ' =>' - - elif language == 'go': - if line.startswith('func '): - return line.split('{')[0].strip() - - elif language == 'rust': - if line.startswith(('fn ', 'pub fn ', 'async fn ', 'pub async fn ')): - return line.split('{')[0].strip() - - return None - - def _fallback_split(self, doc: Document) -> List[TextNode]: - """Fallback splitting for problematic documents""" - text = doc.text - - if not text or not text.strip(): - return [] - - # Use default splitter - chunks = self._default_splitter.split_text(text) - - nodes = [] - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - - # Truncate if too large - if len(chunk) > 30000: - chunk = chunk[:30000] - - metadata = dict(doc.metadata) - metadata['chunk_index'] = i - metadata['total_chunks'] = len(chunks) - metadata['chunk_type'] = 'fallback' - - chunk_id = self._make_deterministic_id( - metadata.get('namespace', ''), - metadata.get('path', ''), - i - ) - nodes.append(TextNode( - id_=chunk_id, - text=chunk, - metadata=metadata - )) - - return nodes - - @staticmethod - def get_supported_languages() -> List[str]: - """Return list of supported languages""" - return list(LANGUAGE_MAP.keys()) - - @staticmethod - def get_separators_for_language(language: str) -> Optional[List[str]]: - """Get the separators used for a specific language""" - lang_enum = LANGUAGE_MAP.get(language.lower()) - if lang_enum: - return RecursiveCharacterTextSplitter.get_separators_for_language(lang_enum) - return None diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py new file mode 100644 index 00000000..74a91cf0 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py @@ -0,0 +1,53 @@ +""" +AST-based code splitter module using Tree-sitter. + +Provides semantic code chunking with: +- Tree-sitter query-based extraction (.scm files) +- Fallback to manual AST traversal +- RecursiveCharacterTextSplitter for oversized chunks +- Rich metadata extraction for RAG +""" + +from .splitter import ASTCodeSplitter, ASTChunk, generate_deterministic_id, compute_file_hash +from .languages import ( + get_language_from_path, + get_treesitter_name, + is_ast_supported, + get_supported_languages, + EXTENSION_TO_LANGUAGE, + AST_SUPPORTED_LANGUAGES, + LANGUAGE_TO_TREESITTER, +) +from .metadata import ContentType, ChunkMetadata, MetadataExtractor +from .tree_parser import TreeSitterParser, get_parser +from .query_runner import QueryRunner, QueryMatch, CapturedNode, get_query_runner + +__all__ = [ + # Main splitter + "ASTCodeSplitter", + "ASTChunk", + "generate_deterministic_id", + "compute_file_hash", + + # Languages + "get_language_from_path", + "get_treesitter_name", + "is_ast_supported", + "get_supported_languages", + "EXTENSION_TO_LANGUAGE", + "AST_SUPPORTED_LANGUAGES", + "LANGUAGE_TO_TREESITTER", + + # Metadata + "ContentType", + "ChunkMetadata", + "MetadataExtractor", + + # Tree-sitter + "TreeSitterParser", + "get_parser", + "QueryRunner", + "QueryMatch", + "CapturedNode", + "get_query_runner", +] diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py new file mode 100644 index 00000000..f4758ff1 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py @@ -0,0 +1,139 @@ +""" +Language detection and mapping for AST-based code splitting. + +Maps file extensions to tree-sitter language names and LangChain Language enum. +""" + +from pathlib import Path +from typing import Dict, Optional, Set +from langchain_text_splitters import Language + + +# Map file extensions to LangChain Language enum (for RecursiveCharacterTextSplitter fallback) +EXTENSION_TO_LANGUAGE: Dict[str, Language] = { + # Python + '.py': Language.PYTHON, + '.pyw': Language.PYTHON, + '.pyi': Language.PYTHON, + + # Java/JVM + '.java': Language.JAVA, + '.kt': Language.KOTLIN, + '.kts': Language.KOTLIN, + '.scala': Language.SCALA, + + # JavaScript/TypeScript + '.js': Language.JS, + '.jsx': Language.JS, + '.mjs': Language.JS, + '.cjs': Language.JS, + '.ts': Language.TS, + '.tsx': Language.TS, + + # Systems languages + '.go': Language.GO, + '.rs': Language.RUST, + '.c': Language.C, + '.h': Language.C, + '.cpp': Language.CPP, + '.cc': Language.CPP, + '.cxx': Language.CPP, + '.hpp': Language.CPP, + '.hxx': Language.CPP, + '.cs': Language.CSHARP, + + # Web/Scripting + '.php': Language.PHP, + '.phtml': Language.PHP, + '.php3': Language.PHP, + '.php4': Language.PHP, + '.php5': Language.PHP, + '.phps': Language.PHP, + '.inc': Language.PHP, + '.rb': Language.RUBY, + '.erb': Language.RUBY, + '.lua': Language.LUA, + '.pl': Language.PERL, + '.pm': Language.PERL, + '.swift': Language.SWIFT, + + # Markup/Config + '.md': Language.MARKDOWN, + '.markdown': Language.MARKDOWN, + '.html': Language.HTML, + '.htm': Language.HTML, + '.rst': Language.RST, + '.tex': Language.LATEX, + '.proto': Language.PROTO, + '.sol': Language.SOL, + '.hs': Language.HASKELL, + '.cob': Language.COBOL, + '.cbl': Language.COBOL, + '.xml': Language.HTML, +} + +# Languages that support full AST parsing via tree-sitter +AST_SUPPORTED_LANGUAGES: Set[Language] = { + Language.PYTHON, Language.JAVA, Language.KOTLIN, Language.JS, Language.TS, + Language.GO, Language.RUST, Language.C, Language.CPP, Language.CSHARP, + Language.PHP, Language.RUBY, Language.SCALA, Language.LUA, Language.PERL, + Language.SWIFT, Language.HASKELL, Language.COBOL +} + +# Map LangChain Language enum to tree-sitter language name +LANGUAGE_TO_TREESITTER: Dict[Language, str] = { + Language.PYTHON: 'python', + Language.JAVA: 'java', + Language.KOTLIN: 'kotlin', + Language.JS: 'javascript', + Language.TS: 'typescript', + Language.GO: 'go', + Language.RUST: 'rust', + Language.C: 'c', + Language.CPP: 'cpp', + Language.CSHARP: 'c_sharp', + Language.PHP: 'php', + Language.RUBY: 'ruby', + Language.SCALA: 'scala', + Language.LUA: 'lua', + Language.PERL: 'perl', + Language.SWIFT: 'swift', + Language.HASKELL: 'haskell', +} + +# Map tree-sitter language name to module info: (module_name, function_name) +TREESITTER_MODULES: Dict[str, tuple] = { + 'python': ('tree_sitter_python', 'language'), + 'java': ('tree_sitter_java', 'language'), + 'javascript': ('tree_sitter_javascript', 'language'), + 'typescript': ('tree_sitter_typescript', 'language_typescript'), + 'go': ('tree_sitter_go', 'language'), + 'rust': ('tree_sitter_rust', 'language'), + 'c': ('tree_sitter_c', 'language'), + 'cpp': ('tree_sitter_cpp', 'language'), + 'c_sharp': ('tree_sitter_c_sharp', 'language'), + 'ruby': ('tree_sitter_ruby', 'language'), + 'php': ('tree_sitter_php', 'language_php'), +} + + +def get_language_from_path(path: str) -> Optional[Language]: + """Determine LangChain Language enum from file path.""" + ext = Path(path).suffix.lower() + return EXTENSION_TO_LANGUAGE.get(ext) + + +def get_treesitter_name(language: Language) -> Optional[str]: + """Get tree-sitter language name from LangChain Language enum.""" + return LANGUAGE_TO_TREESITTER.get(language) + + +def is_ast_supported(path: str) -> bool: + """Check if AST parsing is supported for a file.""" + language = get_language_from_path(path) + return language is not None and language in AST_SUPPORTED_LANGUAGES + + +def get_supported_languages() -> list: + """Return list of languages with AST support.""" + return list(LANGUAGE_TO_TREESITTER.values()) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py new file mode 100644 index 00000000..6a2fccd7 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py @@ -0,0 +1,339 @@ +""" +Metadata extraction from AST chunks. + +Extracts semantic metadata like docstrings, signatures, inheritance info +from parsed code chunks for improved RAG retrieval. +""" + +import re +import logging +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field +from enum import Enum + +logger = logging.getLogger(__name__) + + +class ContentType(Enum): + """Content type as determined by AST parsing.""" + FUNCTIONS_CLASSES = "functions_classes" + SIMPLIFIED_CODE = "simplified_code" + FALLBACK = "fallback" + OVERSIZED_SPLIT = "oversized_split" + + +@dataclass +class ChunkMetadata: + """Structured metadata for a code chunk.""" + content_type: ContentType + language: str + path: str + semantic_names: List[str] = field(default_factory=list) + parent_context: List[str] = field(default_factory=list) + docstring: Optional[str] = None + signature: Optional[str] = None + start_line: int = 0 + end_line: int = 0 + node_type: Optional[str] = None + # Class-level metadata + extends: List[str] = field(default_factory=list) + implements: List[str] = field(default_factory=list) + # File-level metadata + imports: List[str] = field(default_factory=list) + namespace: Optional[str] = None + + +class MetadataExtractor: + """ + Extract semantic metadata from code chunks. + + Uses both AST-derived information and regex fallbacks for + comprehensive metadata extraction. + """ + + # Comment prefixes by language + COMMENT_PREFIX: Dict[str, str] = { + 'python': '#', + 'javascript': '//', + 'typescript': '//', + 'java': '//', + 'kotlin': '//', + 'go': '//', + 'rust': '//', + 'c': '//', + 'cpp': '//', + 'c_sharp': '//', + 'php': '//', + 'ruby': '#', + 'lua': '--', + 'perl': '#', + 'scala': '//', + } + + def extract_docstring(self, content: str, language: str) -> Optional[str]: + """Extract docstring from code chunk.""" + if language == 'python': + match = re.search(r'"""([\s\S]*?)"""|\'\'\'([\s\S]*?)\'\'\'', content) + if match: + return (match.group(1) or match.group(2)).strip() + + elif language in ('javascript', 'typescript', 'java', 'kotlin', + 'c_sharp', 'php', 'go', 'scala', 'c', 'cpp'): + # JSDoc / JavaDoc style + match = re.search(r'/\*\*([\s\S]*?)\*/', content) + if match: + doc = match.group(1) + doc = re.sub(r'^\s*\*\s?', '', doc, flags=re.MULTILINE) + return doc.strip() + + elif language == 'rust': + # Rust doc comments + lines = [] + for line in content.split('\n'): + stripped = line.strip() + if stripped.startswith('///'): + lines.append(stripped[3:].strip()) + elif stripped.startswith('//!'): + lines.append(stripped[3:].strip()) + elif lines: + break + if lines: + return '\n'.join(lines) + + return None + + def extract_signature(self, content: str, language: str) -> Optional[str]: + """Extract function/method signature from code chunk.""" + lines = content.split('\n') + + for line in lines[:15]: + line = line.strip() + + if language == 'python': + if line.startswith(('def ', 'async def ', 'class ')): + sig = line + if line.startswith('class ') and ':' in line: + return line.split(':')[0] + ':' + if ')' not in sig and ':' not in sig: + idx = next((i for i, l in enumerate(lines) if l.strip() == line), -1) + if idx >= 0: + for next_line in lines[idx+1:idx+5]: + sig += ' ' + next_line.strip() + if ')' in next_line: + break + if ':' in sig: + return sig.split(':')[0] + ':' + return sig + + elif language in ('java', 'kotlin', 'c_sharp'): + if any(kw in line for kw in ['public ', 'private ', 'protected ', 'internal ', 'fun ']): + if '(' in line and not line.startswith('//'): + return line.split('{')[0].strip() + + elif language in ('javascript', 'typescript'): + if line.startswith(('function ', 'async function ', 'class ')): + return line.split('{')[0].strip() + if '=>' in line and '(' in line: + return line.split('=>')[0].strip() + ' =>' + + elif language == 'go': + if line.startswith('func ') or line.startswith('type '): + return line.split('{')[0].strip() + + elif language == 'rust': + if line.startswith(('fn ', 'pub fn ', 'async fn ', 'pub async fn ', + 'impl ', 'struct ', 'trait ', 'enum ')): + return line.split('{')[0].strip() + + elif language == 'php': + if 'function ' in line and '(' in line: + return line.split('{')[0].strip() + if line.startswith('class ') or line.startswith('interface '): + return line.split('{')[0].strip() + + return None + + def extract_names_from_content(self, content: str, language: str) -> List[str]: + """Extract semantic names (function/class names) using regex patterns.""" + patterns = self._get_name_patterns(language) + names = [] + + for pattern in patterns: + matches = pattern.findall(content) + names.extend(matches) + + # Deduplicate while preserving order + seen = set() + unique_names = [] + for name in names: + if name not in seen: + seen.add(name) + unique_names.append(name) + + return unique_names[:10] # Limit to 10 names + + def _get_name_patterns(self, language: str) -> List[re.Pattern]: + """Get regex patterns for extracting names by language.""" + patterns = { + 'python': [ + re.compile(r'^class\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:async\s+)?def\s+(\w+)\s*\(', re.MULTILINE), + ], + 'java': [ + re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public|private|protected)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), + ], + 'javascript': [ + re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), + re.compile(r'(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>', re.MULTILINE), + ], + 'typescript': [ + re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:export\s+)?interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), + re.compile(r'(?:export\s+)?type\s+(\w+)', re.MULTILINE), + ], + 'go': [ + re.compile(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', re.MULTILINE), + re.compile(r'^type\s+(\w+)\s+(?:struct|interface)\s*\{', re.MULTILINE), + ], + 'rust': [ + re.compile(r'^(?:pub\s+)?(?:async\s+)?fn\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:pub\s+)?struct\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:pub\s+)?trait\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:pub\s+)?enum\s+(\w+)', re.MULTILINE), + ], + 'php': [ + re.compile(r'(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public|private|protected|static|\s)*function\s+(\w+)\s*\(', re.MULTILINE), + ], + 'c_sharp': [ + re.compile(r'(?:public\s+|private\s+|internal\s+)?(?:abstract\s+|sealed\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public|private|protected|internal)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), + ], + } + return patterns.get(language, []) + + def extract_inheritance(self, content: str, language: str) -> Dict[str, List[str]]: + """Extract inheritance information (extends, implements).""" + result = {'extends': [], 'implements': [], 'imports': []} + + patterns = self._get_inheritance_patterns(language) + + if 'extends' in patterns: + match = patterns['extends'].search(content) + if match: + extends = match.group(1).strip() + result['extends'] = [e.strip() for e in extends.split(',') if e.strip()] + + if 'implements' in patterns: + match = patterns['implements'].search(content) + if match: + implements = match.group(1).strip() + result['implements'] = [i.strip() for i in implements.split(',') if i.strip()] + + for key in ('import', 'use', 'using', 'require'): + if key in patterns: + matches = patterns[key].findall(content) + for m in matches: + if isinstance(m, tuple): + result['imports'].extend([x.strip() for x in m if x and x.strip()]) + else: + result['imports'].append(m.strip()) + + # Limit imports + result['imports'] = result['imports'][:20] + + return result + + def _get_inheritance_patterns(self, language: str) -> Dict[str, re.Pattern]: + """Get regex patterns for inheritance extraction.""" + patterns = { + 'python': { + 'extends': re.compile(r'class\s+\w+\s*\(\s*([\w.,\s]+)\s*\)\s*:', re.MULTILINE), + 'import': re.compile(r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s*]+)', re.MULTILINE), + }, + 'java': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), + 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), + 'import': re.compile(r'^import\s+([\w.]+(?:\.\*)?);', re.MULTILINE), + }, + 'typescript': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), + 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), + 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), + }, + 'javascript': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), + 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), + 'require': re.compile(r'require\s*\(\s*["\']([^"\']+)["\']\s*\)', re.MULTILINE), + }, + 'php': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w\\]+)', re.MULTILINE), + 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w\\]+)?\s+implements\s+([\w\\,\s]+)', re.MULTILINE), + 'use': re.compile(r'^use\s+([\w\\]+)(?:\s+as\s+\w+)?;', re.MULTILINE), + }, + 'c_sharp': { + 'extends': re.compile(r'class\s+\w+\s*:\s*([\w.]+)', re.MULTILINE), + 'using': re.compile(r'^using\s+([\w.]+);', re.MULTILINE), + }, + 'go': { + 'import': re.compile(r'^import\s+(?:\(\s*)?"([^"]+)"', re.MULTILINE), + }, + 'rust': { + 'use': re.compile(r'^use\s+([\w:]+(?:::\{[^}]+\})?);', re.MULTILINE), + }, + } + return patterns.get(language, {}) + + def get_comment_prefix(self, language: str) -> str: + """Get comment prefix for a language.""" + return self.COMMENT_PREFIX.get(language, '//') + + def build_metadata_dict( + self, + chunk_metadata: ChunkMetadata, + base_metadata: Dict[str, Any] + ) -> Dict[str, Any]: + """Build final metadata dictionary from ChunkMetadata.""" + metadata = dict(base_metadata) + + metadata['content_type'] = chunk_metadata.content_type.value + metadata['node_type'] = chunk_metadata.node_type + metadata['start_line'] = chunk_metadata.start_line + metadata['end_line'] = chunk_metadata.end_line + + if chunk_metadata.parent_context: + metadata['parent_context'] = chunk_metadata.parent_context + metadata['parent_class'] = chunk_metadata.parent_context[-1] + full_path_parts = chunk_metadata.parent_context + chunk_metadata.semantic_names[:1] + metadata['full_path'] = '.'.join(full_path_parts) + + if chunk_metadata.semantic_names: + metadata['semantic_names'] = chunk_metadata.semantic_names + metadata['primary_name'] = chunk_metadata.semantic_names[0] + + if chunk_metadata.docstring: + metadata['docstring'] = chunk_metadata.docstring[:500] + + if chunk_metadata.signature: + metadata['signature'] = chunk_metadata.signature + + if chunk_metadata.extends: + metadata['extends'] = chunk_metadata.extends + metadata['parent_types'] = chunk_metadata.extends + + if chunk_metadata.implements: + metadata['implements'] = chunk_metadata.implements + + if chunk_metadata.imports: + metadata['imports'] = chunk_metadata.imports + + if chunk_metadata.namespace: + metadata['namespace'] = chunk_metadata.namespace + + return metadata diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm new file mode 100644 index 00000000..4a6d7270 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm @@ -0,0 +1,56 @@ +; C# tree-sitter queries for AST-based code splitting + +; Using directives +(using_directive) @using + +; Namespace declarations +(namespace_declaration + name: (identifier) @name) @definition.namespace + +(file_scoped_namespace_declaration + name: (identifier) @name) @definition.namespace + +; Class declarations +(class_declaration + name: (identifier) @name) @definition.class + +; Struct declarations +(struct_declaration + name: (identifier) @name) @definition.struct + +; Interface declarations +(interface_declaration + name: (identifier) @name) @definition.interface + +; Enum declarations +(enum_declaration + name: (identifier) @name) @definition.enum + +; Record declarations +(record_declaration + name: (identifier) @name) @definition.record + +; Delegate declarations +(delegate_declaration + name: (identifier) @name) @definition.delegate + +; Method declarations +(method_declaration + name: (identifier) @name) @definition.method + +; Constructor declarations +(constructor_declaration + name: (identifier) @name) @definition.constructor + +; Property declarations +(property_declaration + name: (identifier) @name) @definition.property + +; Field declarations +(field_declaration) @definition.field + +; Event declarations +(event_declaration) @definition.event + +; Attributes +(attribute) @attribute diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm new file mode 100644 index 00000000..87bf86ff --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm @@ -0,0 +1,26 @@ +; Go tree-sitter queries for AST-based code splitting + +; Package clause +(package_clause) @package + +; Import declarations +(import_declaration) @import + +; Type declarations (struct, interface, type alias) +(type_declaration + (type_spec + name: (type_identifier) @name)) @definition.type + +; Function declarations +(function_declaration + name: (identifier) @name) @definition.function + +; Method declarations +(method_declaration + name: (field_identifier) @name) @definition.method + +; Variable declarations +(var_declaration) @definition.variable + +; Constant declarations +(const_declaration) @definition.const diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm new file mode 100644 index 00000000..f9e0bdc9 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm @@ -0,0 +1,45 @@ +; Java tree-sitter queries for AST-based code splitting + +; Package declaration +(package_declaration) @package + +; Import declarations +(import_declaration) @import + +; Class declarations +(class_declaration + name: (identifier) @name) @definition.class + +; Interface declarations +(interface_declaration + name: (identifier) @name) @definition.interface + +; Enum declarations +(enum_declaration + name: (identifier) @name) @definition.enum + +; Record declarations (Java 14+) +(record_declaration + name: (identifier) @name) @definition.record + +; Annotation type declarations +(annotation_type_declaration + name: (identifier) @name) @definition.annotation + +; Method declarations +(method_declaration + name: (identifier) @name) @definition.method + +; Constructor declarations +(constructor_declaration + name: (identifier) @name) @definition.constructor + +; Field declarations +(field_declaration) @definition.field + +; Annotations (for metadata) +(marker_annotation + name: (identifier) @name) @annotation + +(annotation + name: (identifier) @name) @annotation diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm new file mode 100644 index 00000000..dd8c3561 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm @@ -0,0 +1,42 @@ +; JavaScript/TypeScript tree-sitter queries for AST-based code splitting + +; Import statements +(import_statement) @import + +; Export statements +(export_statement) @export + +; Class declarations +(class_declaration + name: (identifier) @name) @definition.class + +; Function declarations +(function_declaration + name: (identifier) @name) @definition.function + +; Arrow functions assigned to variables +(lexical_declaration + (variable_declarator + name: (identifier) @name + value: (arrow_function))) @definition.function + +; Generator functions +(generator_function_declaration + name: (identifier) @name) @definition.function + +; Method definitions (inside class body) +(method_definition + name: (property_identifier) @name) @definition.method + +; Variable declarations (module-level) +(lexical_declaration) @definition.variable + +(variable_declaration) @definition.variable + +; Interface declarations (TypeScript) +(interface_declaration + name: (type_identifier) @name) @definition.interface + +; Type alias declarations (TypeScript) +(type_alias_declaration + name: (type_identifier) @name) @definition.type diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm new file mode 100644 index 00000000..52c29810 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm @@ -0,0 +1,40 @@ +; PHP tree-sitter queries for AST-based code splitting + +; Namespace definition +(namespace_definition) @namespace + +; Use statements +(namespace_use_declaration) @use + +; Class declarations +(class_declaration + name: (name) @name) @definition.class + +; Interface declarations +(interface_declaration + name: (name) @name) @definition.interface + +; Trait declarations +(trait_declaration + name: (name) @name) @definition.trait + +; Enum declarations (PHP 8.1+) +(enum_declaration + name: (name) @name) @definition.enum + +; Function definitions +(function_definition + name: (name) @name) @definition.function + +; Method declarations +(method_declaration + name: (name) @name) @definition.method + +; Property declarations +(property_declaration) @definition.property + +; Const declarations +(const_declaration) @definition.const + +; Attributes (PHP 8.0+) +(attribute) @attribute diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm new file mode 100644 index 00000000..6530ad79 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm @@ -0,0 +1,28 @@ +; Python tree-sitter queries for AST-based code splitting + +; Import statements +(import_statement) @import + +(import_from_statement) @import + +; Class definitions +(class_definition + name: (identifier) @name) @definition.class + +; Function definitions +(function_definition + name: (identifier) @name) @definition.function + +; Decorated definitions (class or function with decorators) +(decorated_definition) @definition.decorated + +; Decorators +(decorator) @decorator + +; Assignment statements (module-level constants) +(assignment + left: (identifier) @name) @definition.assignment + +; Type alias (Python 3.12+) +(type_alias_statement + name: (type) @name) @definition.type_alias diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm new file mode 100644 index 00000000..b3315f6a --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm @@ -0,0 +1,46 @@ +; Rust tree-sitter queries for AST-based code splitting + +; Use declarations +(use_declaration) @use + +; Module declarations +(mod_item + name: (identifier) @name) @definition.module + +; Struct definitions +(struct_item + name: (type_identifier) @name) @definition.struct + +; Enum definitions +(enum_item + name: (type_identifier) @name) @definition.enum + +; Trait definitions +(trait_item + name: (type_identifier) @name) @definition.trait + +; Implementation blocks +(impl_item) @definition.impl + +; Function definitions +(function_item + name: (identifier) @name) @definition.function + +; Type alias +(type_item + name: (type_identifier) @name) @definition.type + +; Constant definitions +(const_item + name: (identifier) @name) @definition.const + +; Static definitions +(static_item + name: (identifier) @name) @definition.static + +; Macro definitions +(macro_definition + name: (identifier) @name) @definition.macro + +; Attributes +(attribute_item) @attribute diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm new file mode 100644 index 00000000..71807c96 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm @@ -0,0 +1,52 @@ +; TypeScript tree-sitter queries for AST-based code splitting +; Uses same patterns as JavaScript plus TypeScript-specific nodes + +; Import statements +(import_statement) @import + +; Export statements +(export_statement) @export + +; Class declarations +(class_declaration + name: (type_identifier) @name) @definition.class + +; Abstract class declarations +(abstract_class_declaration + name: (type_identifier) @name) @definition.class + +; Function declarations +(function_declaration + name: (identifier) @name) @definition.function + +; Arrow functions assigned to variables +(lexical_declaration + (variable_declarator + name: (identifier) @name + value: (arrow_function))) @definition.function + +; Method definitions (inside class body) +(method_definition + name: (property_identifier) @name) @definition.method + +; Interface declarations +(interface_declaration + name: (type_identifier) @name) @definition.interface + +; Type alias declarations +(type_alias_declaration + name: (type_identifier) @name) @definition.type + +; Enum declarations +(enum_declaration + name: (identifier) @name) @definition.enum + +; Module declarations +(module + name: (identifier) @name) @definition.module + +; Variable declarations (module-level) +(lexical_declaration) @definition.variable + +; Ambient declarations +(ambient_declaration) @definition.ambient diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py new file mode 100644 index 00000000..47da9f26 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py @@ -0,0 +1,360 @@ +""" +Tree-sitter query runner using custom query files with built-in fallback. + +Prefers custom .scm query files for rich metadata extraction (extends, implements, imports), +falling back to built-in TAGS_QUERY only when custom query is unavailable. +""" + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Any, Optional + +from .tree_parser import get_parser +from .languages import TREESITTER_MODULES + +logger = logging.getLogger(__name__) + +# Directory containing custom .scm query files +QUERIES_DIR = Path(__file__).parent / "queries" + +# Languages that have built-in TAGS_QUERY (used as fallback only) +LANGUAGES_WITH_BUILTIN_TAGS = {'python', 'java', 'javascript', 'go', 'rust', 'php'} + +# Languages with custom .scm files for rich metadata (extends, implements, imports) +LANGUAGES_WITH_CUSTOM_QUERY = { + 'python', 'java', 'javascript', 'typescript', 'c_sharp', 'go', 'rust', 'php' +} + + +@dataclass +class CapturedNode: + """Represents a captured AST node from a query.""" + name: str # Capture name (e.g., 'function.name', 'class.body') + text: str # Node text content + start_byte: int + end_byte: int + start_point: tuple # (row, column) + end_point: tuple + node_type: str # Tree-sitter node type + + @property + def start_line(self) -> int: + return self.start_point[0] + 1 # Convert to 1-based + + @property + def end_line(self) -> int: + return self.end_point[0] + 1 + + +@dataclass +class QueryMatch: + """A complete match from a query pattern.""" + pattern_name: str # e.g., 'function', 'class', 'import' + captures: Dict[str, CapturedNode] = field(default_factory=dict) + + def get(self, capture_name: str) -> Optional[CapturedNode]: + """Get a captured node by name.""" + return self.captures.get(capture_name) + + @property + def full_text(self) -> Optional[str]: + """Get the full text of the main capture (pattern_name without suffix).""" + main_capture = self.captures.get(self.pattern_name) + return main_capture.text if main_capture else None + + +class QueryRunner: + """ + Executes tree-sitter queries using custom .scm files with built-in fallback. + + Strategy: + 1. Prefer custom .scm files for rich metadata (extends, implements, imports, decorators) + 2. Fall back to built-in TAGS_QUERY only when no custom query exists + + Custom queries capture: @class.extends, @class.implements, @import, @decorator, + @method.visibility, @function.return_type, etc. + + Built-in TAGS_QUERY only captures: @definition.function, @definition.class, @name, @doc + """ + + def __init__(self): + self._query_cache: Dict[str, Any] = {} # lang -> compiled query + self._scm_cache: Dict[str, str] = {} # lang -> raw scm string + self._parser = get_parser() + + def _get_builtin_tags_query(self, lang_name: str) -> Optional[str]: + """Get built-in TAGS_QUERY from language package if available.""" + if lang_name not in LANGUAGES_WITH_BUILTIN_TAGS: + return None + + lang_info = TREESITTER_MODULES.get(lang_name) + if not lang_info: + return None + + module_name = lang_info[0] + try: + import importlib + lang_module = importlib.import_module(module_name) + tags_query = getattr(lang_module, 'TAGS_QUERY', None) + if tags_query: + logger.debug(f"Using built-in TAGS_QUERY for {lang_name}") + return tags_query + except (ImportError, AttributeError) as e: + logger.debug(f"Could not load built-in query for {lang_name}: {e}") + + return None + + def _load_custom_query_file(self, lang_name: str) -> Optional[str]: + """Load custom .scm query file for languages without built-in queries.""" + if lang_name in self._scm_cache: + return self._scm_cache[lang_name] + + query_file = QUERIES_DIR / f"{lang_name}.scm" + + if not query_file.exists(): + logger.debug(f"No custom query file for {lang_name}") + return None + + try: + scm_content = query_file.read_text(encoding='utf-8') + self._scm_cache[lang_name] = scm_content + logger.debug(f"Loaded custom query file for {lang_name}") + return scm_content + except Exception as e: + logger.warning(f"Failed to load query file {query_file}: {e}") + return None + + def _get_query_string(self, lang_name: str) -> Optional[str]: + """Get query string - custom first, then built-in fallback.""" + # Prefer custom .scm for rich metadata (extends, implements, imports) + custom = self._load_custom_query_file(lang_name) + if custom: + return custom + + # Fall back to built-in TAGS_QUERY (limited metadata) + return self._get_builtin_tags_query(lang_name) + + def _try_compile_query(self, lang_name: str, scm_content: str, language: Any) -> Optional[Any]: + """Try to compile a query string, returning None on failure.""" + try: + from tree_sitter import Query + return Query(language, scm_content) + except Exception as e: + logger.debug(f"Query compilation failed for {lang_name}: {e}") + return None + + def _get_compiled_query(self, lang_name: str) -> Optional[Any]: + """Get or compile the query for a language with fallback.""" + if lang_name in self._query_cache: + return self._query_cache[lang_name] + + language = self._parser.get_language(lang_name) + if not language: + return None + + # Try custom .scm first + custom_scm = self._load_custom_query_file(lang_name) + if custom_scm: + query = self._try_compile_query(lang_name, custom_scm, language) + if query: + logger.debug(f"Using custom query for {lang_name}") + self._query_cache[lang_name] = query + return query + else: + logger.debug(f"Custom query failed for {lang_name}, trying built-in") + + # Fallback to built-in TAGS_QUERY + builtin_scm = self._get_builtin_tags_query(lang_name) + if builtin_scm: + query = self._try_compile_query(lang_name, builtin_scm, language) + if query: + logger.debug(f"Using built-in TAGS_QUERY for {lang_name}") + self._query_cache[lang_name] = query + return query + + logger.debug(f"No working query available for {lang_name}") + return None + + def run_query( + self, + source_code: str, + lang_name: str, + tree: Optional[Any] = None + ) -> List[QueryMatch]: + """ + Run the query for a language and return all matches. + + Args: + source_code: Source code string + lang_name: Tree-sitter language name + tree: Optional pre-parsed tree (will parse if not provided) + + Returns: + List of QueryMatch objects with captured nodes + """ + query = self._get_compiled_query(lang_name) + if not query: + return [] + + if tree is None: + tree = self._parser.parse(source_code, lang_name) + if tree is None: + return [] + + source_bytes = source_code.encode('utf-8') + + try: + # Use QueryCursor.matches() for pattern-grouped results + # Each match is (pattern_id, {capture_name: [nodes]}) + from tree_sitter import QueryCursor + cursor = QueryCursor(query) + raw_matches = list(cursor.matches(tree.root_node)) + except Exception as e: + logger.warning(f"Query execution failed for {lang_name}: {e}") + return [] + + results: List[QueryMatch] = [] + + for pattern_id, captures_dict in raw_matches: + # Determine pattern type from captures + # Built-in: @definition.function, @definition.class, @name + # Custom: @function, @class, @function.name + + pattern_name = None + main_node = None + name_node = None + doc_node = None + + for capture_name, nodes in captures_dict.items(): + if not nodes: + continue + node = nodes[0] # Take first node for each capture + + # Built-in definition captures + if capture_name.startswith('definition.'): + pattern_name = capture_name[len('definition.'):] + main_node = node + # Built-in @name capture (associated with this pattern) + elif capture_name == 'name': + name_node = node + # Built-in @doc capture + elif capture_name == 'doc': + doc_node = node + # Skip reference captures + elif capture_name.startswith('reference.'): + continue + # Custom query captures: @function, @class + elif '.' not in capture_name: + pattern_name = capture_name + main_node = node + + # Skip if no definition pattern found + if not pattern_name or not main_node: + continue + + # Build the QueryMatch + match = QueryMatch(pattern_name=pattern_name) + + # Add main capture + match.captures[pattern_name] = CapturedNode( + name=pattern_name, + text=source_bytes[main_node.start_byte:main_node.end_byte].decode('utf-8', errors='replace'), + start_byte=main_node.start_byte, + end_byte=main_node.end_byte, + start_point=(main_node.start_point.row, main_node.start_point.column), + end_point=(main_node.end_point.row, main_node.end_point.column), + node_type=main_node.type + ) + + # Add name capture if present + if name_node: + match.captures[f'{pattern_name}.name'] = CapturedNode( + name=f'{pattern_name}.name', + text=source_bytes[name_node.start_byte:name_node.end_byte].decode('utf-8', errors='replace'), + start_byte=name_node.start_byte, + end_byte=name_node.end_byte, + start_point=(name_node.start_point.row, name_node.start_point.column), + end_point=(name_node.end_point.row, name_node.end_point.column), + node_type=name_node.type + ) + + # Add doc capture if present + if doc_node: + match.captures[f'{pattern_name}.doc'] = CapturedNode( + name=f'{pattern_name}.doc', + text=source_bytes[doc_node.start_byte:doc_node.end_byte].decode('utf-8', errors='replace'), + start_byte=doc_node.start_byte, + end_byte=doc_node.end_byte, + start_point=(doc_node.start_point.row, doc_node.start_point.column), + end_point=(doc_node.end_point.row, doc_node.end_point.column), + node_type=doc_node.type + ) + + # Process any additional sub-captures from custom queries + for capture_name, nodes in captures_dict.items(): + if '.' in capture_name and not capture_name.startswith(('definition.', 'reference.')): + node = nodes[0] + match.captures[capture_name] = CapturedNode( + name=capture_name, + text=source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace'), + start_byte=node.start_byte, + end_byte=node.end_byte, + start_point=(node.start_point.row, node.start_point.column), + end_point=(node.end_point.row, node.end_point.column), + node_type=node.type + ) + + results.append(match) + + return results + + def get_functions(self, source_code: str, lang_name: str) -> List[QueryMatch]: + """Convenience method to get function/method matches.""" + matches = self.run_query(source_code, lang_name) + return [m for m in matches if m.pattern_name in ('function', 'method')] + + def get_classes(self, source_code: str, lang_name: str) -> List[QueryMatch]: + """Convenience method to get class/struct/interface matches.""" + matches = self.run_query(source_code, lang_name) + return [m for m in matches if m.pattern_name in ('class', 'struct', 'interface', 'trait')] + + def get_imports(self, source_code: str, lang_name: str) -> List[QueryMatch]: + """Convenience method to get import statement matches.""" + matches = self.run_query(source_code, lang_name) + return [m for m in matches if m.pattern_name == 'import'] + + def has_query(self, lang_name: str) -> bool: + """Check if a query is available for this language (custom or built-in).""" + # Check custom file first + query_file = QUERIES_DIR / f"{lang_name}.scm" + if query_file.exists(): + return True + # Check built-in fallback + return lang_name in LANGUAGES_WITH_BUILTIN_TAGS + + def uses_custom_query(self, lang_name: str) -> bool: + """Check if this language uses custom .scm query (rich metadata).""" + query_file = QUERIES_DIR / f"{lang_name}.scm" + return query_file.exists() + + def uses_builtin_query(self, lang_name: str) -> bool: + """Check if this language uses built-in TAGS_QUERY (limited metadata).""" + return lang_name in LANGUAGES_WITH_BUILTIN_TAGS and not self.uses_custom_query(lang_name) + + def clear_cache(self): + """Clear compiled query cache.""" + self._query_cache.clear() + self._scm_cache.clear() + + +# Global singleton +_runner_instance: Optional[QueryRunner] = None + + +def get_query_runner() -> QueryRunner: + """Get the global QueryRunner instance.""" + global _runner_instance + if _runner_instance is None: + _runner_instance = QueryRunner() + return _runner_instance diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py new file mode 100644 index 00000000..622892d9 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py @@ -0,0 +1,720 @@ +""" +AST-based Code Splitter using Tree-sitter for accurate code parsing. + +This module provides true AST-aware code chunking that: +1. Uses Tree-sitter queries for efficient pattern matching (15+ languages) +2. Splits code into semantic units (classes, functions, methods) +3. Uses RecursiveCharacterTextSplitter for oversized chunks +4. Enriches metadata for better RAG retrieval +5. Maintains parent context ("breadcrumbs") for nested structures +6. Uses deterministic IDs for Qdrant deduplication +""" + +import hashlib +import logging +from typing import List, Dict, Any, Optional, Set +from pathlib import Path +from dataclasses import dataclass, field + +from langchain_text_splitters import RecursiveCharacterTextSplitter, Language +from llama_index.core.schema import Document as LlamaDocument, TextNode + +from .languages import ( + EXTENSION_TO_LANGUAGE, AST_SUPPORTED_LANGUAGES, LANGUAGE_TO_TREESITTER, + get_language_from_path, get_treesitter_name, is_ast_supported +) +from .tree_parser import get_parser +from .query_runner import get_query_runner, QueryMatch +from .metadata import MetadataExtractor, ContentType, ChunkMetadata + +logger = logging.getLogger(__name__) + + +def generate_deterministic_id(path: str, content: str, chunk_index: int = 0) -> str: + """ + Generate a deterministic ID for a chunk based on file path and content. + + This ensures the same code chunk always gets the same ID, preventing + duplicates in Qdrant during re-indexing. + """ + hash_input = f"{path}:{chunk_index}:{content[:500]}" + return hashlib.sha256(hash_input.encode('utf-8')).hexdigest()[:32] + + +def compute_file_hash(content: str) -> str: + """Compute hash of file content for change detection.""" + return hashlib.sha256(content.encode('utf-8')).hexdigest() + + +@dataclass +class ASTChunk: + """Represents a chunk of code from AST parsing.""" + content: str + content_type: ContentType + language: str + path: str + semantic_names: List[str] = field(default_factory=list) + parent_context: List[str] = field(default_factory=list) + docstring: Optional[str] = None + signature: Optional[str] = None + start_line: int = 0 + end_line: int = 0 + node_type: Optional[str] = None + extends: List[str] = field(default_factory=list) + implements: List[str] = field(default_factory=list) + imports: List[str] = field(default_factory=list) + namespace: Optional[str] = None + + +class ASTCodeSplitter: + """ + AST-based code splitter using Tree-sitter queries for accurate parsing. + + Features: + - Uses .scm query files for declarative pattern matching + - Splits code into semantic units (classes, functions, methods) + - Falls back to RecursiveCharacterTextSplitter when needed + - Uses deterministic IDs for Qdrant deduplication + - Enriches metadata for improved RAG retrieval + + Usage: + splitter = ASTCodeSplitter(max_chunk_size=2000) + nodes = splitter.split_documents(documents) + """ + + DEFAULT_MAX_CHUNK_SIZE = 2000 + DEFAULT_MIN_CHUNK_SIZE = 100 + DEFAULT_CHUNK_OVERLAP = 200 + DEFAULT_PARSER_THRESHOLD = 10 + + def __init__( + self, + max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, + min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + parser_threshold: int = DEFAULT_PARSER_THRESHOLD + ): + """ + Initialize AST code splitter. + + Args: + max_chunk_size: Maximum characters per chunk + min_chunk_size: Minimum characters for a valid chunk + chunk_overlap: Overlap between chunks when splitting oversized content + parser_threshold: Minimum lines for AST parsing + """ + self.max_chunk_size = max_chunk_size + self.min_chunk_size = min_chunk_size + self.chunk_overlap = chunk_overlap + self.parser_threshold = parser_threshold + + # Components + self._parser = get_parser() + self._query_runner = get_query_runner() + self._metadata_extractor = MetadataExtractor() + + # Cache text splitters + self._splitter_cache: Dict[Language, RecursiveCharacterTextSplitter] = {} + + # Default splitter + self._default_splitter = RecursiveCharacterTextSplitter( + chunk_size=max_chunk_size, + chunk_overlap=chunk_overlap, + length_function=len, + ) + + def split_documents(self, documents: List[LlamaDocument]) -> List[TextNode]: + """ + Split LlamaIndex documents using AST-based parsing. + + Args: + documents: List of LlamaIndex Document objects + + Returns: + List of TextNode objects with enriched metadata + """ + all_nodes = [] + + for doc in documents: + path = doc.metadata.get('path', 'unknown') + language = get_language_from_path(path) + + line_count = doc.text.count('\n') + 1 + use_ast = ( + language is not None + and language in AST_SUPPORTED_LANGUAGES + and line_count >= self.parser_threshold + and self._parser.is_available() + ) + + if use_ast: + nodes = self._split_with_ast(doc, language) + else: + nodes = self._split_fallback(doc, language) + + all_nodes.extend(nodes) + logger.debug(f"Split {path} into {len(nodes)} chunks (AST={use_ast})") + + return all_nodes + + def _split_with_ast(self, doc: LlamaDocument, language: Language) -> List[TextNode]: + """Split document using AST parsing with query-based extraction.""" + text = doc.text + path = doc.metadata.get('path', 'unknown') + ts_lang = get_treesitter_name(language) + + if not ts_lang: + return self._split_fallback(doc, language) + + # Try query-based extraction first + chunks = self._extract_with_queries(text, ts_lang, path) + + # If no queries available, fall back to traversal-based extraction + if not chunks: + chunks = self._extract_with_traversal(text, ts_lang, path) + + # Still no chunks? Use fallback + if not chunks: + return self._split_fallback(doc, language) + + return self._process_chunks(chunks, doc, language, path) + + def _extract_with_queries( + self, + text: str, + lang_name: str, + path: str + ) -> List[ASTChunk]: + """Extract chunks using tree-sitter query files with rich metadata.""" + if not self._query_runner.has_query(lang_name): + return [] + + tree = self._parser.parse(text, lang_name) + if not tree: + return [] + + matches = self._query_runner.run_query(text, lang_name, tree) + if not matches: + return [] + + source_bytes = text.encode('utf-8') + chunks = [] + processed_ranges: Set[tuple] = set() + + # Collect file-level metadata from all matches + imports = [] + namespace = None + decorators_map: Dict[int, List[str]] = {} # line -> decorators + + for match in matches: + # Handle imports (multiple capture variations) + if match.pattern_name in ('import', 'use'): + import_path = ( + match.get('import.path') or + match.get('import') or + match.get('use.path') or + match.get('use') + ) + if import_path: + imports.append(import_path.text.strip().strip('"\'')) + continue + + # Handle namespace/package/module + if match.pattern_name in ('namespace', 'package', 'module'): + ns_cap = match.get(f'{match.pattern_name}.name') or match.get(match.pattern_name) + if ns_cap: + namespace = ns_cap.text.strip() + continue + + # Handle standalone decorators/attributes + if match.pattern_name in ('decorator', 'attribute', 'annotation'): + dec_cap = match.get(f'{match.pattern_name}.name') or match.get(match.pattern_name) + if dec_cap: + line = dec_cap.start_line + if line not in decorators_map: + decorators_map[line] = [] + decorators_map[line].append(dec_cap.text.strip()) + continue + + # Handle main constructs: functions, classes, methods, etc. + semantic_patterns = ( + 'function', 'method', 'class', 'interface', 'struct', 'trait', + 'enum', 'impl', 'constructor', 'closure', 'arrow', 'const', + 'var', 'static', 'type', 'record' + ) + if match.pattern_name in semantic_patterns: + main_cap = match.get(match.pattern_name) + if not main_cap: + continue + + range_key = (main_cap.start_byte, main_cap.end_byte) + if range_key in processed_ranges: + continue + processed_ranges.add(range_key) + + # Get name from various capture patterns + name_cap = ( + match.get(f'{match.pattern_name}.name') or + match.get('name') + ) + name = name_cap.text if name_cap else None + + # Get inheritance (extends/implements/embeds/supertrait) + extends = [] + implements = [] + + for ext_capture in ('extends', 'embeds', 'supertrait', 'base_type'): + cap = match.get(f'{match.pattern_name}.{ext_capture}') + if cap: + extends.extend(self._parse_type_list(cap.text)) + + for impl_capture in ('implements', 'trait'): + cap = match.get(f'{match.pattern_name}.{impl_capture}') + if cap: + implements.extend(self._parse_type_list(cap.text)) + + # Get additional metadata from captures + visibility = match.get(f'{match.pattern_name}.visibility') + return_type = match.get(f'{match.pattern_name}.return_type') + params = match.get(f'{match.pattern_name}.params') + modifiers = [] + + for mod in ('static', 'abstract', 'final', 'async', 'readonly', 'const', 'unsafe'): + if match.get(f'{match.pattern_name}.{mod}'): + modifiers.append(mod) + + chunk = ASTChunk( + content=main_cap.text, + content_type=ContentType.FUNCTIONS_CLASSES, + language=lang_name, + path=path, + semantic_names=[name] if name else [], + parent_context=[], + start_line=main_cap.start_line, + end_line=main_cap.end_line, + node_type=match.pattern_name, + extends=extends, + implements=implements, + ) + + # Extract docstring and signature + chunk.docstring = self._metadata_extractor.extract_docstring(main_cap.text, lang_name) + chunk.signature = self._metadata_extractor.extract_signature(main_cap.text, lang_name) + + chunks.append(chunk) + + # Add imports and namespace to all chunks + for chunk in chunks: + chunk.imports = imports[:30] + chunk.namespace = namespace + + # Create simplified code chunk + if chunks: + simplified = self._create_simplified_code(text, chunks, lang_name) + if simplified and len(simplified.strip()) > 50: + chunks.append(ASTChunk( + content=simplified, + content_type=ContentType.SIMPLIFIED_CODE, + language=lang_name, + path=path, + start_line=1, + end_line=text.count('\n') + 1, + node_type='simplified', + imports=imports[:30], + namespace=namespace, + )) + + return chunks + + def _extract_with_traversal( + self, + text: str, + lang_name: str, + path: str + ) -> List[ASTChunk]: + """Fallback: extract chunks using manual AST traversal.""" + tree = self._parser.parse(text, lang_name) + if not tree: + return [] + + source_bytes = text.encode('utf-8') + chunks = [] + processed_ranges: Set[tuple] = set() + + # Node types for semantic chunking + semantic_types = self._get_semantic_node_types(lang_name) + class_types = set(semantic_types.get('class', [])) + function_types = set(semantic_types.get('function', [])) + all_types = class_types | function_types + + def get_node_text(node) -> str: + return source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') + + def get_node_name(node) -> Optional[str]: + for child in node.children: + if child.type in ('identifier', 'name', 'type_identifier', 'property_identifier'): + return get_node_text(child) + return None + + def traverse(node, parent_context: List[str]): + node_range = (node.start_byte, node.end_byte) + + if node.type in all_types: + if node_range in processed_ranges: + return + + content = get_node_text(node) + start_line = source_bytes[:node.start_byte].count(b'\n') + 1 + end_line = start_line + content.count('\n') + node_name = get_node_name(node) + is_class = node.type in class_types + + chunk = ASTChunk( + content=content, + content_type=ContentType.FUNCTIONS_CLASSES, + language=lang_name, + path=path, + semantic_names=[node_name] if node_name else [], + parent_context=list(parent_context), + start_line=start_line, + end_line=end_line, + node_type=node.type, + ) + + chunk.docstring = self._metadata_extractor.extract_docstring(content, lang_name) + chunk.signature = self._metadata_extractor.extract_signature(content, lang_name) + + # Extract inheritance via regex + inheritance = self._metadata_extractor.extract_inheritance(content, lang_name) + chunk.extends = inheritance.get('extends', []) + chunk.implements = inheritance.get('implements', []) + chunk.imports = inheritance.get('imports', []) + + chunks.append(chunk) + processed_ranges.add(node_range) + + if is_class and node_name: + for child in node.children: + traverse(child, parent_context + [node_name]) + else: + for child in node.children: + traverse(child, parent_context) + + traverse(tree.root_node, []) + + # Create simplified code + if chunks: + simplified = self._create_simplified_code(text, chunks, lang_name) + if simplified and len(simplified.strip()) > 50: + chunks.append(ASTChunk( + content=simplified, + content_type=ContentType.SIMPLIFIED_CODE, + language=lang_name, + path=path, + start_line=1, + end_line=text.count('\n') + 1, + node_type='simplified', + )) + + return chunks + + def _process_chunks( + self, + chunks: List[ASTChunk], + doc: LlamaDocument, + language: Language, + path: str + ) -> List[TextNode]: + """Process AST chunks into TextNodes, handling oversized chunks.""" + nodes = [] + chunk_counter = 0 + + for ast_chunk in chunks: + if len(ast_chunk.content) > self.max_chunk_size: + sub_nodes = self._split_oversized_chunk(ast_chunk, language, doc.metadata, path) + nodes.extend(sub_nodes) + chunk_counter += len(sub_nodes) + else: + metadata = self._build_metadata(ast_chunk, doc.metadata, chunk_counter, len(chunks)) + chunk_id = generate_deterministic_id(path, ast_chunk.content, chunk_counter) + + node = TextNode( + id_=chunk_id, + text=ast_chunk.content, + metadata=metadata + ) + nodes.append(node) + chunk_counter += 1 + + return nodes + + def _split_oversized_chunk( + self, + chunk: ASTChunk, + language: Optional[Language], + base_metadata: Dict[str, Any], + path: str + ) -> List[TextNode]: + """Split an oversized chunk using RecursiveCharacterTextSplitter.""" + splitter = self._get_text_splitter(language) if language else self._default_splitter + sub_chunks = splitter.split_text(chunk.content) + + nodes = [] + parent_id = generate_deterministic_id(path, chunk.content, 0) + + for i, sub_chunk in enumerate(sub_chunks): + if not sub_chunk or not sub_chunk.strip(): + continue + if len(sub_chunk.strip()) < self.min_chunk_size and len(sub_chunks) > 1: + continue + + metadata = dict(base_metadata) + metadata['content_type'] = ContentType.OVERSIZED_SPLIT.value + metadata['original_content_type'] = chunk.content_type.value + metadata['parent_chunk_id'] = parent_id + metadata['sub_chunk_index'] = i + metadata['total_sub_chunks'] = len(sub_chunks) + + if chunk.parent_context: + metadata['parent_context'] = chunk.parent_context + metadata['parent_class'] = chunk.parent_context[-1] + + if chunk.semantic_names: + metadata['semantic_names'] = chunk.semantic_names + metadata['primary_name'] = chunk.semantic_names[0] + + chunk_id = generate_deterministic_id(path, sub_chunk, i) + nodes.append(TextNode(id_=chunk_id, text=sub_chunk, metadata=metadata)) + + return nodes + + def _split_fallback( + self, + doc: LlamaDocument, + language: Optional[Language] = None + ) -> List[TextNode]: + """Fallback splitting using RecursiveCharacterTextSplitter.""" + text = doc.text + path = doc.metadata.get('path', 'unknown') + + if not text or not text.strip(): + return [] + + splitter = self._get_text_splitter(language) if language else self._default_splitter + chunks = splitter.split_text(text) + + nodes = [] + lang_str = doc.metadata.get('language', 'text') + text_offset = 0 + + for i, chunk in enumerate(chunks): + if not chunk or not chunk.strip(): + continue + if len(chunk.strip()) < self.min_chunk_size and len(chunks) > 1: + continue + if len(chunk) > 30000: + chunk = chunk[:30000] + + # Calculate line numbers + start_line = text[:text_offset].count('\n') + 1 if text_offset > 0 else 1 + chunk_pos = text.find(chunk, text_offset) + if chunk_pos >= 0: + text_offset = chunk_pos + len(chunk) + end_line = start_line + chunk.count('\n') + + metadata = dict(doc.metadata) + metadata['content_type'] = ContentType.FALLBACK.value + metadata['chunk_index'] = i + metadata['total_chunks'] = len(chunks) + metadata['start_line'] = start_line + metadata['end_line'] = end_line + + # Extract names via regex + names = self._metadata_extractor.extract_names_from_content(chunk, lang_str) + if names: + metadata['semantic_names'] = names + metadata['primary_name'] = names[0] + + # Extract inheritance + inheritance = self._metadata_extractor.extract_inheritance(chunk, lang_str) + if inheritance.get('extends'): + metadata['extends'] = inheritance['extends'] + metadata['parent_types'] = inheritance['extends'] + if inheritance.get('implements'): + metadata['implements'] = inheritance['implements'] + if inheritance.get('imports'): + metadata['imports'] = inheritance['imports'] + + chunk_id = generate_deterministic_id(path, chunk, i) + nodes.append(TextNode(id_=chunk_id, text=chunk, metadata=metadata)) + + return nodes + + def _build_metadata( + self, + chunk: ASTChunk, + base_metadata: Dict[str, Any], + chunk_index: int, + total_chunks: int + ) -> Dict[str, Any]: + """Build metadata dictionary from ASTChunk.""" + metadata = dict(base_metadata) + + metadata['content_type'] = chunk.content_type.value + metadata['node_type'] = chunk.node_type + metadata['chunk_index'] = chunk_index + metadata['total_chunks'] = total_chunks + metadata['start_line'] = chunk.start_line + metadata['end_line'] = chunk.end_line + + if chunk.parent_context: + metadata['parent_context'] = chunk.parent_context + metadata['parent_class'] = chunk.parent_context[-1] + metadata['full_path'] = '.'.join(chunk.parent_context + chunk.semantic_names[:1]) + + if chunk.semantic_names: + metadata['semantic_names'] = chunk.semantic_names + metadata['primary_name'] = chunk.semantic_names[0] + + if chunk.docstring: + metadata['docstring'] = chunk.docstring[:500] + + if chunk.signature: + metadata['signature'] = chunk.signature + + if chunk.extends: + metadata['extends'] = chunk.extends + metadata['parent_types'] = chunk.extends + + if chunk.implements: + metadata['implements'] = chunk.implements + + if chunk.imports: + metadata['imports'] = chunk.imports + + if chunk.namespace: + metadata['namespace'] = chunk.namespace + + return metadata + + def _get_text_splitter(self, language: Language) -> RecursiveCharacterTextSplitter: + """Get language-specific text splitter.""" + if language not in self._splitter_cache: + try: + self._splitter_cache[language] = RecursiveCharacterTextSplitter.from_language( + language=language, + chunk_size=self.max_chunk_size, + chunk_overlap=self.chunk_overlap, + ) + except Exception: + self._splitter_cache[language] = self._default_splitter + return self._splitter_cache[language] + + def _create_simplified_code( + self, + source_code: str, + chunks: List[ASTChunk], + language: str + ) -> str: + """Create simplified code with placeholders for extracted chunks.""" + semantic_chunks = [c for c in chunks if c.content_type == ContentType.FUNCTIONS_CLASSES] + if not semantic_chunks: + return source_code + + sorted_chunks = sorted( + semantic_chunks, + key=lambda x: source_code.find(x.content), + reverse=True + ) + + result = source_code + comment_prefix = self._metadata_extractor.get_comment_prefix(language) + + for chunk in sorted_chunks: + pos = result.find(chunk.content) + if pos == -1: + continue + + first_line = chunk.content.split('\n')[0].strip() + if len(first_line) > 60: + first_line = first_line[:60] + '...' + + breadcrumb = "" + if chunk.parent_context: + breadcrumb = f" (in {'.'.join(chunk.parent_context)})" + + placeholder = f"{comment_prefix} Code for: {first_line}{breadcrumb}\n" + result = result[:pos] + placeholder + result[pos + len(chunk.content):] + + return result.strip() + + def _parse_type_list(self, text: str) -> List[str]: + """Parse a comma-separated list of types.""" + if not text: + return [] + + text = text.strip().strip('()[]') + + # Remove keywords + for kw in ('extends', 'implements', 'with', ':'): + text = text.replace(kw, ' ') + + types = [] + for part in text.split(','): + name = part.strip() + if '<' in name: + name = name.split('<')[0].strip() + if '(' in name: + name = name.split('(')[0].strip() + if name: + types.append(name) + + return types + + def _get_semantic_node_types(self, language: str) -> Dict[str, List[str]]: + """Get semantic node types for manual traversal fallback.""" + types = { + 'python': { + 'class': ['class_definition'], + 'function': ['function_definition'], + }, + 'java': { + 'class': ['class_declaration', 'interface_declaration', 'enum_declaration'], + 'function': ['method_declaration', 'constructor_declaration'], + }, + 'javascript': { + 'class': ['class_declaration'], + 'function': ['function_declaration', 'method_definition', 'arrow_function'], + }, + 'typescript': { + 'class': ['class_declaration', 'interface_declaration'], + 'function': ['function_declaration', 'method_definition', 'arrow_function'], + }, + 'go': { + 'class': ['type_declaration'], + 'function': ['function_declaration', 'method_declaration'], + }, + 'rust': { + 'class': ['struct_item', 'impl_item', 'trait_item', 'enum_item'], + 'function': ['function_item'], + }, + 'c_sharp': { + 'class': ['class_declaration', 'interface_declaration', 'struct_declaration'], + 'function': ['method_declaration', 'constructor_declaration'], + }, + 'php': { + 'class': ['class_declaration', 'interface_declaration', 'trait_declaration'], + 'function': ['function_definition', 'method_declaration'], + }, + } + return types.get(language, {'class': [], 'function': []}) + + @staticmethod + def get_supported_languages() -> List[str]: + """Return list of languages with AST support.""" + return list(LANGUAGE_TO_TREESITTER.values()) + + @staticmethod + def is_ast_supported(path: str) -> bool: + """Check if AST parsing is supported for a file.""" + return is_ast_supported(path) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py new file mode 100644 index 00000000..da378a38 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py @@ -0,0 +1,129 @@ +""" +Tree-sitter parser wrapper with caching and language loading. + +Handles dynamic loading of tree-sitter language modules using the new API (v0.23+). +""" + +import logging +from typing import Dict, Any, Optional + +from .languages import TREESITTER_MODULES + +logger = logging.getLogger(__name__) + + +class TreeSitterParser: + """ + Wrapper for tree-sitter parser with language caching. + + Uses the new tree-sitter API (v0.23+) with individual language packages. + """ + + def __init__(self): + self._language_cache: Dict[str, Any] = {} + self._available: Optional[bool] = None + + def is_available(self) -> bool: + """Check if tree-sitter is available and working.""" + if self._available is None: + try: + from tree_sitter import Parser, Language + import tree_sitter_python as tspython + + py_language = Language(tspython.language()) + parser = Parser(py_language) + parser.parse(b"def test(): pass") + + self._available = True + logger.info("tree-sitter is available and working") + except ImportError as e: + logger.warning(f"tree-sitter not installed: {e}") + self._available = False + except Exception as e: + logger.warning(f"tree-sitter error: {type(e).__name__}: {e}") + self._available = False + return self._available + + def get_language(self, lang_name: str) -> Optional[Any]: + """ + Get tree-sitter Language object for a language name. + + Args: + lang_name: Tree-sitter language name (e.g., 'python', 'java', 'php') + + Returns: + tree_sitter.Language object or None if unavailable + """ + if lang_name in self._language_cache: + return self._language_cache[lang_name] + + if not self.is_available(): + return None + + try: + from tree_sitter import Language + + lang_info = TREESITTER_MODULES.get(lang_name) + if not lang_info: + logger.debug(f"No tree-sitter module mapping for '{lang_name}'") + return None + + module_name, func_name = lang_info + + import importlib + lang_module = importlib.import_module(module_name) + + lang_func = getattr(lang_module, func_name, None) + if not lang_func: + logger.debug(f"Module {module_name} has no {func_name} function") + return None + + language = Language(lang_func()) + self._language_cache[lang_name] = language + return language + + except Exception as e: + logger.debug(f"Could not load tree-sitter language '{lang_name}': {e}") + return None + + def parse(self, source_code: str, lang_name: str) -> Optional[Any]: + """ + Parse source code and return the AST tree. + + Args: + source_code: Source code string + lang_name: Tree-sitter language name + + Returns: + tree_sitter.Tree object or None if parsing failed + """ + language = self.get_language(lang_name) + if not language: + return None + + try: + from tree_sitter import Parser + + parser = Parser(language) + tree = parser.parse(bytes(source_code, "utf8")) + return tree + + except Exception as e: + logger.warning(f"Failed to parse code with tree-sitter ({lang_name}): {e}") + return None + + def clear_cache(self): + """Clear the language cache.""" + self._language_cache.clear() + + +# Global singleton instance +_parser_instance: Optional[TreeSitterParser] = None + + +def get_parser() -> TreeSitterParser: + """Get the global TreeSitterParser instance.""" + global _parser_instance + if _parser_instance is None: + _parser_instance = TreeSitterParser() + return _parser_instance diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py new file mode 100644 index 00000000..e1cba601 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py @@ -0,0 +1,232 @@ +""" +Scoring configuration for RAG query result reranking. + +Provides configurable boost factors and priority patterns that can be +overridden via environment variables. +""" + +import os +from typing import Dict, List +from pydantic import BaseModel, Field +import logging + +logger = logging.getLogger(__name__) + + +def _parse_list_env(env_var: str, default: List[str]) -> List[str]: + """Parse comma-separated environment variable into list.""" + value = os.getenv(env_var) + if not value: + return default + return [item.strip() for item in value.split(',') if item.strip()] + + +def _parse_float_env(env_var: str, default: float) -> float: + """Parse float from environment variable.""" + value = os.getenv(env_var) + if not value: + return default + try: + return float(value) + except ValueError: + logger.warning(f"Invalid float value for {env_var}: {value}, using default {default}") + return default + + +class ContentTypeBoost(BaseModel): + """Boost factors for different content types from AST parsing.""" + + functions_classes: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_FUNCTIONS_CLASSES", 1.2), + description="Boost for full function/class definitions (highest value)" + ) + fallback: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_FALLBACK", 1.0), + description="Boost for regex-based split chunks" + ) + oversized_split: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_OVERSIZED", 0.95), + description="Boost for large chunks that were split" + ) + simplified_code: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_SIMPLIFIED", 0.7), + description="Boost for code with placeholders (context only)" + ) + + def get(self, content_type: str) -> float: + """Get boost factor for a content type.""" + return getattr(self, content_type, 1.0) + + +class FilePriorityPatterns(BaseModel): + """File path patterns for priority-based boosting.""" + + high: List[str] = Field( + default_factory=lambda: _parse_list_env( + "RAG_HIGH_PRIORITY_PATTERNS", + ['service', 'controller', 'handler', 'api', 'core', 'auth', 'security', + 'permission', 'repository', 'dao', 'migration'] + ), + description="Patterns for high-priority files (1.3x boost)" + ) + + medium: List[str] = Field( + default_factory=lambda: _parse_list_env( + "RAG_MEDIUM_PRIORITY_PATTERNS", + ['model', 'entity', 'dto', 'schema', 'util', 'helper', 'common', + 'shared', 'component', 'hook', 'client', 'integration'] + ), + description="Patterns for medium-priority files (1.1x boost)" + ) + + low: List[str] = Field( + default_factory=lambda: _parse_list_env( + "RAG_LOW_PRIORITY_PATTERNS", + ['test', 'spec', 'config', 'mock', 'fixture', 'stub'] + ), + description="Patterns for low-priority files (0.8x penalty)" + ) + + high_boost: float = Field( + default_factory=lambda: _parse_float_env("RAG_HIGH_PRIORITY_BOOST", 1.3) + ) + medium_boost: float = Field( + default_factory=lambda: _parse_float_env("RAG_MEDIUM_PRIORITY_BOOST", 1.1) + ) + low_boost: float = Field( + default_factory=lambda: _parse_float_env("RAG_LOW_PRIORITY_BOOST", 0.8) + ) + + def get_priority(self, file_path: str) -> tuple: + """ + Get priority level and boost factor for a file path. + + Returns: + Tuple of (priority_name, boost_factor) + """ + path_lower = file_path.lower() + + if any(p in path_lower for p in self.high): + return ('HIGH', self.high_boost) + elif any(p in path_lower for p in self.medium): + return ('MEDIUM', self.medium_boost) + elif any(p in path_lower for p in self.low): + return ('LOW', self.low_boost) + else: + return ('MEDIUM', 1.0) + + +class MetadataBonus(BaseModel): + """Bonus multipliers for metadata presence.""" + + semantic_names: float = Field( + default_factory=lambda: _parse_float_env("RAG_BONUS_SEMANTIC_NAMES", 1.1), + description="Bonus for chunks with extracted semantic names" + ) + docstring: float = Field( + default_factory=lambda: _parse_float_env("RAG_BONUS_DOCSTRING", 1.05), + description="Bonus for chunks with docstrings" + ) + signature: float = Field( + default_factory=lambda: _parse_float_env("RAG_BONUS_SIGNATURE", 1.02), + description="Bonus for chunks with function signatures" + ) + + +class ScoringConfig(BaseModel): + """ + Complete scoring configuration for RAG query reranking. + + All values can be overridden via environment variables: + - RAG_BOOST_FUNCTIONS_CLASSES, RAG_BOOST_FALLBACK, etc. + - RAG_HIGH_PRIORITY_PATTERNS (comma-separated) + - RAG_HIGH_PRIORITY_BOOST, RAG_MEDIUM_PRIORITY_BOOST, etc. + - RAG_BONUS_SEMANTIC_NAMES, RAG_BONUS_DOCSTRING, etc. + + Usage: + config = ScoringConfig() + boost = config.content_type_boost.get('functions_classes') + priority, boost = config.file_priority.get_priority('/src/UserService.java') + """ + + content_type_boost: ContentTypeBoost = Field(default_factory=ContentTypeBoost) + file_priority: FilePriorityPatterns = Field(default_factory=FilePriorityPatterns) + metadata_bonus: MetadataBonus = Field(default_factory=MetadataBonus) + + # Score thresholds + min_relevance_score: float = Field( + default_factory=lambda: _parse_float_env("RAG_MIN_RELEVANCE_SCORE", 0.7), + description="Minimum score threshold for results" + ) + + max_score_cap: float = Field( + default_factory=lambda: _parse_float_env("RAG_MAX_SCORE_CAP", 1.0), + description="Maximum score cap after boosting" + ) + + def calculate_boosted_score( + self, + base_score: float, + file_path: str, + content_type: str, + has_semantic_names: bool = False, + has_docstring: bool = False, + has_signature: bool = False + ) -> tuple: + """ + Calculate final boosted score for a result. + + Args: + base_score: Original similarity score + file_path: File path of the chunk + content_type: Content type (functions_classes, fallback, etc.) + has_semantic_names: Whether chunk has semantic names + has_docstring: Whether chunk has docstring + has_signature: Whether chunk has signature + + Returns: + Tuple of (boosted_score, priority_level) + """ + score = base_score + + # File priority boost + priority, priority_boost = self.file_priority.get_priority(file_path) + score *= priority_boost + + # Content type boost + content_boost = self.content_type_boost.get(content_type) + score *= content_boost + + # Metadata bonuses + if has_semantic_names: + score *= self.metadata_bonus.semantic_names + if has_docstring: + score *= self.metadata_bonus.docstring + if has_signature: + score *= self.metadata_bonus.signature + + # Cap the score + score = min(score, self.max_score_cap) + + return (score, priority) + + +# Global singleton +_scoring_config: ScoringConfig | None = None + + +def get_scoring_config() -> ScoringConfig: + """Get the global ScoringConfig instance.""" + global _scoring_config + if _scoring_config is None: + _scoring_config = ScoringConfig() + logger.info("ScoringConfig initialized with:") + logger.info(f" High priority patterns: {_scoring_config.file_priority.high[:5]}...") + logger.info(f" Content type boosts: functions_classes={_scoring_config.content_type_boost.functions_classes}") + return _scoring_config + + +def reset_scoring_config(): + """Reset the global config (useful for testing).""" + global _scoring_config + _scoring_config = None diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py index dc222980..f9513fb4 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py @@ -8,36 +8,13 @@ from qdrant_client.http.models import Filter, FieldCondition, MatchValue, MatchAny from ..models.config import RAGConfig +from ..models.scoring_config import get_scoring_config, ScoringConfig from ..utils.utils import make_namespace, make_project_namespace from ..core.openrouter_embedding import OpenRouterEmbedding from ..models.instructions import InstructionType, format_query logger = logging.getLogger(__name__) -# File priority patterns for smart RAG -HIGH_PRIORITY_PATTERNS = [ - 'service', 'controller', 'handler', 'api', 'core', 'auth', 'security', - 'permission', 'repository', 'dao', 'migration' -] - -MEDIUM_PRIORITY_PATTERNS = [ - 'model', 'entity', 'dto', 'schema', 'util', 'helper', 'common', - 'shared', 'component', 'hook', 'client', 'integration' -] - -LOW_PRIORITY_PATTERNS = [ - 'test', 'spec', 'config', 'mock', 'fixture', 'stub' -] - -# Content type priorities for AST-based chunks -# functions_classes are more valuable than simplified_code (placeholders) -CONTENT_TYPE_BOOST = { - 'functions_classes': 1.2, # Full function/class definitions - highest value - 'fallback': 1.0, # Regex-based split - normal value - 'oversized_split': 0.95, # Large chunks that were split - slightly lower - 'simplified_code': 0.7, # Code with placeholders - lower value (context only) -} - class RAGQueryService: """Service for querying RAG indices using Qdrant. @@ -864,11 +841,12 @@ def _merge_and_rank_results(self, results: List[Dict], min_score_threshold: floa """ Deduplicate matches and filter by relevance score with priority-based reranking. - Applies three types of boosting: + Uses ScoringConfig for configurable boosting factors: 1. File path priority (service/controller vs test/config) 2. Content type priority (functions_classes vs simplified_code) 3. Semantic name bonus (chunks with extracted function/class names) """ + scoring_config = get_scoring_config() grouped = {} # Deduplicate by file_path + content hash @@ -884,43 +862,28 @@ def _merge_and_rank_results(self, results: List[Dict], min_score_threshold: floa unique_results = list(grouped.values()) - # Apply multi-factor score boosting + # Apply multi-factor score boosting using ScoringConfig for result in unique_results: metadata = result.get('metadata', {}) - file_path = metadata.get('path', metadata.get('file_path', '')).lower() + file_path = metadata.get('path', metadata.get('file_path', '')) content_type = metadata.get('content_type', 'fallback') semantic_names = metadata.get('semantic_names', []) + has_docstring = bool(metadata.get('docstring')) + has_signature = bool(metadata.get('signature')) + + boosted_score, priority = scoring_config.calculate_boosted_score( + base_score=result['score'], + file_path=file_path, + content_type=content_type, + has_semantic_names=bool(semantic_names), + has_docstring=has_docstring, + has_signature=has_signature + ) - base_score = result['score'] - - # 1. File path priority boosting - if any(p in file_path for p in HIGH_PRIORITY_PATTERNS): - base_score *= 1.3 - result['_priority'] = 'HIGH' - elif any(p in file_path for p in MEDIUM_PRIORITY_PATTERNS): - base_score *= 1.1 - result['_priority'] = 'MEDIUM' - elif any(p in file_path for p in LOW_PRIORITY_PATTERNS): - base_score *= 0.8 # Penalize test/config files - result['_priority'] = 'LOW' - else: - result['_priority'] = 'MEDIUM' - - # 2. Content type boosting (AST-based metadata) - content_boost = CONTENT_TYPE_BOOST.get(content_type, 1.0) - base_score *= content_boost + result['score'] = boosted_score + result['_priority'] = priority result['_content_type'] = content_type - - # 3. Semantic name bonus - chunks with extracted names are more valuable - if semantic_names: - base_score *= 1.1 # 10% bonus for having semantic names - result['_has_semantic_names'] = True - - # 4. Docstring bonus - chunks with docstrings provide better context - if metadata.get('docstring'): - base_score *= 1.05 # 5% bonus for having docstring - - result['score'] = min(1.0, base_score) + result['_has_semantic_names'] = bool(semantic_names) # Filter by threshold filtered = [r for r in unique_results if r['score'] >= min_score_threshold] From 642bda00bc9ad63cd0f8f2fffbb4586a99524876 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:55:20 +0200 Subject: [PATCH 30/34] feat: Enhance lock management in PullRequestAnalysisProcessor and improve code referencing in prompts --- .../processor/analysis/PullRequestAnalysisProcessor.java | 6 +++++- .../mcp-client/service/multi_stage_orchestrator.py | 6 ++++-- .../mcp-client/utils/prompts/prompt_constants.py | 6 ++++++ .../src/rag_pipeline/core/index_manager/indexer.py | 1 + .../src/rag_pipeline/core/index_manager/point_operations.py | 2 +- 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index ce7b3292..87443d21 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -94,8 +94,10 @@ public Map process( // Check if a lock was already acquired by the caller (e.g., webhook handler) // to prevent double-locking which causes unnecessary 2-minute waits String lockKey; + boolean isPreAcquired = false; if (request.getPreAcquiredLockKey() != null && !request.getPreAcquiredLockKey().isBlank()) { lockKey = request.getPreAcquiredLockKey(); + isPreAcquired = true; log.info("Using pre-acquired lock: {} for project={}, PR={}", lockKey, project.getId(), request.getPullRequestId()); } else { Optional acquiredLock = analysisLockService.acquireLockWithWait( @@ -225,7 +227,9 @@ public Map process( return Map.of("status", "error", "message", e.getMessage()); } finally { - analysisLockService.releaseLock(lockKey); + if (!isPreAcquired) { + analysisLockService.releaseLock(lockKey); + } } } diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index e4b9846b..bf662f22 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -410,7 +410,7 @@ def _extract_diff_snippets(self, diff_content: str) -> List[str]: snippets = [] current_snippet_lines = [] - for line in diff_content.split("\n"): + for line in diff_content.splitlines(): # Focus on added lines (new code) if line.startswith("+") and not line.startswith("+++"): clean_line = line[1:].strip() @@ -1084,8 +1084,10 @@ def _format_rag_context( meta_lines.append(f"Type: {chunk_type}") meta_text = "\n".join(meta_lines) + # Use file path as primary identifier, not a number + # This encourages AI to reference by path rather than by chunk number formatted_parts.append( - f"### Related Code #{included_count} (relevance: {score:.2f})\n" + f"### Context from `{path}` (relevance: {score:.2f})\n" f"{meta_text}\n" f"```\n{text}\n```\n" ) diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index 47e65d1b..1c86a621 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -356,6 +356,12 @@ CODEBASE CONTEXT (from RAG): {rag_context} +IMPORTANT: When referencing codebase context in your analysis: +- ALWAYS cite the actual file path (e.g., "as seen in `src/service/UserService.java`") +- NEVER reference context by number (e.g., DO NOT say "Related Code #1" or "chunk #3") +- Quote relevant code snippets when needed to support your analysis +- The numbered headers are for your reference only, not for output + {previous_issues} SUGGESTED_FIX_DIFF_FORMAT: diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py index 2bf145a4..d2f3f323 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py @@ -94,6 +94,7 @@ def estimate_repository_size( sample_chunk_count += len(chunks) del chunks del documents + gc.collect() avg_chunks_per_file = sample_chunk_count / SAMPLE_SIZE chunk_count = int(avg_chunks_per_file * file_count) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py index b9682bf0..62ca6fd2 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py @@ -50,7 +50,7 @@ def prepare_chunks_for_embedding( # Group chunks by file path chunks_by_file: Dict[str, List[TextNode]] = {} for chunk in chunks: - path = chunk.metadata.get("path", "unknown") + path = chunk.metadata.get("path", str(uuid.uuid4())) if path not in chunks_by_file: chunks_by_file[path] = [] chunks_by_file[path].append(chunk) From 5add89cf8ba4a2485e139e51d17b265b37635147 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 22:05:19 +0200 Subject: [PATCH 31/34] feat: Enhance AST processing and metadata extraction in RAG pipeline components --- .../src/rag_pipeline/core/index_manager.py | 3 +- .../src/rag_pipeline/core/loader.py | 21 +- .../rag_pipeline/core/splitter/splitter.py | 646 +++++++++++++++++- .../src/rag_pipeline/models/config.py | 8 +- .../src/rag_pipeline/utils/utils.py | 47 ++ 5 files changed, 696 insertions(+), 29 deletions(-) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py index 7eebf78e..fae2373f 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py @@ -63,7 +63,8 @@ def __init__(self, config: RAGConfig): max_chunk_size=config.chunk_size, min_chunk_size=min(200, config.chunk_size // 4), chunk_overlap=config.chunk_overlap, - parser_threshold=10 # Minimum lines for AST parsing + parser_threshold=3, # Low threshold - AST benefits even small files + enrich_embedding_text=True # Prepend semantic context for better embeddings ) self.loader = DocumentLoader(config) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py index 25d92966..05235278 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py @@ -3,7 +3,7 @@ import logging from llama_index.core.schema import Document -from ..utils.utils import detect_language_from_path, should_exclude_file, is_binary_file +from ..utils.utils import detect_language_from_path, should_exclude_file, is_binary_file, clean_archive_path from ..models.config import RAGConfig logger = logging.getLogger(__name__) @@ -106,11 +106,14 @@ def load_file_batch( language = detect_language_from_path(str(full_path)) filetype = full_path.suffix.lstrip('.') + # Clean archive root prefix from path (e.g., 'owner-repo-commit/src/file.php' -> 'src/file.php') + clean_path = clean_archive_path(relative_path_str) + metadata = { "workspace": workspace, "project": project, "branch": branch, - "path": relative_path_str, + "path": clean_path, "commit": commit, "language": language, "filetype": filetype, @@ -190,11 +193,14 @@ def load_from_directory( language = detect_language_from_path(str(file_path)) filetype = file_path.suffix.lstrip('.') + # Clean archive root prefix from path + clean_path = clean_archive_path(relative_path) + metadata = { "workspace": workspace, "project": project, "branch": branch, - "path": relative_path, + "path": clean_path, "commit": commit, "language": language, "filetype": filetype, @@ -207,7 +213,7 @@ def load_from_directory( ) documents.append(doc) - logger.debug(f"Loaded document: {relative_path} ({language})") + logger.debug(f"Loaded document: {clean_path} ({language})") logger.info(f"Loaded {len(documents)} documents from {repo_path} (excluded {excluded_count} files by patterns)") return documents @@ -257,11 +263,14 @@ def load_specific_files( language = detect_language_from_path(str(full_path)) filetype = full_path.suffix.lstrip('.') + # Clean archive root prefix from path + clean_path = clean_archive_path(relative_path) + metadata = { "workspace": workspace, "project": project, "branch": branch, - "path": relative_path, + "path": clean_path, "commit": commit, "language": language, "filetype": filetype, @@ -274,7 +283,7 @@ def load_specific_files( ) documents.append(doc) - logger.debug(f"Loaded document: {relative_path}") + logger.debug(f"Loaded document: {clean_path}") return documents diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py index 622892d9..25c12d8a 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py @@ -48,22 +48,69 @@ def compute_file_hash(content: str) -> str: @dataclass class ASTChunk: - """Represents a chunk of code from AST parsing.""" + """Represents a chunk of code from AST parsing with rich metadata.""" content: str content_type: ContentType language: str path: str + + # Identity semantic_names: List[str] = field(default_factory=list) - parent_context: List[str] = field(default_factory=list) - docstring: Optional[str] = None - signature: Optional[str] = None + node_type: Optional[str] = None + namespace: Optional[str] = None + + # Location start_line: int = 0 end_line: int = 0 - node_type: Optional[str] = None + + # Hierarchy & Context + parent_context: List[str] = field(default_factory=list) # Breadcrumb path + + # Documentation + docstring: Optional[str] = None + signature: Optional[str] = None + + # Type relationships extends: List[str] = field(default_factory=list) implements: List[str] = field(default_factory=list) + + # Dependencies imports: List[str] = field(default_factory=list) - namespace: Optional[str] = None + + # --- RICH AST FIELDS (extracted from tree-sitter) --- + + # Methods/functions within this chunk (for classes) + methods: List[str] = field(default_factory=list) + + # Properties/fields within this chunk (for classes) + properties: List[str] = field(default_factory=list) + + # Parameters (for functions/methods) + parameters: List[str] = field(default_factory=list) + + # Return type (for functions/methods) + return_type: Optional[str] = None + + # Decorators/annotations + decorators: List[str] = field(default_factory=list) + + # Modifiers (public, private, static, async, abstract, etc.) + modifiers: List[str] = field(default_factory=list) + + # Called functions/methods (dependencies) + calls: List[str] = field(default_factory=list) + + # Referenced types (type annotations, generics) + referenced_types: List[str] = field(default_factory=list) + + # Variables declared in this chunk + variables: List[str] = field(default_factory=list) + + # Constants defined + constants: List[str] = field(default_factory=list) + + # Generic type parameters (e.g., ) + type_parameters: List[str] = field(default_factory=list) class ASTCodeSplitter: @@ -76,23 +123,36 @@ class ASTCodeSplitter: - Falls back to RecursiveCharacterTextSplitter when needed - Uses deterministic IDs for Qdrant deduplication - Enriches metadata for improved RAG retrieval + - Prepares embedding-optimized text with semantic context + + Chunk Size Strategy: + - text-embedding-3-small supports ~8191 tokens (~32K chars) + - We use 8000 chars as default to keep semantic units intact + - Only truly massive classes/functions get split + - Splitting loses AST benefits, so we avoid it when possible Usage: - splitter = ASTCodeSplitter(max_chunk_size=2000) + splitter = ASTCodeSplitter(max_chunk_size=8000) nodes = splitter.split_documents(documents) """ - DEFAULT_MAX_CHUNK_SIZE = 2000 + # Chunk size considerations: + # - Embedding models (text-embedding-3-small): ~8191 tokens = ~32K chars + # - Most classes/functions: 500-5000 chars + # - Keeping semantic units whole improves retrieval quality + # - Only split when absolutely necessary + DEFAULT_MAX_CHUNK_SIZE = 8000 # ~2000 tokens, fits most semantic units DEFAULT_MIN_CHUNK_SIZE = 100 DEFAULT_CHUNK_OVERLAP = 200 - DEFAULT_PARSER_THRESHOLD = 10 + DEFAULT_PARSER_THRESHOLD = 3 # Low threshold - AST benefits even small files def __init__( self, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, - parser_threshold: int = DEFAULT_PARSER_THRESHOLD + parser_threshold: int = DEFAULT_PARSER_THRESHOLD, + enrich_embedding_text: bool = True ): """ Initialize AST code splitter. @@ -101,12 +161,15 @@ def __init__( max_chunk_size: Maximum characters per chunk min_chunk_size: Minimum characters for a valid chunk chunk_overlap: Overlap between chunks when splitting oversized content - parser_threshold: Minimum lines for AST parsing + parser_threshold: Minimum lines for AST parsing (3 recommended) + enrich_embedding_text: Whether to prepend semantic context to chunk text + for better embedding quality """ self.max_chunk_size = max_chunk_size self.min_chunk_size = min_chunk_size self.chunk_overlap = chunk_overlap self.parser_threshold = parser_threshold + self.enrich_embedding_text = enrich_embedding_text # Components self._parser = get_parser() @@ -295,12 +358,16 @@ def _extract_with_queries( node_type=match.pattern_name, extends=extends, implements=implements, + modifiers=modifiers, ) # Extract docstring and signature chunk.docstring = self._metadata_extractor.extract_docstring(main_cap.text, lang_name) chunk.signature = self._metadata_extractor.extract_signature(main_cap.text, lang_name) + # Extract rich AST details (methods, properties, params, calls, etc.) + self._extract_rich_ast_details(chunk, tree, main_cap, lang_name) + chunks.append(chunk) # Add imports and namespace to all chunks @@ -390,6 +457,9 @@ def traverse(node, parent_context: List[str]): chunk.implements = inheritance.get('implements', []) chunk.imports = inheritance.get('imports', []) + # Extract rich AST details directly from this node + self._extract_rich_details_from_node(chunk, node, source_bytes, lang_name) + chunks.append(chunk) processed_ranges.add(node_range) @@ -418,6 +488,339 @@ def traverse(node, parent_context: List[str]): return chunks + def _extract_rich_ast_details( + self, + chunk: ASTChunk, + tree: Any, + captured_node: Any, + lang_name: str + ) -> None: + """ + Extract rich AST details from tree-sitter node by traversing its children. + + This extracts: + - Methods (for classes) + - Properties/fields (for classes) + - Parameters (for functions/methods) + - Return type + - Decorators/annotations + - Called functions/methods + - Referenced types + - Variables + - Type parameters (generics) + """ + source_bytes = chunk.content.encode('utf-8') + + # Find the actual tree-sitter node for this capture + node = self._find_node_at_position( + tree.root_node, + captured_node.start_byte, + captured_node.end_byte + ) + if not node: + return + + # Language-specific node type mappings + node_types = self._get_rich_node_types(lang_name) + + def get_text(n) -> str: + """Get text for a node relative to chunk content.""" + start = n.start_byte - captured_node.start_byte + end = n.end_byte - captured_node.start_byte + if 0 <= start < len(source_bytes) and start < end <= len(source_bytes): + return source_bytes[start:end].decode('utf-8', errors='replace') + return '' + + def extract_identifier(n) -> Optional[str]: + """Extract identifier name from a node.""" + for child in n.children: + if child.type in node_types['identifier']: + return get_text(child) + return None + + def traverse_for_details(n, depth: int = 0): + """Recursively traverse to extract details.""" + if depth > 10: # Prevent infinite recursion + return + + node_type = n.type + + # Extract methods (for classes) + if node_type in node_types['method']: + method_name = extract_identifier(n) + if method_name and method_name not in chunk.methods: + chunk.methods.append(method_name) + + # Extract properties/fields + if node_type in node_types['property']: + prop_name = extract_identifier(n) + if prop_name and prop_name not in chunk.properties: + chunk.properties.append(prop_name) + + # Extract parameters + if node_type in node_types['parameter']: + param_name = extract_identifier(n) + if param_name and param_name not in chunk.parameters: + chunk.parameters.append(param_name) + + # Extract decorators/annotations + if node_type in node_types['decorator']: + dec_text = get_text(n).strip() + if dec_text and dec_text not in chunk.decorators: + # Clean up decorator text + if dec_text.startswith('@'): + dec_text = dec_text[1:] + if '(' in dec_text: + dec_text = dec_text.split('(')[0] + chunk.decorators.append(dec_text) + + # Extract function calls + if node_type in node_types['call']: + call_name = extract_identifier(n) + if call_name and call_name not in chunk.calls: + chunk.calls.append(call_name) + + # Extract type references + if node_type in node_types['type_ref']: + type_text = get_text(n).strip() + if type_text and type_text not in chunk.referenced_types: + # Clean generic params + if '<' in type_text: + type_text = type_text.split('<')[0] + chunk.referenced_types.append(type_text) + + # Extract return type + if node_type in node_types['return_type'] and not chunk.return_type: + chunk.return_type = get_text(n).strip() + + # Extract type parameters (generics) + if node_type in node_types['type_param']: + param_text = get_text(n).strip() + if param_text and param_text not in chunk.type_parameters: + chunk.type_parameters.append(param_text) + + # Extract variables + if node_type in node_types['variable']: + var_name = extract_identifier(n) + if var_name and var_name not in chunk.variables: + chunk.variables.append(var_name) + + # Recurse into children + for child in n.children: + traverse_for_details(child, depth + 1) + + traverse_for_details(node) + + # Limit list sizes to prevent bloat + chunk.methods = chunk.methods[:30] + chunk.properties = chunk.properties[:30] + chunk.parameters = chunk.parameters[:20] + chunk.decorators = chunk.decorators[:10] + chunk.calls = chunk.calls[:50] + chunk.referenced_types = chunk.referenced_types[:30] + chunk.variables = chunk.variables[:30] + chunk.type_parameters = chunk.type_parameters[:10] + + def _find_node_at_position(self, root, start_byte: int, end_byte: int) -> Optional[Any]: + """Find the tree-sitter node at the given byte position.""" + def find(node): + if node.start_byte == start_byte and node.end_byte == end_byte: + return node + for child in node.children: + if child.start_byte <= start_byte and child.end_byte >= end_byte: + result = find(child) + if result: + return result + return None + return find(root) + + def _get_rich_node_types(self, language: str) -> Dict[str, List[str]]: + """Get tree-sitter node types for extracting rich details.""" + # Common patterns across languages + common = { + 'identifier': ['identifier', 'name', 'type_identifier', 'property_identifier'], + 'call': ['call_expression', 'call', 'function_call', 'method_invocation'], + 'type_ref': ['type_identifier', 'generic_type', 'type_annotation', 'type'], + 'type_param': ['type_parameter', 'type_parameters', 'generic_parameter'], + } + + types = { + 'python': { + **common, + 'method': ['function_definition'], + 'property': ['assignment', 'expression_statement'], + 'parameter': ['parameter', 'default_parameter', 'typed_parameter'], + 'decorator': ['decorator'], + 'return_type': ['type'], + 'variable': ['assignment'], + }, + 'java': { + **common, + 'method': ['method_declaration', 'constructor_declaration'], + 'property': ['field_declaration'], + 'parameter': ['formal_parameter', 'spread_parameter'], + 'decorator': ['annotation', 'marker_annotation'], + 'return_type': ['type_identifier', 'generic_type', 'void_type'], + 'variable': ['local_variable_declaration'], + }, + 'javascript': { + **common, + 'method': ['method_definition', 'function_declaration'], + 'property': ['field_definition', 'public_field_definition'], + 'parameter': ['formal_parameters', 'required_parameter'], + 'decorator': ['decorator'], + 'return_type': ['type_annotation'], + 'variable': ['variable_declarator'], + }, + 'typescript': { + **common, + 'method': ['method_definition', 'method_signature', 'function_declaration'], + 'property': ['public_field_definition', 'property_signature'], + 'parameter': ['required_parameter', 'optional_parameter'], + 'decorator': ['decorator'], + 'return_type': ['type_annotation'], + 'variable': ['variable_declarator'], + }, + 'go': { + **common, + 'method': ['method_declaration', 'function_declaration'], + 'property': ['field_declaration'], + 'parameter': ['parameter_declaration'], + 'decorator': [], # Go doesn't have decorators + 'return_type': ['type_identifier', 'pointer_type'], + 'variable': ['short_var_declaration', 'var_declaration'], + }, + 'rust': { + **common, + 'method': ['function_item', 'associated_item'], + 'property': ['field_declaration'], + 'parameter': ['parameter'], + 'decorator': ['attribute_item'], + 'return_type': ['type_identifier', 'generic_type'], + 'variable': ['let_declaration'], + }, + 'c_sharp': { + **common, + 'method': ['method_declaration', 'constructor_declaration'], + 'property': ['property_declaration', 'field_declaration'], + 'parameter': ['parameter'], + 'decorator': ['attribute_list', 'attribute'], + 'return_type': ['predefined_type', 'generic_name'], + 'variable': ['variable_declaration'], + }, + 'php': { + **common, + 'method': ['method_declaration', 'function_definition'], + 'property': ['property_declaration'], + 'parameter': ['simple_parameter'], + 'decorator': ['attribute_list'], + 'return_type': ['named_type', 'union_type'], + 'variable': ['property_declaration', 'simple_variable'], + }, + } + + return types.get(language, { + **common, + 'method': [], + 'property': [], + 'parameter': [], + 'decorator': [], + 'return_type': [], + 'variable': [], + }) + + def _extract_rich_details_from_node( + self, + chunk: ASTChunk, + node: Any, + source_bytes: bytes, + lang_name: str + ) -> None: + """ + Extract rich AST details directly from a tree-sitter node. + Used by traversal-based extraction when we already have the node. + """ + node_types = self._get_rich_node_types(lang_name) + + def get_text(n) -> str: + return source_bytes[n.start_byte:n.end_byte].decode('utf-8', errors='replace') + + def extract_identifier(n) -> Optional[str]: + for child in n.children: + if child.type in node_types['identifier']: + return get_text(child) + return None + + def traverse(n, depth: int = 0): + if depth > 10: + return + + node_type = n.type + + if node_type in node_types['method']: + name = extract_identifier(n) + if name and name not in chunk.methods: + chunk.methods.append(name) + + if node_type in node_types['property']: + name = extract_identifier(n) + if name and name not in chunk.properties: + chunk.properties.append(name) + + if node_type in node_types['parameter']: + name = extract_identifier(n) + if name and name not in chunk.parameters: + chunk.parameters.append(name) + + if node_type in node_types['decorator']: + dec_text = get_text(n).strip() + if dec_text and dec_text not in chunk.decorators: + if dec_text.startswith('@'): + dec_text = dec_text[1:] + if '(' in dec_text: + dec_text = dec_text.split('(')[0] + chunk.decorators.append(dec_text) + + if node_type in node_types['call']: + name = extract_identifier(n) + if name and name not in chunk.calls: + chunk.calls.append(name) + + if node_type in node_types['type_ref']: + type_text = get_text(n).strip() + if type_text and type_text not in chunk.referenced_types: + if '<' in type_text: + type_text = type_text.split('<')[0] + chunk.referenced_types.append(type_text) + + if node_type in node_types['return_type'] and not chunk.return_type: + chunk.return_type = get_text(n).strip() + + if node_type in node_types['type_param']: + param_text = get_text(n).strip() + if param_text and param_text not in chunk.type_parameters: + chunk.type_parameters.append(param_text) + + if node_type in node_types['variable']: + name = extract_identifier(n) + if name and name not in chunk.variables: + chunk.variables.append(name) + + for child in n.children: + traverse(child, depth + 1) + + traverse(node) + + # Limit sizes + chunk.methods = chunk.methods[:30] + chunk.properties = chunk.properties[:30] + chunk.parameters = chunk.parameters[:20] + chunk.decorators = chunk.decorators[:10] + chunk.calls = chunk.calls[:50] + chunk.referenced_types = chunk.referenced_types[:30] + chunk.variables = chunk.variables[:30] + chunk.type_parameters = chunk.type_parameters[:10] + def _process_chunks( self, chunks: List[ASTChunk], @@ -438,9 +841,12 @@ def _process_chunks( metadata = self._build_metadata(ast_chunk, doc.metadata, chunk_counter, len(chunks)) chunk_id = generate_deterministic_id(path, ast_chunk.content, chunk_counter) + # Create embedding-enriched text with semantic context + enriched_text = self._create_embedding_text(ast_chunk.content, metadata) + node = TextNode( id_=chunk_id, - text=ast_chunk.content, + text=enriched_text, metadata=metadata ) nodes.append(node) @@ -455,36 +861,87 @@ def _split_oversized_chunk( base_metadata: Dict[str, Any], path: str ) -> List[TextNode]: - """Split an oversized chunk using RecursiveCharacterTextSplitter.""" + """ + Split an oversized chunk using RecursiveCharacterTextSplitter. + + IMPORTANT: Splitting an AST chunk loses semantic integrity. + We try to preserve what we can: + - Parent context and primary name are kept (they're still relevant) + - Detailed lists (methods, properties, calls) are NOT copied to sub-chunks + because they describe the whole unit, not the fragment + - A summary of the original unit is prepended to help embeddings + """ splitter = self._get_text_splitter(language) if language else self._default_splitter sub_chunks = splitter.split_text(chunk.content) nodes = [] parent_id = generate_deterministic_id(path, chunk.content, 0) + total_sub = len([s for s in sub_chunks if s and s.strip()]) + # Build a brief summary of the original semantic unit + # This helps embeddings understand context even in fragments + unit_summary_parts = [] + if chunk.semantic_names: + unit_summary_parts.append(f"{chunk.node_type or 'code'}: {chunk.semantic_names[0]}") + if chunk.extends: + unit_summary_parts.append(f"extends {', '.join(chunk.extends[:3])}") + if chunk.implements: + unit_summary_parts.append(f"implements {', '.join(chunk.implements[:3])}") + if chunk.methods: + unit_summary_parts.append(f"has {len(chunk.methods)} methods") + + unit_summary = " | ".join(unit_summary_parts) if unit_summary_parts else None + + sub_idx = 0 for i, sub_chunk in enumerate(sub_chunks): if not sub_chunk or not sub_chunk.strip(): continue - if len(sub_chunk.strip()) < self.min_chunk_size and len(sub_chunks) > 1: + if len(sub_chunk.strip()) < self.min_chunk_size and total_sub > 1: continue + # Build metadata for this fragment + # DO NOT copy detailed lists - they don't apply to fragments metadata = dict(base_metadata) metadata['content_type'] = ContentType.OVERSIZED_SPLIT.value metadata['original_content_type'] = chunk.content_type.value metadata['parent_chunk_id'] = parent_id - metadata['sub_chunk_index'] = i - metadata['total_sub_chunks'] = len(sub_chunks) + metadata['sub_chunk_index'] = sub_idx + metadata['total_sub_chunks'] = total_sub + metadata['start_line'] = chunk.start_line + metadata['end_line'] = chunk.end_line + # Keep parent context - still relevant if chunk.parent_context: metadata['parent_context'] = chunk.parent_context metadata['parent_class'] = chunk.parent_context[-1] + # Keep primary name - this fragment belongs to this unit if chunk.semantic_names: - metadata['semantic_names'] = chunk.semantic_names + metadata['semantic_names'] = chunk.semantic_names[:1] # Just the main name metadata['primary_name'] = chunk.semantic_names[0] - chunk_id = generate_deterministic_id(path, sub_chunk, i) - nodes.append(TextNode(id_=chunk_id, text=sub_chunk, metadata=metadata)) + # Add note that this is a fragment + metadata['is_fragment'] = True + metadata['fragment_of'] = chunk.semantic_names[0] if chunk.semantic_names else None + + # For embedding: prepend fragment context + if unit_summary: + fragment_header = f"[Fragment {sub_idx + 1}/{total_sub} of {unit_summary}]" + enriched_text = f"{fragment_header}\n\n{sub_chunk}" + else: + enriched_text = sub_chunk + + chunk_id = generate_deterministic_id(path, sub_chunk, sub_idx) + nodes.append(TextNode(id_=chunk_id, text=enriched_text, metadata=metadata)) + sub_idx += 1 + + # Log when splitting happens - it's a signal the chunk_size might need adjustment + if nodes: + logger.info( + f"Split oversized {chunk.node_type or 'chunk'} " + f"'{chunk.semantic_names[0] if chunk.semantic_names else 'unknown'}' " + f"({len(chunk.content)} chars) into {len(nodes)} fragments" + ) return nodes @@ -545,8 +1002,11 @@ def _split_fallback( if inheritance.get('imports'): metadata['imports'] = inheritance['imports'] + # Create embedding-enriched text with semantic context + enriched_text = self._create_embedding_text(chunk, metadata) + chunk_id = generate_deterministic_id(path, chunk, i) - nodes.append(TextNode(id_=chunk_id, text=chunk, metadata=metadata)) + nodes.append(TextNode(id_=chunk_id, text=enriched_text, metadata=metadata)) return nodes @@ -595,8 +1055,154 @@ def _build_metadata( if chunk.namespace: metadata['namespace'] = chunk.namespace + # --- RICH AST METADATA --- + + if chunk.methods: + metadata['methods'] = chunk.methods + + if chunk.properties: + metadata['properties'] = chunk.properties + + if chunk.parameters: + metadata['parameters'] = chunk.parameters + + if chunk.return_type: + metadata['return_type'] = chunk.return_type + + if chunk.decorators: + metadata['decorators'] = chunk.decorators + + if chunk.modifiers: + metadata['modifiers'] = chunk.modifiers + + if chunk.calls: + metadata['calls'] = chunk.calls + + if chunk.referenced_types: + metadata['referenced_types'] = chunk.referenced_types + + if chunk.variables: + metadata['variables'] = chunk.variables + + if chunk.constants: + metadata['constants'] = chunk.constants + + if chunk.type_parameters: + metadata['type_parameters'] = chunk.type_parameters + return metadata + def _create_embedding_text(self, content: str, metadata: Dict[str, Any]) -> str: + """ + Create embedding-optimized text by prepending concise semantic context. + + Design principles: + 1. Keep it SHORT - long headers can skew embeddings for small code chunks + 2. Avoid redundancy - don't repeat info that's obvious from the code + 3. Clean paths - strip commit hashes and archive prefixes + 4. Add VALUE - include info that helps semantic matching + + What we include (selectively): + - Clean file path (without commit/archive prefixes) + - Parent context (for nested structures - very valuable) + - Extends/implements (inheritance is critical for understanding) + - Docstring (helps semantic matching) + - For CLASSES: method count (helps identify scope) + - For METHODS: skip redundant method list + """ + if not self.enrich_embedding_text: + return content + + context_parts = [] + + # Clean file path - remove commit hash prefixes and archive structure + path = metadata.get('path', '') + if path: + path = self._clean_path(path) + context_parts.append(f"File: {path}") + + # Parent context - valuable for nested structures + parent_context = metadata.get('parent_context', []) + if parent_context: + context_parts.append(f"In: {'.'.join(parent_context)}") + + # Clean namespace - strip keyword if present + namespace = metadata.get('namespace', '') + if namespace: + ns_clean = namespace.replace('namespace ', '').replace('package ', '').strip().rstrip(';') + if ns_clean: + context_parts.append(f"Namespace: {ns_clean}") + + # Type relationships - very valuable for understanding code structure + extends = metadata.get('extends', []) + implements = metadata.get('implements', []) + if extends: + context_parts.append(f"Extends: {', '.join(extends[:3])}") + if implements: + context_parts.append(f"Implements: {', '.join(implements[:3])}") + + # For CLASSES: show method/property counts (helps understand scope) + # For METHODS/FUNCTIONS: skip - it's redundant + node_type = metadata.get('node_type', '') + is_container = node_type in ('class', 'interface', 'struct', 'trait', 'enum', 'impl') + + if is_container: + methods = metadata.get('methods', []) + properties = metadata.get('properties', []) + if methods and len(methods) > 1: + # Only show if there are multiple methods + context_parts.append(f"Methods({len(methods)}): {', '.join(methods[:8])}") + if properties and len(properties) > 1: + context_parts.append(f"Fields({len(properties)}): {', '.join(properties[:5])}") + + # Docstring - valuable for semantic matching + docstring = metadata.get('docstring', '') + if docstring: + # Take just the first sentence or 100 chars + brief = docstring.split('.')[0][:100].strip() + if brief: + context_parts.append(f"Desc: {brief}") + + # Build final text - only if we have meaningful context + if context_parts: + context_header = " | ".join(context_parts) + return f"[{context_header}]\n\n{content}" + + return content + + def _clean_path(self, path: str) -> str: + """ + Clean file path for embedding text. + + Removes: + - Commit hash prefixes (e.g., 'owner-repo-abc123def/') + - Archive extraction paths + - Redundant path components + """ + if not path: + return path + + # Split by '/' and look for src/, lib/, app/ etc as anchor points + parts = path.split('/') + + # Common source directory markers + source_markers = {'src', 'lib', 'app', 'source', 'main', 'test', 'tests', 'pkg', 'cmd', 'internal'} + + # Find the first source marker and start from there + for i, part in enumerate(parts): + if part.lower() in source_markers: + return '/'.join(parts[i:]) + + # If no marker found but path has commit-hash-like prefix (40 hex chars or similar) + if parts and len(parts) > 1: + first_part = parts[0] + # Check if first part looks like "owner-repo-commithash" pattern + if '-' in first_part and len(first_part) > 40: + # Skip the first part + return '/'.join(parts[1:]) + + return path + def _get_text_splitter(self, language: Language) -> RecursiveCharacterTextSplitter: """Get language-specific text splitter.""" if language not in self._splitter_cache: diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py index bde5e2bf..72013d78 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py @@ -74,10 +74,14 @@ def validate_api_key(cls, v: str) -> str: logger.info(f"OpenRouter API key loaded: {v[:10]}...{v[-4:]}") return v - chunk_size: int = Field(default=800) + # Chunk size for code files + # text-embedding-3-small supports ~8191 tokens (~32K chars) + # 8000 chars keeps most semantic units (classes, functions) intact + chunk_size: int = Field(default=8000) chunk_overlap: int = Field(default=200) - text_chunk_size: int = Field(default=1000) + # Text chunk size for non-code files (markdown, docs) + text_chunk_size: int = Field(default=2000) text_chunk_overlap: int = Field(default=200) base_index_namespace: str = Field(default="code_rag") diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py index 4b965cb2..8810113b 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py @@ -69,6 +69,53 @@ def make_project_namespace(workspace: str, project: str) -> str: return f"{workspace}__{project}".replace("/", "_").replace(".", "_").lower() +def clean_archive_path(path: str) -> str: + """ + Clean archive root prefix from file paths. + + Bitbucket and other VCS archives often create a root folder like: + - 'owner-repo-commitHash/' (Bitbucket) + - 'repo-branch/' (GitHub) + + This function strips that prefix to get clean paths like 'src/file.php'. + + Args: + path: File path potentially with archive prefix + + Returns: + Clean path without archive prefix + """ + if not path: + return path + + parts = Path(path).parts + if len(parts) <= 1: + return path + + first_part = parts[0] + + # Common source directory markers - if first part is one of these, path is already clean + source_markers = {'src', 'lib', 'app', 'source', 'main', 'test', 'tests', + 'pkg', 'cmd', 'internal', 'bin', 'scripts', 'docs'} + if first_part.lower() in source_markers: + return path + + # Check if first part looks like archive root: + # - Contains hyphens (owner-repo-commit pattern) + # - Or is very long (40+ chars for commit hash) + # - Or matches pattern like 'name-hexstring' + looks_like_archive = ( + '-' in first_part and len(first_part) > 20 or # owner-repo-commit + len(first_part) >= 40 or # Just commit hash + (first_part.count('-') >= 2 and any(c.isdigit() for c in first_part)) # Has digits and multiple hyphens + ) + + if looks_like_archive: + return '/'.join(parts[1:]) + + return path + + def should_exclude_file(path: str, excluded_patterns: list[str]) -> bool: """Check if file should be excluded based on patterns. From 1fc484c455871ccba3a6d8ec412aa73b5744b6a1 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 23:04:50 +0200 Subject: [PATCH 32/34] feat: Improve deduplication strategy in RAGQueryService to prioritize target branch results --- .../rag_pipeline/services/query_service.py | 152 ++++++++++++++---- 1 file changed, 117 insertions(+), 35 deletions(-) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py index f9513fb4..b4078cd8 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py @@ -72,45 +72,64 @@ def _dedupe_by_branch_priority( ) -> List[Dict]: """Deduplicate results by file path, preferring target branch version. - When same file exists in multiple branches, keep only one version: - - Prefer target_branch version (it's the latest) - - Fall back to base_branch version if target doesn't have it - - This preserves cross-file relationships while avoiding duplicates. + When same file exists in multiple branches, keep only the TARGET branch version. + This ensures we review the NEW code, not the OLD code. + + Strategy: + 1. First pass: collect all paths that exist in target branch + 2. Second pass: for each result, include it only if: + - It's from target branch, OR + - Its path doesn't exist in target branch (cross-file reference from base) + + This ensures: + - Changed files are always from target branch (the PR's new code) + - Related files from base branch are included only if they don't exist in target """ if not results: return results - # Group by path + chunk position (approximate by content hash) - grouped = {} + # Step 1: Find all paths that exist in target branch + target_branch_paths = set() + for result in results: + metadata = result.get('metadata', {}) + branch = metadata.get('branch', '') + if branch == target_branch: + path = metadata.get('path', metadata.get('file_path', '')) + target_branch_paths.add(path) + + logger.debug(f"Target branch '{target_branch}' has {len(target_branch_paths)} unique paths") + + # Step 2: Filter results - target branch wins for same path + deduped = [] + seen_chunks = set() # Track (path, chunk_identity) to avoid exact duplicates for result in results: metadata = result.get('metadata', {}) path = metadata.get('path', metadata.get('file_path', '')) branch = metadata.get('branch', '') - # Create a key based on path and approximate content position - # Using text hash to distinguish different chunks from same file - text_hash = hash(result.get('text', '')[:200]) # First 200 chars for identity - key = f"{path}:{text_hash}" + # Create chunk identity (path + start of content) + chunk_id = f"{path}:{branch}:{hash(result.get('text', '')[:100])}" - if key not in grouped: - grouped[key] = result - else: - existing_branch = grouped[key].get('metadata', {}).get('branch', '') - - # Prefer target branch, then base branch, then whatever has higher score - if branch == target_branch and existing_branch != target_branch: - grouped[key] = result - elif (branch == base_branch and - existing_branch != target_branch and - existing_branch != base_branch): - grouped[key] = result - elif result['score'] > grouped[key]['score'] and branch == existing_branch: - # Same branch, keep higher score - grouped[key] = result - - return list(grouped.values()) + if chunk_id in seen_chunks: + continue + seen_chunks.add(chunk_id) + + # Include if: + # 1. It's from target branch (always include), OR + # 2. Path doesn't exist in target branch (cross-file reference from base) + if branch == target_branch: + deduped.append(result) + elif path not in target_branch_paths: + # This file only exists in base branch - include for cross-file context + deduped.append(result) + # else: skip - file exists in target branch, use that version instead + + skipped_count = len(results) - len(deduped) + if skipped_count > 0: + logger.info(f"Branch priority: kept {len(deduped)} results, skipped {skipped_count} base branch duplicates") + + return deduped def semantic_search( self, @@ -317,6 +336,38 @@ def get_deterministic_context( logger.info(f"Deterministic context: files={file_paths[:5]}, branches={branches}") + def _apply_branch_priority(points: list, target: str, existing_target_paths: set) -> list: + """Filter points to prioritize target branch. + + For each unique path: + - If path exists in target branch, keep only target branch version + - If path only in base branch, keep it (cross-file reference) + """ + if not target or len(branches) == 1: + return points + + # Group by path + by_path = {} + for p in points: + path = p.payload.get("path", "") + if path not in by_path: + by_path[path] = [] + by_path[path].append(p) + + # Select best version per path + result = [] + for path, path_points in by_path.items(): + has_target = any(p.payload.get("branch") == target for p in path_points) + if has_target: + # Keep only target branch for this path + result.extend([p for p in path_points if p.payload.get("branch") == target]) + elif path not in existing_target_paths: + # Path doesn't exist in target - keep base branch version + result.extend(path_points) + # else: skip - path exists in target but these results are from base + + return result + all_chunks = [] changed_files_chunks = {} related_definitions = {} @@ -334,13 +385,17 @@ def get_deterministic_context( changed_file_paths = set() seen_texts = set() - # Build branch filter + # Build branch filter - NOTE: branches[0] is the target branch (has priority) + target_branch = branches[0] if branches else None branch_filter = ( FieldCondition(key="branch", match=MatchValue(value=branches[0])) if len(branches) == 1 else FieldCondition(key="branch", match=MatchAny(any=branches)) ) + # Track which paths exist in target branch (for priority filtering) + target_branch_paths = set() + # ========== STEP 1: Get chunks from changed files ========== for file_path in file_paths: try: @@ -356,7 +411,7 @@ def get_deterministic_context( FieldCondition(key="path", match=MatchValue(value=normalized_path)) ] ), - limit=limit_per_file, + limit=limit_per_file * len(branches), # Get more to account for multiple branches with_payload=True, with_vectors=False ) @@ -374,7 +429,19 @@ def get_deterministic_context( p for p in all_results if normalized_path in p.payload.get("path", "") or filename == p.payload.get("path", "").rsplit("/", 1)[-1] - ][:limit_per_file] + ][:limit_per_file * len(branches)] + + # Apply branch priority: if file exists in target branch, only keep target branch version + if target_branch and len(branches) > 1: + # Check if any result is from target branch + has_target = any(p.payload.get("branch") == target_branch for p in results) + if has_target: + # Keep only target branch results for this file + results = [p for p in results if p.payload.get("branch") == target_branch] + logger.debug(f"Branch priority: keeping target branch '{target_branch}' for {normalized_path}") + + # Apply limit after filtering + results = results[:limit_per_file] chunks_for_file = [] for point in results: @@ -385,6 +452,10 @@ def get_deterministic_context( continue seen_texts.add(text) + # Track which paths exist in target branch + if payload.get("branch") == target_branch: + target_branch_paths.add(payload.get("path", "")) + chunk = { "text": text, "metadata": {k: v for k, v in payload.items() if k not in ("text", "_node_content")}, @@ -448,11 +519,14 @@ def get_deterministic_context( FieldCondition(key="primary_name", match=MatchAny(any=batch)) ] ), - limit=200, + limit=200 * len(branches), # Get more to account for multiple branches with_payload=True, with_vectors=False ) + # Apply branch priority filtering + results = _apply_branch_priority(results, target_branch, target_branch_paths) + for point in results: payload = point.payload if payload.get("path") in changed_file_paths: @@ -495,11 +569,14 @@ def get_deterministic_context( FieldCondition(key="parent_class", match=MatchAny(any=batch)) ] ), - limit=100, + limit=100 * len(branches), # Get more to account for multiple branches with_payload=True, with_vectors=False ) + # Apply branch priority filtering + results = _apply_branch_priority(results, target_branch, target_branch_paths) + for point in results: payload = point.payload if payload.get("path") in changed_file_paths: @@ -542,11 +619,14 @@ def get_deterministic_context( FieldCondition(key="namespace", match=MatchAny(any=batch)) ] ), - limit=30, + limit=30 * len(branches), # Get more to account for multiple branches with_payload=True, with_vectors=False ) + # Apply branch priority filtering + results = _apply_branch_priority(results, target_branch, target_branch_paths) + for point in results: payload = point.payload if payload.get("path") in changed_file_paths: @@ -590,12 +670,14 @@ def get_deterministic_context( "namespace_context": namespace_context, "_metadata": { "branches_searched": branches, + "target_branch": target_branch, "files_requested": file_paths, "identifiers_extracted": list(identifiers_to_find)[:30], "parent_classes_found": list(parent_classes), "namespaces_found": list(namespaces), "imports_extracted": list(imports_raw)[:30], - "extends_extracted": list(extends_raw)[:20] + "extends_extracted": list(extends_raw)[:20], + "target_branch_paths_found": len(target_branch_paths) } } From c03591e119959448e6be522b725236c8438941ee Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 23:05:12 +0200 Subject: [PATCH 33/34] feat: Enhance comments for clarity on target branch indexing and incremental updates in RAG operations --- .../analysis/PullRequestAnalysisProcessor.java | 3 ++- .../service/RagOperationsServiceImpl.java | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index 87443d21..5c7783c0 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -162,7 +162,8 @@ public Map process( ? Optional.empty() : Optional.of(allPrAnalyses.get(0)); - // Ensure branch index exists for target branch if configured + // Ensure branch index exists for TARGET branch (e.g., "1.2.1-rc") + // This is where the PR will merge TO - we want RAG context from this branch ensureRagIndexForTargetBranch(project, request.getTargetBranchName(), consumer); VcsAiClientService aiClientService = vcsServiceFactory.getAiClientService(provider); diff --git a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java index e8128be4..56c54d09 100644 --- a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java +++ b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/service/RagOperationsServiceImpl.java @@ -802,14 +802,16 @@ private boolean ensureBranchIndexUpToDate( return true; } + // Branch index exists but commit changed - do INCREMENTAL update (only new changes) + // Full diff vs main is only done on INITIAL indexing (when RagBranchIndex doesn't exist) log.info("Branch index outdated for project={}, branch={}: indexed={}, current={} - fetching incremental diff", project.getId(), targetBranch, indexedCommit, currentCommit); - // Fetch diff between indexed commit and current commit on this branch + // Fetch diff between last indexed commit and current commit (incremental) String rawDiff = vcsClient.getBranchDiff(workspaceSlug, repoSlug, indexedCommit, currentCommit); - log.info("Incremental diff for branch '{}' ({}..{}): bytes={}", - targetBranch, indexedCommit.substring(0, Math.min(7, indexedCommit.length())), - currentCommit.substring(0, 7), rawDiff != null ? rawDiff.length() : 0); + log.info("Incremental diff for branch '{}' ({} -> {}): bytes={}", + targetBranch, indexedCommit.substring(0, 7), currentCommit.substring(0, 7), + rawDiff != null ? rawDiff.length() : 0); if (rawDiff == null || rawDiff.isEmpty()) { log.info("No diff between {} and {} - updating commit hash only", indexedCommit, currentCommit); @@ -823,13 +825,11 @@ private boolean ensureBranchIndexUpToDate( eventConsumer.accept(Map.of( "type", "status", "state", "branch_update", - "message", String.format("Updating branch %s index from %s to %s", - targetBranch, indexedCommit.substring(0, Math.min(7, indexedCommit.length())), - currentCommit.substring(0, 7)) + "message", String.format("Updating branch %s index (incremental: %d bytes)", targetBranch, rawDiff.length()) )); // Trigger incremental update for this branch - log.info("Triggering incremental update for branch '{}' with diff of {} bytes", + log.info("Triggering incremental branch update for '{}' with {} bytes diff", targetBranch, rawDiff.length()); triggerIncrementalUpdate(project, targetBranch, currentCommit, rawDiff, eventConsumer); From 0bb9ca8dc7e61848c14527be38aa45e36ecabc09 Mon Sep 17 00:00:00 2001 From: rostislav Date: Thu, 29 Jan 2026 00:09:01 +0200 Subject: [PATCH 34/34] feat: Update default configuration values for chunk size and text chunk size in tests --- python-ecosystem/rag-pipeline/tests/test_rag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-ecosystem/rag-pipeline/tests/test_rag.py b/python-ecosystem/rag-pipeline/tests/test_rag.py index 4be64146..2a45a1d3 100644 --- a/python-ecosystem/rag-pipeline/tests/test_rag.py +++ b/python-ecosystem/rag-pipeline/tests/test_rag.py @@ -60,9 +60,9 @@ def test_config_defaults(): """Test default configuration""" config = RAGConfig() - assert config.chunk_size == 800 + assert config.chunk_size == 8000 # Increased to fit most semantic units assert config.chunk_overlap == 200 - assert config.text_chunk_size == 1000 + assert config.text_chunk_size == 2000 assert config.retrieval_top_k == 10 assert config.similarity_threshold == 0.7