diff --git a/aipg/__init__.py b/aipg/__init__.py index 0abed06..7d610d8 100644 --- a/aipg/__init__.py +++ b/aipg/__init__.py @@ -7,11 +7,6 @@ import typer from rich import print as rprint -from aipg.assistant import Assistant -from aipg.configs.app_config import AppConfig -from aipg.configs.loader import load_config -from aipg.task import Task - @dataclass class TimingContext: @@ -67,6 +62,10 @@ def run_assistant( logging.info("Starting Cherry AI Project Generator") # Load config with all overrides try: + # Local imports to avoid heavy dependencies at package import time + from aipg.configs.app_config import AppConfig + from aipg.configs.loader import load_config + config = load_config(presets, config_path, config_overrides, AppConfig) logging.info("Successfully loaded config") except Exception as e: @@ -77,6 +76,9 @@ def run_assistant( with time_block("initializing components", timer): rprint("🤖 [bold red] Welcome to Cherry AI Project Generator [/bold red]") + from aipg.assistant import Assistant + from aipg.task import Task + assistant = Assistant(config) task = Task(issue_description=issue) task = assistant.generate_project(task) diff --git a/aipg/assistant.py b/aipg/assistant.py index e595339..85f8d16 100644 --- a/aipg/assistant.py +++ b/aipg/assistant.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import signal import sys @@ -6,9 +8,7 @@ from typing import List, Optional, Type from aipg.configs.app_config import AppConfig -from aipg.llm import LLMClient from aipg.task import Task -from aipg.task_inference import MicroProjectGenerationInference, TaskInference logger = logging.getLogger(__name__) @@ -39,18 +39,26 @@ def handle_timeout(signum, frame): class Assistant: - def __init__(self, config: AppConfig) -> None: + def __init__(self, config: AppConfig, rag_service: Optional[object] = None) -> None: self.config = config - self.llm = LLMClient(config) + self.llm = None # Lazy initialize to avoid heavy imports + self.rag = rag_service def handle_exception(self, stage: str, exception: Exception): raise Exception(str(exception), stage) def _run_task_inference( - self, task_inferences: List[Type[TaskInference]], task: Task + self, task_inferences: List[Type[object]], task: Task ): + class _LLMProxy: + def __init__(self, ensure_llm): + self._ensure_llm = ensure_llm + + def __getattr__(self, item): + return getattr(self._ensure_llm(), item) + for inference_class in task_inferences: - inference = inference_class(llm=self.llm) + inference = inference_class(llm=_LLMProxy(self._ensure_llm)) try: with timeout( seconds=self.config.task_timeout, @@ -62,11 +70,66 @@ def _run_task_inference( f"Task inference preprocessing: {inference_class}", e ) + def _ensure_llm(self): + if self.llm is None: + from aipg.llm import LLMClient + + self.llm = LLMClient(self.config) + return self.llm + + def _ensure_rag(self): + if self.rag is None: + llm = self._ensure_llm() + from aipg.rag.integration import build_rag_service + + self.rag = build_rag_service(self.config, llm) + return self.rag + def generate_project(self, task: Task) -> Task: - task_inferences: List[Type[TaskInference]] = [ - MicroProjectGenerationInference, - ] + # 1) Try RAG retrieval first + try: + rag = self._ensure_rag() + retrieved = rag.retrieve(task.issue.description) # type: ignore[attr-defined] + except Exception: + retrieved = None + + if retrieved: + # Parse and set task fields + from aipg.prompting.utils import parse_and_check_json + from aipg.prompting.prompt_generator import ( + MicroTaskGenerationPromptGenerator, + ) + + parsed = parse_and_check_json( + retrieved, expected_keys=MicroTaskGenerationPromptGenerator.fields + ) + for k, v in parsed.items(): + setattr(task, k, v) + return task + + # 2) Generate via LLM + # Lazy import to allow test monkeypatching and avoid heavy imports at module load + from aipg import task_inference as ti_mod + task_inferences: List[Type[object]] = [ + ti_mod.MicroProjectGenerationInference, + ] self._run_task_inference(task_inferences, task) + # 3) Save generated to RAG + try: + rag = self._ensure_rag() + import json + + micro_project_json = json.dumps( + { + "task_description": task.task_description or "", + "task_goal": task.task_goal or "", + "expert_solution": task.expert_solution or "", + } + ) + rag.save(task.issue.description, micro_project_json) # type: ignore[attr-defined] + except Exception: + pass + return task diff --git a/aipg/configs/app_config.py b/aipg/configs/app_config.py index 22e7513..5d22e99 100644 --- a/aipg/configs/app_config.py +++ b/aipg/configs/app_config.py @@ -35,3 +35,17 @@ class AppConfig(BaseModel): task_timeout: int = 3600 time_limit: int = 14400 session_id: str = Field(default_factory=lambda: uuid4().hex) + + +class RagConfig(BaseModel): + similarity_threshold: float = 0.7 + k_candidates: int = 5 + collection_name: str = "micro_projects" + chroma_path: str = Field(default=str(Path(PACKAGE_PATH) / "cache" / "chroma")) + embedding_model: str = "text-embedding-3-small" + embedding_base_url: Optional[str] = None + embedding_api_key: Optional[str] = None + + +# Backward compatibility: AppConfig may not include rag in persisted configs +AppConfig.rag = RagConfig() # type: ignore[attr-defined] diff --git a/aipg/rag/__init__.py b/aipg/rag/__init__.py new file mode 100644 index 0000000..5ca09a8 --- /dev/null +++ b/aipg/rag/__init__.py @@ -0,0 +1,10 @@ +from .ports import EmbeddingPort, VectorStorePort +from .service import RagService, RagResult + +__all__ = [ + "EmbeddingPort", + "VectorStorePort", + "RagService", + "RagResult", +] + diff --git a/aipg/rag/adapters.py b/aipg/rag/adapters.py new file mode 100644 index 0000000..1890ec0 --- /dev/null +++ b/aipg/rag/adapters.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import List, Optional, Callable + +try: # Optional dependency at runtime + import chromadb # type: ignore +except Exception: # pragma: no cover - import-time optionality + chromadb = None # type: ignore + +try: # Optional dependency at runtime + from openai import OpenAI # type: ignore +except Exception: # pragma: no cover - import-time optionality + OpenAI = None # type: ignore + +from .ports import EmbeddingPort, VectorStorePort, RetrievedItem + + +class ChromaVectorStore(VectorStorePort): + def __init__( + self, + collection_name: str = "micro_projects", + persist_dir: Optional[str] = None, + ) -> None: + if chromadb is None: # pragma: no cover + raise ImportError("chromadb is not installed") + + if persist_dir: + client = chromadb.PersistentClient(path=persist_dir) + else: + client = chromadb.Client() + + self.collection = client.get_or_create_collection( + name=collection_name, metadata={"hnsw:space": "cosine"} + ) + + def add(self, ids: List[str], embeddings: List[List[float]], metadatas: List[dict]): + self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) + + def query(self, embedding: List[float], k: int) -> List[RetrievedItem]: + res = self.collection.query( + query_embeddings=[embedding], n_results=k, include=["metadatas"] + ) + items: List[RetrievedItem] = [] + metadatas = res.get("metadatas") or [] + if metadatas: + for meta in metadatas[0]: + issue = meta.get("issue", "") + micro_project = meta.get("micro_project", "") + items.append(RetrievedItem(issue=issue, micro_project=micro_project, metadata=meta)) + return items + + +class OpenAIEmbeddingAdapter(EmbeddingPort): + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "text-embedding-3-small", + client: Optional[object] = None, + ) -> None: + if client is not None: + self.client = client + else: + if OpenAI is None: # pragma: no cover + raise ImportError("openai is not installed") + self.client = OpenAI(api_key=api_key, base_url=base_url) + self.model = model + + def embed(self, texts: List[str]) -> List[List[float]]: + resp = self.client.embeddings.create(model=self.model, input=texts) + return [d.embedding for d in resp.data] + + +def llm_ranker_from_client(llm_query: Callable[[list[dict] | str], Optional[str]]) -> Callable[[str, List[str]], List[float]]: + def rank(query: str, candidates: List[str]) -> List[float]: + if not candidates: + return [] + numbered = "\n".join([f"{i+1}. {c}" for i, c in enumerate(candidates)]) + prompt = [ + { + "role": "system", + "content": ( + "You are a precise similarity rater. Given a query and a list of candidate issues, " + "return a JSON array of floats in [0,1] representing semantic similarity for each candidate." + ), + }, + { + "role": "user", + "content": ( + f"Query: {query}\nCandidates:\n{numbered}\n\n" + "Return only JSON like [0.12, 0.5, 0.99]." + ), + }, + ] + output = llm_query(prompt) or "[]" + try: + import json + + scores = json.loads(output) + if not isinstance(scores, list): + raise ValueError("Invalid scores format") + return [float(x) for x in scores] + except Exception: + return [0.0 for _ in candidates] + + return rank + diff --git a/aipg/rag/integration.py b/aipg/rag/integration.py new file mode 100644 index 0000000..d51ca45 --- /dev/null +++ b/aipg/rag/integration.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import json +from typing import Optional + +from aipg.configs.app_config import AppConfig +from aipg.llm import LLMClient +from .adapters import ChromaVectorStore, OpenAIEmbeddingAdapter, llm_ranker_from_client +from .service import RagService + + +def build_rag_service(config: AppConfig, llm: LLMClient) -> RagService: + rag = getattr(config, "rag", None) + similarity_threshold = getattr(rag, "similarity_threshold", 0.7) + k_candidates = getattr(rag, "k_candidates", 5) + collection_name = getattr(rag, "collection_name", "micro_projects") + chroma_path = getattr(rag, "chroma_path", None) + embedding_model = getattr(rag, "embedding_model", "text-embedding-3-small") + embedding_base_url = getattr(rag, "embedding_base_url", None) or config.llm.base_url + embedding_api_key = getattr(rag, "embedding_api_key", None) or config.llm.api_key + + vector_store = ChromaVectorStore(collection_name=collection_name, persist_dir=chroma_path) + embedder = OpenAIEmbeddingAdapter( + api_key=embedding_api_key, + base_url=embedding_base_url, + model=embedding_model, + ) + ranker = llm_ranker_from_client(llm.query) + + # Generator is not used by assistant path, keep simple fallback + def generator(issue: str) -> str: + return json.dumps({"task_description": issue, "task_goal": "", "expert_solution": ""}) + + return RagService( + embedder=embedder, + vector_store=vector_store, + ranker=ranker, + generator=generator, + similarity_threshold=similarity_threshold, + k_candidates=k_candidates, + ) + diff --git a/aipg/rag/ports.py b/aipg/rag/ports.py new file mode 100644 index 0000000..6031983 --- /dev/null +++ b/aipg/rag/ports.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable, List, Optional, Protocol, Tuple + + +@dataclass +class RetrievedItem: + issue: str + micro_project: str + metadata: Optional[dict] = None + + +class EmbeddingPort(Protocol): + def embed(self, texts: List[str]) -> List[List[float]]: # pragma: no cover + ... + + +class VectorStorePort(ABC): + @abstractmethod + def add(self, ids: List[str], embeddings: List[List[float]], metadatas: List[dict]): + raise NotImplementedError + + @abstractmethod + def query(self, embedding: List[float], k: int) -> List[RetrievedItem]: + raise NotImplementedError + diff --git a/aipg/rag/service.py b/aipg/rag/service.py new file mode 100644 index 0000000..bd55272 --- /dev/null +++ b/aipg/rag/service.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, List, Optional + +from .ports import EmbeddingPort, VectorStorePort, RetrievedItem + + +@dataclass +class RagResult: + micro_project: str + source: str # "retrieved" or "generated" + matched_issue: Optional[str] = None + + +class RagService: + def __init__( + self, + embedder: EmbeddingPort, + vector_store: VectorStorePort, + ranker: Callable[[str, List[str]], List[float]], + generator: Callable[[str], str], + similarity_threshold: float = 0.7, + k_candidates: int = 5, + ) -> None: + self.embedder = embedder + self.vector_store = vector_store + self.ranker = ranker + self.generator = generator + self.similarity_threshold = similarity_threshold + self.k_candidates = k_candidates + + def get_or_create_micro_project(self, issue: str) -> RagResult: + issue_embedding = self.embedder.embed([issue])[0] + candidates: List[RetrievedItem] = self.vector_store.query( + embedding=issue_embedding, k=self.k_candidates + ) + + candidate_issues = [c.issue for c in candidates] + if candidate_issues: + scores = self.ranker(issue, candidate_issues) + best_idx = max(range(len(scores)), key=lambda i: scores[i]) + best_score = scores[best_idx] + if best_score >= self.similarity_threshold: + best_item = candidates[best_idx] + return RagResult( + micro_project=best_item.micro_project, + source="retrieved", + matched_issue=best_item.issue, + ) + + # Generate new micro-project + micro_project = self.generator(issue) + # Save to vector store + new_embedding = issue_embedding # reuse computed embedding + self.vector_store.add( + ids=[issue], + embeddings=[new_embedding], + metadatas=[{"issue": issue, "micro_project": micro_project}], + ) + return RagResult(micro_project=micro_project, source="generated") + + # Integration methods for external generation + def retrieve(self, issue: str) -> Optional[str]: + issue_embedding = self.embedder.embed([issue])[0] + candidates: List[RetrievedItem] = self.vector_store.query( + embedding=issue_embedding, k=self.k_candidates + ) + candidate_issues = [c.issue for c in candidates] + if not candidate_issues: + return None + scores = self.ranker(issue, candidate_issues) + if not scores: + return None + best_idx = max(range(len(scores)), key=lambda i: scores[i]) + best_score = scores[best_idx] + if best_score >= self.similarity_threshold: + return candidates[best_idx].micro_project + return None + + def save(self, issue: str, micro_project: str) -> None: + issue_embedding = self.embedder.embed([issue])[0] + self.vector_store.add( + ids=[issue], + embeddings=[issue_embedding], + metadatas=[{"issue": issue, "micro_project": micro_project}], + ) + diff --git a/aipg/task.py b/aipg/task.py index 3c44c3f..2a94512 100644 --- a/aipg/task.py +++ b/aipg/task.py @@ -1,4 +1,4 @@ -from litellm import dataclass +from dataclasses import dataclass @dataclass diff --git a/aipg/task_inference/task_inference.py b/aipg/task_inference/task_inference.py index b04a130..69c0914 100644 --- a/aipg/task_inference/task_inference.py +++ b/aipg/task_inference/task_inference.py @@ -2,7 +2,6 @@ from typing import Any, Dict, List, Optional from aipg.exceptions import OutputParserException -from aipg.llm import LLMClient from aipg.prompting.prompt_generator import ( MicroTaskGenerationPromptGenerator, PromptGenerator, @@ -13,9 +12,9 @@ class TaskInference: - def __init__(self, llm: LLMClient, *args, **kwargs): + def __init__(self, llm: Any, *args, **kwargs): super().__init__(*args, **kwargs) - self.llm: LLMClient = llm + self.llm = llm self.fallback_value: Optional[str] = None self.ignored_value: List[str] = [] diff --git a/pyproject.toml b/pyproject.toml index 5c1c664..b7ab819 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ dependencies = [ "tenacity>=9.1.2", "tiktoken>=0.11.0", "typer>=0.17.3", + "chromadb>=0.5.0", + "openai>=1.30.0", ] [build-system] @@ -28,6 +30,7 @@ aipg = "aipg:main" dev = [ "mypy>=1.17.1", "ruff>=0.12.12", + "pytest>=7.0.0", ] [tool.hatch.build.targets.wheel] diff --git a/tests/test_assistant_integration.py b/tests/test_assistant_integration.py new file mode 100644 index 0000000..3005474 --- /dev/null +++ b/tests/test_assistant_integration.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Optional + +from aipg.assistant import Assistant +from aipg.configs.app_config import AppConfig +from aipg.task import Task + + +class FakeRag: + def __init__(self, retrieved: Optional[str] = None): + self.retrieved = retrieved + self.saved = [] + + def retrieve(self, issue: str) -> Optional[str]: + return self.retrieved + + def save(self, issue: str, micro_project: str) -> None: + self.saved.append((issue, micro_project)) + + +def test_assistant_uses_rag_when_available(): + # retrieved JSON + mp = '{"task_description":"A","task_goal":"B","expert_solution":"C"}' + assistant = Assistant(AppConfig(), rag_service=FakeRag(retrieved=mp)) + task = Task("Some issue") + out = assistant.generate_project(task) + assert out.task_description == "A" + assert out.task_goal == "B" + assert out.expert_solution == "C" + + +def test_assistant_generates_and_saves_when_not_found(monkeypatch): + # Force LLM inference to be a no-op by mocking transform flow + # We patch MicroProjectGenerationInference to avoid external LLM calls + from aipg import task_inference as ti_mod + + class DummyInference(ti_mod.TaskInference): + def transform(self, task: Task) -> Task: # type: ignore[override] + task.task_description = "X" + task.task_goal = "Y" + task.expert_solution = "Z" + return task + + monkeypatch.setattr(ti_mod, "MicroProjectGenerationInference", DummyInference) + + fake_rag = FakeRag(retrieved=None) + assistant = Assistant(AppConfig(), rag_service=fake_rag) + task = Task("Different issue") + out = assistant.generate_project(task) + assert out.task_description == "X" + assert fake_rag.saved, "Generated project should be saved in RAG" + diff --git a/tests/test_rag_service.py b/tests/test_rag_service.py new file mode 100644 index 0000000..fa22f7e --- /dev/null +++ b/tests/test_rag_service.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List + +import pytest + +from aipg.rag.ports import RetrievedItem +from aipg.rag.service import RagService + + +class FakeEmbedder: + def __init__(self, mapping: dict[str, List[float]]): + self.mapping = mapping + + def embed(self, texts: List[str]) -> List[List[float]]: + return [self.mapping[t] for t in texts] + + +class FakeVectorStore: + def __init__(self, items: list[RetrievedItem] | None = None): + self.items = items or [] + self.add_calls: list[dict] = [] + + def add(self, ids: List[str], embeddings: List[List[float]], metadatas: List[dict]): + for meta in metadatas: + self.items.append( + RetrievedItem( + issue=meta["issue"], micro_project=meta["micro_project"], metadata=meta + ) + ) + self.add_calls.append({"ids": ids, "embeddings": embeddings, "metadatas": metadatas}) + + def query(self, embedding: List[float], k: int) -> List[RetrievedItem]: + return self.items[:k] + + +@pytest.mark.parametrize( + "issue, existing, rank_scores, expect_source", + [ + ( + "How to fix list index error?", + [RetrievedItem(issue="IndexError in Python", micro_project="Fix index handling")], + [0.9], + "retrieved", + ), + ( + "How to fix list index error?", + [RetrievedItem(issue="Null pointer in Java", micro_project="Handle nulls")], + [0.3], + "generated", + ), + ], +) +def test_rag_service_retrieve_or_generate(issue, existing, rank_scores, expect_source): + embedder = FakeEmbedder({issue: [0.1, 0.2, 0.3]}) + store = FakeVectorStore(existing.copy()) + + def ranker(q: str, cands: List[str]) -> List[float]: + return rank_scores + + def generator(q: str) -> str: + return "Generated Micro Project" + + service = RagService( + embedder=embedder, + vector_store=store, + ranker=ranker, + generator=generator, + similarity_threshold=0.7, + k_candidates=5, + ) + + result = service.get_or_create_micro_project(issue) + + assert result.source == expect_source + if expect_source == "retrieved": + assert result.micro_project == existing[0].micro_project + assert result.matched_issue == existing[0].issue + assert store.add_calls == [] + else: + assert result.micro_project == "Generated Micro Project" + assert result.matched_issue is None + assert store.add_calls, "Generated project must be saved to vector store" +