From 7e607450d11193c173b29b281a511082fa6bbaab Mon Sep 17 00:00:00 2001 From: DanFrolov <61428101+DanFrolov@users.noreply.github.com> Date: Sat, 4 Apr 2026 14:07:58 -0400 Subject: [PATCH 1/2] feat: implement Pydantic v2 validation and unit tests for alpha inference --- pyproject.toml | 1 + src/opengradient/client/alpha.py | 13 +++++++++++++ src/opengradient/types.py | 19 ++++++++++++++++-- tests/test_inference_validation.py | 31 ++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/test_inference_validation.py diff --git a/pyproject.toml b/pyproject.toml index 34754b4e..2bf9647c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "openai>=1.58.1", "pydantic>=2.9.2", "og-x402==0.0.1.dev4" + ] [project.optional-dependencies] diff --git a/src/opengradient/client/alpha.py b/src/opengradient/client/alpha.py index d2957c2d..41aadd62 100644 --- a/src/opengradient/client/alpha.py +++ b/src/opengradient/client/alpha.py @@ -20,6 +20,7 @@ from ..types import HistoricalInputQuery, InferenceMode, InferenceResult, ModelOutput, SchedulerParams from ._conversions import convert_array_to_model_output, convert_to_model_input, convert_to_model_output # type: ignore[attr-defined] from ._utils import get_abi, get_bin, run_with_retry +from ..types import InferenceRequest DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_API_URL = "https://sdk-devnet.opengradient.ai" @@ -80,6 +81,18 @@ def infer( model_input: Dict[str, Union[str, int, float, List, np.ndarray]], max_retries: Optional[int] = None, ) -> InferenceResult: + # Validate the data using the Pydantic model + validated = InferenceRequest( + model_cid=model_cid, + inference_mode=inference_mode, + model_input=model_input + ) + + # From here on, we use the validated data + # (Optional: Re-assigning ensures we use the cleaned types) + model_cid = validated.model_cid + inference_mode = validated.inference_mode + model_input = validated.model_input """ Perform inference on a model. diff --git a/src/opengradient/types.py b/src/opengradient/types.py index a59293fa..13d94e03 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -5,10 +5,10 @@ import time from dataclasses import dataclass from enum import Enum, IntEnum -from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union, Any import numpy as np - +from pydantic import BaseModel, Field, field_validator, ConfigDict class x402SettlementMode(str, Enum): """ @@ -552,3 +552,18 @@ class ModelRepository: class FileUploadResult: modelCid: str size: int + + +class InferenceRequest(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + model_cid: str = Field(..., pattern=r"^(Qm|ba)[1-9A-HJ-NP-Za-km-z]+$") + inference_mode: Any # We let Pydantic handle the Enum conversion + model_input: Dict[str, Any] + + # Enable arbitrary types so Pydantic doesn't complain about numpy arrays + @field_validator('model_input') + @classmethod + def validate_inputs(cls, v: Dict[str, Any]): + if not v: + raise ValueError("model_input cannot be empty.") + return v \ No newline at end of file diff --git a/tests/test_inference_validation.py b/tests/test_inference_validation.py new file mode 100644 index 00000000..c499f8e0 --- /dev/null +++ b/tests/test_inference_validation.py @@ -0,0 +1,31 @@ +import pytest +import opengradient as og +from opengradient.types import InferenceMode +from pydantic import ValidationError + +def test_alpha_infer_validation(): + # Dummy key used for testing validation logic + alpha = og.Alpha(private_key="0x" + "1" * 64) + + # 1. Test that invalid CID and empty input RAISE an error + with pytest.raises(ValidationError): + alpha.infer( + model_cid="invalid_id", + inference_mode=InferenceMode.VANILLA, + model_input={} + ) + + # 2. Test that valid data DOES NOT raise a ValidationError + # Note: It might fail later due to the dummy key, but Pydantic should pass it + try: + alpha.infer( + model_cid="QmXoypizjW3WknFiJnKLwHCnL72vedxjQkDDP1mXWo6uco", + inference_mode=InferenceMode.VANILLA, + model_input={"data": [1, 2, 3]} + ) + except ValidationError: + pytest.fail("Pydantic rejected a valid CID!") + except Exception: + # We ignore other errors (like RPC/Key errors) because + # we are only testing the VALIDATION layer here. + pass \ No newline at end of file From 0cc89c8edac33c9a7d7b51611d08ac5259b42dd8 Mon Sep 17 00:00:00 2001 From: DanFrolov <61428101+DanFrolov@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:22:20 -0400 Subject: [PATCH 2/2] refactor: replace silent pass with explicit exception handling --- tests/test_inference_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_inference_validation.py b/tests/test_inference_validation.py index c499f8e0..a7d29ee2 100644 --- a/tests/test_inference_validation.py +++ b/tests/test_inference_validation.py @@ -25,7 +25,7 @@ def test_alpha_infer_validation(): ) except ValidationError: pytest.fail("Pydantic rejected a valid CID!") - except Exception: + except Exception as e: # We ignore other errors (like RPC/Key errors) because # we are only testing the VALIDATION layer here. - pass \ No newline at end of file + print(f"Skipping non-validation error: {type(e).__name__}") \ No newline at end of file