Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"openai>=1.58.1",
"pydantic>=2.9.2",
"og-x402==0.0.1.dev4"

]

[project.optional-dependencies]
Expand Down
13 changes: 13 additions & 0 deletions src/opengradient/client/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..types import HistoricalInputQuery, InferenceMode, InferenceResult, ModelOutput, SchedulerParams
from ._conversions import convert_array_to_model_output, convert_to_model_input, convert_to_model_output # type: ignore[attr-defined]
from ._utils import get_abi, get_bin, run_with_retry
from ..types import InferenceRequest

DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai"
DEFAULT_API_URL = "https://sdk-devnet.opengradient.ai"
Expand Down Expand Up @@ -80,6 +81,18 @@ def infer(
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
max_retries: Optional[int] = None,
) -> InferenceResult:
# Validate the data using the Pydantic model
validated = InferenceRequest(
model_cid=model_cid,
inference_mode=inference_mode,
model_input=model_input
)

# From here on, we use the validated data
# (Optional: Re-assigning ensures we use the cleaned types)
model_cid = validated.model_cid
inference_mode = validated.inference_mode
model_input = validated.model_input
"""
Perform inference on a model.

Expand Down
19 changes: 17 additions & 2 deletions src/opengradient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import time
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union, Any

import numpy as np

from pydantic import BaseModel, Field, field_validator, ConfigDict

class x402SettlementMode(str, Enum):
"""
Expand Down Expand Up @@ -552,3 +552,18 @@ class ModelRepository:
class FileUploadResult:
modelCid: str
size: int


class InferenceRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model_cid: str = Field(..., pattern=r"^(Qm|ba)[1-9A-HJ-NP-Za-km-z]+$")
inference_mode: Any # We let Pydantic handle the Enum conversion
model_input: Dict[str, Any]

# Enable arbitrary types so Pydantic doesn't complain about numpy arrays
@field_validator('model_input')
@classmethod
def validate_inputs(cls, v: Dict[str, Any]):
if not v:
raise ValueError("model_input cannot be empty.")
return v
31 changes: 31 additions & 0 deletions tests/test_inference_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import opengradient as og
from opengradient.types import InferenceMode
from pydantic import ValidationError

def test_alpha_infer_validation():
# Dummy key used for testing validation logic
alpha = og.Alpha(private_key="0x" + "1" * 64)

# 1. Test that invalid CID and empty input RAISE an error
with pytest.raises(ValidationError):
alpha.infer(
model_cid="invalid_id",
inference_mode=InferenceMode.VANILLA,
model_input={}
)

# 2. Test that valid data DOES NOT raise a ValidationError
# Note: It might fail later due to the dummy key, but Pydantic should pass it
try:
alpha.infer(
model_cid="QmXoypizjW3WknFiJnKLwHCnL72vedxjQkDDP1mXWo6uco",
inference_mode=InferenceMode.VANILLA,
model_input={"data": [1, 2, 3]}
)
except ValidationError:
pytest.fail("Pydantic rejected a valid CID!")
except Exception as e:
# We ignore other errors (like RPC/Key errors) because
# we are only testing the VALIDATION layer here.
print(f"Skipping non-validation error: {type(e).__name__}")