Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions aipg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
import typer
from rich import print as rprint

from aipg.assistant import Assistant
from aipg.configs.app_config import AppConfig
from aipg.configs.loader import load_config
from aipg.task import Task


@dataclass
class TimingContext:
Expand Down Expand Up @@ -67,6 +62,10 @@ def run_assistant(
logging.info("Starting Cherry AI Project Generator")
# Load config with all overrides
try:
# Local imports to avoid heavy dependencies at package import time
from aipg.configs.app_config import AppConfig
from aipg.configs.loader import load_config

config = load_config(presets, config_path, config_overrides, AppConfig)
logging.info("Successfully loaded config")
except Exception as e:
Expand All @@ -77,6 +76,9 @@ def run_assistant(

with time_block("initializing components", timer):
rprint("🤖 [bold red] Welcome to Cherry AI Project Generator [/bold red]")
from aipg.assistant import Assistant
from aipg.task import Task

assistant = Assistant(config)
task = Task(issue_description=issue)
task = assistant.generate_project(task)
Expand Down
81 changes: 72 additions & 9 deletions aipg/assistant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import signal
import sys
Expand All @@ -6,9 +8,7 @@
from typing import List, Optional, Type

from aipg.configs.app_config import AppConfig
from aipg.llm import LLMClient
from aipg.task import Task
from aipg.task_inference import MicroProjectGenerationInference, TaskInference

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,18 +39,26 @@ def handle_timeout(signum, frame):


class Assistant:
def __init__(self, config: AppConfig) -> None:
def __init__(self, config: AppConfig, rag_service: Optional[object] = None) -> None:
self.config = config
self.llm = LLMClient(config)
self.llm = None # Lazy initialize to avoid heavy imports
self.rag = rag_service

def handle_exception(self, stage: str, exception: Exception):
raise Exception(str(exception), stage)

def _run_task_inference(
self, task_inferences: List[Type[TaskInference]], task: Task
self, task_inferences: List[Type[object]], task: Task
):
class _LLMProxy:
def __init__(self, ensure_llm):
self._ensure_llm = ensure_llm

def __getattr__(self, item):
return getattr(self._ensure_llm(), item)

for inference_class in task_inferences:
inference = inference_class(llm=self.llm)
inference = inference_class(llm=_LLMProxy(self._ensure_llm))
try:
with timeout(
seconds=self.config.task_timeout,
Expand All @@ -62,11 +70,66 @@ def _run_task_inference(
f"Task inference preprocessing: {inference_class}", e
)

def _ensure_llm(self):
if self.llm is None:
from aipg.llm import LLMClient

self.llm = LLMClient(self.config)
return self.llm

def _ensure_rag(self):
if self.rag is None:
llm = self._ensure_llm()
from aipg.rag.integration import build_rag_service

self.rag = build_rag_service(self.config, llm)
return self.rag

def generate_project(self, task: Task) -> Task:
task_inferences: List[Type[TaskInference]] = [
MicroProjectGenerationInference,
]
# 1) Try RAG retrieval first
try:
rag = self._ensure_rag()
retrieved = rag.retrieve(task.issue.description) # type: ignore[attr-defined]
except Exception:
retrieved = None

if retrieved:
# Parse and set task fields
from aipg.prompting.utils import parse_and_check_json
from aipg.prompting.prompt_generator import (
MicroTaskGenerationPromptGenerator,
)

parsed = parse_and_check_json(
retrieved, expected_keys=MicroTaskGenerationPromptGenerator.fields
)
for k, v in parsed.items():
setattr(task, k, v)
return task

# 2) Generate via LLM
# Lazy import to allow test monkeypatching and avoid heavy imports at module load
from aipg import task_inference as ti_mod

task_inferences: List[Type[object]] = [
ti_mod.MicroProjectGenerationInference,
]
self._run_task_inference(task_inferences, task)

# 3) Save generated to RAG
try:
rag = self._ensure_rag()
import json

micro_project_json = json.dumps(
{
"task_description": task.task_description or "",
"task_goal": task.task_goal or "",
"expert_solution": task.expert_solution or "",
}
)
rag.save(task.issue.description, micro_project_json) # type: ignore[attr-defined]
except Exception:
pass

return task
14 changes: 14 additions & 0 deletions aipg/configs/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,17 @@ class AppConfig(BaseModel):
task_timeout: int = 3600
time_limit: int = 14400
session_id: str = Field(default_factory=lambda: uuid4().hex)


class RagConfig(BaseModel):
similarity_threshold: float = 0.7
k_candidates: int = 5
collection_name: str = "micro_projects"
chroma_path: str = Field(default=str(Path(PACKAGE_PATH) / "cache" / "chroma"))
embedding_model: str = "text-embedding-3-small"
embedding_base_url: Optional[str] = None
embedding_api_key: Optional[str] = None


# Backward compatibility: AppConfig may not include rag in persisted configs
AppConfig.rag = RagConfig() # type: ignore[attr-defined]
10 changes: 10 additions & 0 deletions aipg/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .ports import EmbeddingPort, VectorStorePort
from .service import RagService, RagResult

__all__ = [
"EmbeddingPort",
"VectorStorePort",
"RagService",
"RagResult",
]

107 changes: 107 additions & 0 deletions aipg/rag/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

from typing import List, Optional, Callable

try: # Optional dependency at runtime
import chromadb # type: ignore
except Exception: # pragma: no cover - import-time optionality
chromadb = None # type: ignore

try: # Optional dependency at runtime
from openai import OpenAI # type: ignore
except Exception: # pragma: no cover - import-time optionality
OpenAI = None # type: ignore

from .ports import EmbeddingPort, VectorStorePort, RetrievedItem


class ChromaVectorStore(VectorStorePort):
def __init__(
self,
collection_name: str = "micro_projects",
persist_dir: Optional[str] = None,
) -> None:
if chromadb is None: # pragma: no cover
raise ImportError("chromadb is not installed")

if persist_dir:
client = chromadb.PersistentClient(path=persist_dir)
else:
client = chromadb.Client()

self.collection = client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)

def add(self, ids: List[str], embeddings: List[List[float]], metadatas: List[dict]):
self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)

def query(self, embedding: List[float], k: int) -> List[RetrievedItem]:
res = self.collection.query(
query_embeddings=[embedding], n_results=k, include=["metadatas"]
)
items: List[RetrievedItem] = []
metadatas = res.get("metadatas") or []
if metadatas:
for meta in metadatas[0]:
issue = meta.get("issue", "")
micro_project = meta.get("micro_project", "")
items.append(RetrievedItem(issue=issue, micro_project=micro_project, metadata=meta))
return items


class OpenAIEmbeddingAdapter(EmbeddingPort):
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
client: Optional[object] = None,
) -> None:
if client is not None:
self.client = client
else:
if OpenAI is None: # pragma: no cover
raise ImportError("openai is not installed")
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model = model

