Skip to content

Commit cb65705

Browse files
committed
feat(dna): add framework detection and middleware patterns
- Add detected_framework field to CodebaseDNA - Add _detect_framework() for FastAPI/Starlette/Flask/Django/Express/Next/Nest - Add _extract_middleware_patterns() for framework-specific middleware detection - Improve _extract_auth_patterns() with Starlette/Flask/Django patterns - Add middleware_patterns field to output - Update to_markdown() with framework and middleware sections
1 parent 7ec8e96 commit cb65705

1 file changed

Lines changed: 117 additions & 4 deletions

File tree

backend/services/dna_extractor.py

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,15 @@ class NamingConventions:
7878
class CodebaseDNA:
7979
"""Complete DNA profile of a codebase"""
8080
repo_id: str
81+
detected_framework: Optional[str] = None
8182
language_distribution: Dict[str, int] = field(default_factory=dict)
8283
auth_patterns: AuthPattern = field(default_factory=AuthPattern)
8384
service_patterns: ServicePattern = field(default_factory=ServicePattern)
8485
database_patterns: DatabasePattern = field(default_factory=DatabasePattern)
8586
error_patterns: ErrorPattern = field(default_factory=ErrorPattern)
8687
logging_patterns: LoggingPattern = field(default_factory=LoggingPattern)
8788
naming_conventions: NamingConventions = field(default_factory=NamingConventions)
89+
middleware_patterns: List[str] = field(default_factory=list)
8890
common_imports: List[str] = field(default_factory=list)
8991
skip_directories: List[str] = field(default_factory=list)
9092
api_versioning: Optional[str] = None
@@ -98,12 +100,23 @@ def to_markdown(self) -> str:
98100
"""Generate markdown DNA document for AI consumption"""
99101
md = f"# Codebase DNA\n\n"
100102

103+
# Framework detection
104+
if self.detected_framework:
105+
md += f"**Detected Framework:** {self.detected_framework}\n\n"
106+
101107
# Language distribution
102108
md += "## Language Distribution\n"
103109
for lang, count in sorted(self.language_distribution.items(), key=lambda x: -x[1]):
104110
md += f"- {lang}: {count} files\n"
105111
md += "\n"
106112

113+
# Middleware patterns
114+
if self.middleware_patterns:
115+
md += "## Middleware Patterns\n"
116+
for mw in self.middleware_patterns:
117+
md += f"- `{mw}`\n"
118+
md += "\n"
119+
107120
# Auth patterns
108121
md += "## Authentication Patterns\n"
109122
if self.auth_patterns.middleware_used:
@@ -229,7 +242,70 @@ def _discover_files(self, repo_path: Path) -> List[Path]:
229242

230243
return files
231244

232-
def _extract_auth_patterns(self, files: List[Path], repo_path: Path) -> AuthPattern:
245+
def _detect_framework(self, files: List[Path]) -> Optional[str]:
246+
"""Detect the primary framework used in the codebase"""
247+
framework_indicators = {
248+
'fastapi': ['from fastapi', 'FastAPI()', 'APIRouter'],
249+
'starlette': ['from starlette', 'Starlette()', 'starlette.routing'],
250+
'flask': ['from flask', 'Flask(__name__)', '@app.route'],
251+
'django': ['from django', 'django.conf', 'INSTALLED_APPS'],
252+
'express': ['require("express")', 'express()', 'app.use('],
253+
'nextjs': ['from next', 'getServerSideProps', 'getStaticProps'],
254+
'nestjs': ['@Module(', '@Injectable(', '@Controller('],
255+
}
256+
257+
scores = Counter()
258+
for file_path in files:
259+
try:
260+
content = file_path.read_text(encoding='utf-8', errors='ignore')
261+
for framework, indicators in framework_indicators.items():
262+
for indicator in indicators:
263+
if indicator in content:
264+
scores[framework] += 1
265+
except:
266+
pass
267+
268+
if scores:
269+
return scores.most_common(1)[0][0]
270+
return None
271+
272+
def _extract_middleware_patterns(self, files: List[Path], framework: Optional[str]) -> List[str]:
273+
"""Extract middleware patterns based on framework"""
274+
patterns = []
275+
276+
for file_path in files:
277+
try:
278+
content = file_path.read_text(encoding='utf-8', errors='ignore')
279+
280+
# Starlette/ASGI middleware
281+
if 'class' in content and 'Middleware' in content:
282+
middlewares = re.findall(r'class\s+(\w*Middleware\w*)', content)
283+
patterns.extend(middlewares)
284+
if 'Middleware(' in content:
285+
patterns.append('Middleware(cls)')
286+
if 'app.add_middleware' in content:
287+
patterns.append('app.add_middleware()')
288+
289+
# FastAPI Depends
290+
if 'Depends(' in content:
291+
deps = re.findall(r'Depends\((\w+)\)', content)
292+
for dep in deps:
293+
patterns.append(f'Depends({dep})')
294+
295+
# Express middleware
296+
if 'app.use(' in content:
297+
patterns.append('app.use(middleware)')
298+
299+
# Flask decorators
300+
if '@app.before_request' in content:
301+
patterns.append('@app.before_request')
302+
303+
except:
304+
pass
305+
306+
return list(set(patterns))
307+
308+
def _extract_auth_patterns(self, files: List[Path], repo_path: Path, framework: Optional[str] = None) -> AuthPattern:
233309
"""Extract authentication patterns from codebase"""
234310
pattern = AuthPattern()
235311

