|
5 | 5 | # spell-checker:ignore ollama, hnsw, mult, ocid, testset |
6 | 6 |
|
7 | 7 | from typing import Optional, Literal, Union |
8 | | -from pydantic import BaseModel, Field, PrivateAttr |
| 8 | +from pydantic import BaseModel, Field, PrivateAttr, model_validator |
9 | 9 |
|
10 | 10 | from langchain_core.messages import ChatMessage |
11 | 11 | import oracledb |
|
18 | 18 | DistanceMetrics = Literal["COSINE", "EUCLIDEAN_DISTANCE", "DOT_PRODUCT"] |
19 | 19 | IndexTypes = Literal["HNSW", "IVF"] |
20 | 20 |
|
| 21 | +# ModelAPIs |
| 22 | +EmbedAPI = Literal[ |
| 23 | + "OllamaEmbeddings", |
| 24 | + "OCIGenAIEmbeddings", |
| 25 | + "CompatOpenAIEmbeddings", |
| 26 | + "OpenAIEmbeddings", |
| 27 | + "CohereEmbeddings", |
| 28 | + "HuggingFaceEndpointEmbeddings", |
| 29 | +] |
| 30 | +LlAPI = Literal[ |
| 31 | + "ChatOllama", |
| 32 | + "ChatOCIGenAI", |
| 33 | + "CompatOpenAI", |
| 34 | + "Perplexity", |
| 35 | + "OpenAI", |
| 36 | + "Cohere", |
| 37 | +] |
| 38 | + |
21 | 39 |
|
22 | 40 | ##################################################### |
23 | 41 | # Database |
@@ -110,6 +128,21 @@ class Model(ModelAccess, LanguageModelParameters, EmbeddingModelParameters): |
110 | 128 | openai_compat: bool = Field(default=True, description="Is the API OpenAI compatible?") |
111 | 129 | status: Statuses = Field(default="UNVERIFIED", description="Status (read-only)", readOnly=True) |
112 | 130 |
|
| 131 | + @model_validator(mode="after") |
| 132 | + def check_api_matches_type(self): |
| 133 | + """Validate valid API""" |
| 134 | + ll_apis = LlAPI.__args__ |
| 135 | + embed_apis = EmbedAPI.__args__ |
| 136 | + |
| 137 | + if not self.api or self.api == "unset": |
| 138 | + return self |
| 139 | + |
| 140 | + if self.type == "ll" and self.api not in ll_apis: |
| 141 | + raise ValueError(f"API '{self.api}' is not valid for type 'll'. Must be one of: {ll_apis}") |
| 142 | + if self.type == "embed" and self.api not in embed_apis: |
| 143 | + raise ValueError(f"API '{self.api}' is not valid for type 'embed'. Must be one of: {embed_apis}") |
| 144 | + return self |
| 145 | + |
113 | 146 |
|
114 | 147 | ##################################################### |
115 | 148 | # Oracle Cloud Infrastructure |
|
0 commit comments