def embed(self, texts: List[str]) -> List[List[float]]:
resp = self.client.embeddings.create(model=self.model, input=texts)
return [d.embedding for d in resp.data]


def llm_ranker_from_client(llm_query: Callable[[list[dict] | str], Optional[str]]) -> Callable[[str, List[str]], List[float]]:
def rank(query: str, candidates: List[str]) -> List[float]:
if not candidates:
return []
numbered = "\n".join([f"{i+1}. {c}" for i, c in enumerate(candidates)])
prompt = [
{
"role": "system",
"content": (
"You are a precise similarity rater. Given a query and a list of candidate issues, "
"return a JSON array of floats in [0,1] representing semantic similarity for each candidate."
),
},
{
"role": "user",
"content": (
f"Query: {query}\nCandidates:\n{numbered}\n\n"
"Return only JSON like [0.12, 0.5, 0.99]."
),
},
]
output = llm_query(prompt) or "[]"
try:
import json

scores = json.loads(output)
if not isinstance(scores, list):
raise ValueError("Invalid scores format")
return [float(x) for x in scores]
except Exception:
return [0.0 for _ in candidates]

return rank

42 changes: 42 additions & 0 deletions aipg/rag/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import json
from typing import Optional

from aipg.configs.app_config import AppConfig
from aipg.llm import LLMClient
from .adapters import ChromaVectorStore, OpenAIEmbeddingAdapter, llm_ranker_from_client
from .service import RagService


def build_rag_service(config: AppConfig, llm: LLMClient) -> RagService:
rag = getattr(config, "rag", None)
similarity_threshold = getattr(rag, "similarity_threshold", 0.7)
k_candidates = getattr(rag, "k_candidates", 5)
collection_name = getattr(rag, "collection_name", "micro_projects")
chroma_path = getattr(rag, "chroma_path", None)
embedding_model = getattr(rag, "embedding_model", "text-embedding-3-small")
embedding_base_url = getattr(rag, "embedding_base_url", None) or config.llm.base_url
embedding_api_key = getattr(rag, "embedding_api_key", None) or config.llm.api_key

vector_store = ChromaVectorStore(collection_name=collection_name, persist_dir=chroma_path)
embedder = OpenAIEmbeddingAdapter(
api_key=embedding_api_key,
base_url=embedding_base_url,
model=embedding_model,
)
ranker = llm_ranker_from_client(llm.query)

# Generator is not used by assistant path, keep simple fallback
def generator(issue: str) -> str:
return json.dumps({"task_description": issue, "task_goal": "", "expert_solution": ""})

return RagService(
embedder=embedder,
vector_store=vector_store,
ranker=ranker,
generator=generator,
similarity_threshold=similarity_threshold,
k_candidates=k_candidates,
)

28 changes: 28 additions & 0 deletions aipg/rag/ports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterable, List, Optional, Protocol, Tuple


@dataclass
class RetrievedItem:
issue: str
micro_project: str
metadata: Optional[dict] = None


class EmbeddingPort(Protocol):
def embed(self, texts: List[str]) -> List[List[float]]: # pragma: no cover
...


class VectorStorePort(ABC):
@abstractmethod
def add(self, ids: List[str], embeddings: List[List[float]], metadatas: List[dict]):
raise NotImplementedError

@abstractmethod
def query(self, embedding: List[float], k: int) -> List[RetrievedItem]:
raise NotImplementedError

Loading