@@ -240,19 +316,49 @@ def _extract_auth_patterns(self, files: List[Path], repo_path: Path) -> AuthPatt
240316
try:
241317
content = file_path.read_text(encoding='utf-8', errors='ignore')
242318

243-
# Detect middleware imports
319+
# FastAPI patterns
244320
if 'require_auth' in content:
245321
pattern.middleware_used.append('require_auth')
246322
if 'public_auth' in content:
247323
pattern.middleware_used.append('public_auth')
248324
if 'Depends(' in content and 'auth' in content.lower():
249325
pattern.auth_decorators.append('Depends(require_auth)')
250326

327+
# Starlette patterns
328+
if 'AuthenticationMiddleware' in content:
329+
pattern.middleware_used.append('AuthenticationMiddleware')
330+
if 'AuthCredentials' in content:
331+
pattern.auth_context_type = 'AuthCredentials'
332+
if 'AuthenticationBackend' in content:
333+
pattern.middleware_used.append('AuthenticationBackend')
334+
if 'requires(' in content:
335+
scopes = re.findall(r'requires\([\'"](\w+)[\'"]\)', content)
336+
for scope in scopes:
337+
pattern.auth_decorators.append(f'@requires("{scope}")')
338+
339+
# Flask patterns
340+
if 'login_required' in content:
341+
pattern.auth_decorators.append('@login_required')
342+
if 'flask_login' in content:
343+
pattern.middleware_used.append('flask_login')
344+
if 'current_user' in content:
345+
pattern.auth_context_type = 'current_user'
346+
347+
# Django patterns
348+
if '@login_required' in content:
349+
pattern.auth_decorators.append('@login_required')
350+
if 'permission_required' in content:
351+
pattern.auth_decorators.append('@permission_required')
352+
if 'request.user' in content:
353+
pattern.auth_context_type = 'request.user'
354+
251355
# Detect ownership checks
252356
if 'get_repo_or_404' in content:
253357
pattern.ownership_checks.append('get_repo_or_404(repo_id, auth.user_id)')
254358
if 'verify_ownership' in content:
255359
pattern.ownership_checks.append('verify_ownership')
360+
if 'user_id' in content and ('==' in content or '.filter(' in content):
361+
pattern.ownership_checks.append('user_id check')
256362

257363
# Detect AuthContext
258364
if 'AuthContext' in content:
@@ -516,15 +622,20 @@ def extract_dna(self, repo_path: str, repo_id: str) -> CodebaseDNA:
516622
files = self._discover_files(repo_path)
517623
logger.info(f"Found {len(files)} code files")
518624

625+
# Detect framework first
626+
detected_framework = self._detect_framework(files)
627+
logger.info(f"Detected framework: {detected_framework}")
628+
519629
# Language distribution
520630
lang_dist = Counter()
521631
for f in files:
522632
lang = self._detect_language(str(f))
523633
if lang != 'unknown':
524634
lang_dist[lang] += 1
525635

526-
# Extract all patterns
527-
auth_patterns = self._extract_auth_patterns(files, repo_path)
636+
# Extract all patterns (pass framework where needed)
637+
auth_patterns = self._extract_auth_patterns(files, repo_path, detected_framework)
638+
middleware_patterns = self._extract_middleware_patterns(files, detected_framework)
528639
service_patterns = self._extract_service_patterns(files, repo_path)
529640
database_patterns = self._extract_database_patterns(files, repo_path)
530641
error_patterns = self._extract_error_patterns(files)
@@ -535,13 +646,15 @@ def extract_dna(self, repo_path: str, repo_id: str) -> CodebaseDNA:
535646

536647
dna = CodebaseDNA(
537648
repo_id=repo_id,
649+
detected_framework=detected_framework,
538650
language_distribution=dict(lang_dist),
539651
auth_patterns=auth_patterns,
540652
service_patterns=service_patterns,
541653
database_patterns=database_patterns,
542654
error_patterns=error_patterns,
543655
logging_patterns=logging_patterns,
544656
naming_conventions=naming_conventions,
657+
middleware_patterns=middleware_patterns,
545658
common_imports=common_imports,
546659
skip_directories=list(self.SKIP_DIRS),
547660
api_versioning=api_versioning,

0 commit comments

Comments
 (0)