Skip to content

Commit 5a02586

Browse files
committed
added prompt argument validation in detect and detect_async method call detect
1 parent c099cd6 commit 5a02586

File tree

1 file changed

+14
-40
lines changed

1 file changed

+14
-40
lines changed

zenguard/zenguard.py

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from dataclasses import dataclass
66
from enum import Enum
7-
from typing import Optional
7+
from typing import Any, Dict, Optional
88

99
import httpx
1010
from openai import OpenAI
@@ -117,16 +117,18 @@ def detect(self, detectors: list[Detector], prompt: str):
117117
"""
118118
Uses detectors to evaluate the prompt and return the results.
119119
"""
120+
if prompt.isspace() or prompt == "":
121+
raise ValueError("Prompt must not be an empty string or whitespace string")
122+
120123
if len(detectors) == 0:
121124
raise ValueError("No detectors were provided")
122125

123-
url = self._backend
126+
json: Dict[str, Any] = {"messages": [prompt]}
124127
if len(detectors) == 1:
125-
url += convert_detector_to_api(detectors[0])
126-
json = {"messages": [prompt]}
128+
url = f"{self._backend}{convert_detector_to_api(detectors[0])}"
127129
else:
128-
url += "v1/detect"
129-
json = {"messages": [prompt], "in_parallel": True, "detectors": detectors}
130+
url = f"{self._backend}v1/detect"
131+
json["detectors"] = detectors
130132

131133
try:
132134
response = httpx.post(
@@ -135,49 +137,21 @@ def detect(self, detectors: list[Detector], prompt: str):
135137
headers={"x-api-key": self._api_key},
136138
timeout=20,
137139
)
138-
response.raise_for_status()
140+
if response.status_code != 200:
141+
raise RuntimeError(
142+
f"Received an unexpected status code: {response.status_code}\nResponse content: {response.json()}"
143+
)
144+
return response.json()
139145
except httpx.RequestError as e:
140146
raise RuntimeError(
141147
f"An error occurred while making the request: {str(e)}"
142148
) from e
143-
except httpx.HTTPStatusError as e:
144-
raise RuntimeError(
145-
f"Received an unexpected status code: {response.status_code}\nResponse content: {response.json()}"
146-
) from e
147-
148-
return response.json()
149149

150150
def detect_async(self, detectors: list[Detector], prompt: str):
151151
"""
152152
Same as detect function but asynchroneous.
153153
"""
154-
if len(detectors) == 0:
155-
raise ValueError("No detectors were provided")
156-
157-
if detectors[0] != Detector.PROMPT_INJECTION:
158-
raise ValueError(
159-
"Only Prompt Injection detector is supported for async detection"
160-
)
161-
162-
url = self._backend + convert_detector_to_api(detectors[0]) + "_async"
163-
json = {"messages": [prompt]}
164-
165-
try:
166-
response = httpx.post(
167-
url,
168-
json=json,
169-
headers={"x-api-key": self._api_key},
170-
timeout=20,
171-
)
172-
response.raise_for_status()
173-
except httpx.RequestError as e:
174-
raise RuntimeError(
175-
f"An error occurred while making the request: {str(e)}"
176-
) from e
177-
except httpx.HTTPStatusError as e:
178-
raise RuntimeError(
179-
f"Received an unexpected status code: {response.status_code}\nResponse content: {response.json()}"
180-
) from e
154+
return self.detect(detectors, prompt)
181155

182156
def _attack_zenguard(self, detector: Detector, attacks: list[str]):
183157
attacks = tqdm(attacks)

0 commit comments

Comments
 (0)