Skip to content

Commit 7040a29

Browse files
wip: move llm functions in models module
1 parent fd48488 commit 7040a29

File tree

3 files changed

+455
-497
lines changed

3 files changed

+455
-497
lines changed

src/kit/models/llm_client.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""LLM client interfaces and implementations."""
2+
3+
import logging
4+
from abc import ABC, abstractmethod
5+
from typing import Any, Dict, Optional, Union
6+
7+
from kit.models.base import LLMError
8+
from kit.models.config import AnthropicConfig, GoogleConfig, OpenAIConfig
9+
from kit.models.llm_utils import count_openai_chat_tokens
10+
11+
# Conditionally import google.genai
12+
try:
13+
import google.genai as genai
14+
from google.genai import types as genai_types
15+
except ImportError:
16+
genai = None # type: ignore
17+
genai_types = None # type: ignore
18+
19+
logger = logging.getLogger(__name__)
20+
21+
# Constants
22+
OPENAI_MAX_PROMPT_TOKENS = 15000 # Max tokens for the prompt to OpenAI
23+
24+
25+
class LLMClient(ABC):
26+
"""Base class for LLM clients."""
27+
28+
@abstractmethod
29+
def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str:
30+
"""Generate a completion from the LLM.
31+
32+
Args:
33+
system_prompt: The system prompt to use.
34+
user_prompt: The user prompt to use.
35+
model_name: Optional model name to override the default.
36+
37+
Returns:
38+
The generated completion text.
39+
40+
Raises:
41+
LLMError: If there was an error generating the completion.
42+
"""
43+
pass
44+
45+
@staticmethod
46+
def create_client(config: Union[OpenAIConfig, AnthropicConfig, GoogleConfig]) -> "LLMClient":
47+
"""Factory method to create an appropriate LLM client.
48+
49+
Args:
50+
config: The LLM configuration to use.
51+
52+
Returns:
53+
An LLMClient instance.
54+
55+
Raises:
56+
TypeError: If config is None or an unsupported configuration type.
57+
LLMError: If there was an error initializing the client.
58+
"""
59+
# Require a valid config
60+
if config is None:
61+
raise TypeError("LLM configuration must be provided")
62+
63+
if isinstance(config, OpenAIConfig):
64+
return OpenAIClient(config)
65+
elif isinstance(config, AnthropicConfig):
66+
return AnthropicClient(config)
67+
elif isinstance(config, GoogleConfig):
68+
return GoogleClient(config)
69+
else:
70+
raise TypeError(f"Unsupported LLM configuration type: {type(config)}")
71+
72+
73+
class OpenAIClient(LLMClient):
74+
"""Client for OpenAI's API."""
75+
76+
def __init__(self, config: OpenAIConfig):
77+
"""Initialize with OpenAI configuration.
78+
79+
Args:
80+
config: The OpenAI configuration.
81+
82+
Raises:
83+
LLMError: If the OpenAI SDK is not available.
84+
"""
85+
self.config = config
86+
try:
87+
from openai import OpenAI
88+
89+
if self.config.base_url:
90+
self.client = OpenAI(api_key=self.config.api_key, base_url=self.config.base_url)
91+
else:
92+
self.client = OpenAI(api_key=self.config.api_key)
93+
except ImportError:
94+
raise LLMError("OpenAI SDK (openai) not available. Please install it.")
95+
96+
def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str:
97+
"""Generate a completion using OpenAI's API.
98+
99+
Args:
100+
system_prompt: The system prompt to use.
101+
user_prompt: The user prompt to use.
102+
model_name: Optional model name to override the config's model.
103+
104+
Returns:
105+
The generated completion text.
106+
107+
Raises:
108+
LLMError: If there was an error generating the completion.
109+
"""
110+
# Use provided model_name or fall back to config
111+
actual_model = model_name if model_name is not None else self.config.model
112+
113+
messages_for_api = [
114+
{"role": "system", "content": system_prompt},
115+
{"role": "user", "content": user_prompt},
116+
]
117+
118+
# Check token count
119+
prompt_token_count = count_openai_chat_tokens(messages_for_api, actual_model)
120+
if prompt_token_count is not None and prompt_token_count > OPENAI_MAX_PROMPT_TOKENS:
121+
return f"Completion generation failed: OpenAI prompt too large ({prompt_token_count} tokens). Limit is {OPENAI_MAX_PROMPT_TOKENS} tokens."
122+
123+
try:
124+
response = self.client.chat.completions.create(
125+
model=actual_model,
126+
messages=messages_for_api,
127+
temperature=self.config.temperature,
128+
max_tokens=self.config.max_tokens,
129+
)
130+
131+
if response.usage:
132+
logger.debug(f"OpenAI API usage: {response.usage}")
133+
134+
return response.choices[0].message.content
135+
except Exception as e:
136+
logger.error(f"Error communicating with OpenAI API: {e}")
137+
raise LLMError(f"Error communicating with OpenAI API: {e}") from e
138+
139+
140+
class AnthropicClient(LLMClient):
141+
"""Client for Anthropic's API."""
142+
143+
def __init__(self, config: AnthropicConfig):
144+
"""Initialize with Anthropic configuration.
145+
146+
Args:
147+
config: The Anthropic configuration.
148+
149+
Raises:
150+
LLMError: If the Anthropic SDK is not available.
151+
"""
152+
self.config = config
153+
try:
154+
from anthropic import Anthropic
155+
156+
self.client = Anthropic(api_key=self.config.api_key)
157+
except ImportError:
158+
raise LLMError("Anthropic SDK (anthropic) not available. Please install it.")
159+
160+
def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str:
161+
"""Generate a completion using Anthropic's API.
162+
163+
Args:
164+
system_prompt: The system prompt to use.
165+
user_prompt: The user prompt to use.
166+
model_name: Optional model name to override the config's model.
167+
168+
Returns:
169+
The generated completion text.
170+
171+
Raises:
172+
LLMError: If there was an error generating the completion.
173+
"""
174+
# Use provided model_name or fall back to config
175+
actual_model = model_name if model_name is not None else self.config.model
176+
177+
try:
178+
response = self.client.messages.create(
179+
model=actual_model,
180+
system=system_prompt,
181+
messages=[{"role": "user", "content": user_prompt}],
182+
max_tokens=self.config.max_tokens,
183+
temperature=self.config.temperature,
184+
)
185+
186+
return response.content[0].text
187+
except Exception as e:
188+
logger.error(f"Error communicating with Anthropic API: {e}")
189+
raise LLMError(f"Error communicating with Anthropic API: {e}") from e
190+
191+
192+
class GoogleClient(LLMClient):
193+
"""Client for Google's Generative AI API."""
194+
195+
def __init__(self, config: GoogleConfig):
196+
"""Initialize with Google configuration.
197+
198+
Args:
199+
config: The Google configuration.
200+
201+
Raises:
202+
LLMError: If the Google Gen AI SDK is not available.
203+
"""
204+
self.config = config
205+
if genai is None:
206+
raise LLMError("Google Gen AI SDK (google-genai) not available. Please install it.")
207+
208+
try:
209+
self.client = genai.Client(api_key=self.config.api_key)
210+
except Exception as e:
211+
raise LLMError(f"Error initializing Google Gen AI client: {e}") from e
212+
213+
def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str:
214+
"""Generate a completion using Google's Generative AI API.
215+
216+
Args:
217+
system_prompt: The system prompt to use (Note: currently not used by Google's API directly).
218+
user_prompt: The user prompt to use.
219+
model_name: Optional model name to override the config's model.
220+
221+
Returns:
222+
The generated completion text.
223+
224+
Raises:
225+
LLMError: If there was an error generating the completion.
226+
"""
227+
# Use provided model_name or fall back to config
228+
actual_model = model_name if model_name is not None else self.config.model
229+
230+
if genai_types is None:
231+
raise LLMError(
232+
"Google Gen AI SDK (google-genai) types not available. SDK might not be installed correctly."
233+
)
234+
235+
# Prepare generation config from model_kwargs
236+
generation_config_params: Dict[str, Any] = (
237+
self.config.model_kwargs.copy() if self.config.model_kwargs is not None else {}
238+
)
239+
240+
if self.config.temperature is not None:
241+
generation_config_params["temperature"] = self.config.temperature
242+
if self.config.max_output_tokens is not None:
243+
generation_config_params["max_output_tokens"] = self.config.max_output_tokens
244+
245+
final_sdk_params = generation_config_params if generation_config_params else None
246+
247+
# TODO: Incorporate system_prompt into user_prompt for Google models
248+
# Since Google models don't have a direct system prompt parameter,
249+
# we might need to combine them or use a different approach
250+
251+
try:
252+
response = self.client.models.generate_content(
253+
model=actual_model, contents=user_prompt, generation_config=final_sdk_params
254+
)
255+
256+
# Check for blocked prompt
257+
if (
258+
hasattr(response, "prompt_feedback")
259+
and response.prompt_feedback
260+
and response.prompt_feedback.block_reason
261+
):
262+
logger.warning(f"Google LLM prompt blocked. Reason: {response.prompt_feedback.block_reason}")
263+
return f"Completion generation failed: Prompt blocked by API (Reason: {response.prompt_feedback.block_reason})"
264+
265+
# Check for empty response
266+
if not response.text:
267+
logger.warning(f"Google LLM returned no text. Response: {response}")
268+
return "Completion generation failed: No text returned by API."
269+
270+
return response.text
271+
except Exception as e:
272+
logger.error(f"Error communicating with Google Gen AI API: {e}")
273+
raise LLMError(f"Error communicating with Google Gen AI API: {e}") from e

0 commit comments

Comments
 (0)