diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index ed54fd9..622fed5 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -47,6 +47,17 @@ class _ChatParams: x402_settlement_mode: x402SettlementMode +def _get_model_id(model) -> str: + """Extract model ID from provider/model-name format, raising ValueError for invalid formats.""" + value = model.value if hasattr(model, "value") else str(model) + parts = value.split("/", 1) + if len(parts) != 2: + raise ValueError( + f"Invalid model identifier '{value}'. Expected 'provider/model-name' format." + ) + return parts[1] + + class LLM: """ LLM inference namespace. @@ -252,7 +263,7 @@ async def completion( Raises: RuntimeError: If the inference fails. """ - model_id = model.split("/")[1] + model_id = _get_model_id(model) payload: Dict = { "model": model_id, "prompt": prompt, @@ -327,7 +338,7 @@ async def chat( RuntimeError: If the inference fails. """ params = _ChatParams( - model=model.split("/")[1], + model=_get_model_id(model), max_tokens=max_tokens, temperature=temperature, stop_sequence=stop_sequence,