|
20 | 20 | from javelin_sdk.services.secret_service import SecretService |
21 | 21 | from javelin_sdk.services.template_service import TemplateService |
22 | 22 | from javelin_sdk.services.trace_service import TraceService |
| 23 | +from javelin_sdk.services.guardrails_service import GuardrailsService |
23 | 24 | from javelin_sdk.tracing_setup import configure_span_exporter |
24 | 25 | import inspect |
25 | 26 | from opentelemetry.trace import SpanKind |
@@ -98,6 +99,7 @@ def __init__(self, config: JavelinConfig) -> None: |
98 | 99 | self.template_service = TemplateService(self) |
99 | 100 | self.trace_service = TraceService(self) |
100 | 101 | self.modelspec_service = ModelSpecService(self) |
| 102 | + self.guardrails_service = GuardrailsService(self) |
101 | 103 |
|
102 | 104 | self.chat = Chat(self) |
103 | 105 | self.completions = Completions(self) |
@@ -899,6 +901,8 @@ def _prepare_request(self, request: Request) -> tuple: |
899 | 901 | is_model_specs=request.is_model_specs, |
900 | 902 | is_reload=request.is_reload, |
901 | 903 | univ_model=request.univ_model_config, |
| 904 | + guardrail=request.guardrail, |
| 905 | + list_guardrails=request.list_guardrails, |
902 | 906 | ) |
903 | 907 | headers = {**self._headers, **(request.headers or {})} |
904 | 908 | return url, headers |
@@ -939,6 +943,8 @@ def _construct_url( |
939 | 943 | is_model_specs: bool = False, |
940 | 944 | is_reload: bool = False, |
941 | 945 | univ_model: Optional[Dict[str, Any]] = None, |
| 946 | + guardrail: Optional[str] = None, |
| 947 | + list_guardrails: bool = False, |
942 | 948 | ) -> str: |
943 | 949 | url_parts = [self.base_url] |
944 | 950 |
|
@@ -993,6 +999,13 @@ def _construct_url( |
993 | 999 | url_parts.extend(["admin", "archives"]) |
994 | 1000 | if archive != "###": |
995 | 1001 | url_parts.append(archive) |
| 1002 | + elif guardrail: |
| 1003 | + if guardrail == "all": |
| 1004 | + url_parts.extend(["guardrails", "apply"]) |
| 1005 | + else: |
| 1006 | + url_parts.extend(["guardrail", guardrail, "apply"]) |
| 1007 | + elif list_guardrails: |
| 1008 | + url_parts.extend(["guardrails", "list"]) |
996 | 1009 | else: |
997 | 1010 | url_parts.extend(["admin", "routes"]) |
998 | 1011 |
|
@@ -1201,6 +1214,12 @@ def _construct_url( |
1201 | 1214 | ) |
1202 | 1215 | ) |
1203 | 1216 |
|
| 1217 | + # Guardrails methods |
| 1218 | + apply_trustsafety = lambda self, text, config=None: self.guardrails_service.apply_trustsafety(text, config) |
| 1219 | + apply_promptinjectiondetection = lambda self, text, config=None: self.guardrails_service.apply_promptinjectiondetection(text, config) |
| 1220 | + apply_guardrails = lambda self, text, guardrails: self.guardrails_service.apply_guardrails(text, guardrails) |
| 1221 | + list_guardrails = lambda self: self.guardrails_service.list_guardrails() |
| 1222 | + |
1204 | 1223 | ## Traces methods |
1205 | 1224 | get_traces = lambda self: self.trace_service.get_traces() |
1206 | 1225 | aget_traces = lambda self: self.trace_service.aget_traces() |
@@ -1286,3 +1305,9 @@ def set_headers(self, headers: Dict[str, str]) -> None: |
1286 | 1305 | headers (Dict[str, str]): A dictionary of headers to set or update. |
1287 | 1306 | """ |
1288 | 1307 | self._headers.update(headers) |
| 1308 | + |
| 1309 | + # Guardrails methods |
| 1310 | + apply_trustsafety = lambda self, text, config=None: self.guardrails_service.apply_trustsafety(text, config) |
| 1311 | + apply_promptinjectiondetection = lambda self, text, config=None: self.guardrails_service.apply_promptinjectiondetection(text, config) |
| 1312 | + apply_guardrails = lambda self, text, guardrails: self.guardrails_service.apply_guardrails(text, guardrails) |
| 1313 | + list_guardrails = lambda self: self.guardrails_service.list_guardrails() |
0 commit comments