Skip to content

Commit 77b4f41

Browse files
committed
Added pydantic validation to _get_version_number, compare_versions, restore_version in routes, updated tests for new error messages
1 parent 94581fe commit 77b4f41

File tree

3 files changed

+84
-23
lines changed

3 files changed

+84
-23
lines changed

pydatalab/src/pydatalab/models/versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class VersionCounter(BaseModel):
7373
)
7474

7575
class Config:
76-
extra = "forbid"
76+
extra = "ignore" # Allow MongoDB's _id field and other internal fields
7777

7878

7979
class RestoreVersionRequest(BaseModel):

pydatalab/src/pydatalab/routes/v0_1/items.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from pydatalab.models.people import Person
2323
from pydatalab.models.relationships import RelationshipType
2424
from pydatalab.models.utils import generate_unique_refcode
25-
from pydatalab.models.versions import VersionAction
25+
from pydatalab.models.versions import (
26+
CompareVersionsQuery,
27+
RestoreVersionRequest,
28+
VersionAction,
29+
VersionCounter,
30+
)
2631
from pydatalab.mongo import ITEMS_FTS_FIELDS, flask_mongo
2732
from pydatalab.permissions import PUBLIC_USER_ID, active_users_or_get_only, get_default_permissions
2833

@@ -1060,7 +1065,20 @@ def _get_next_version_number(refcode: str) -> int:
10601065
upsert=True,
10611066
return_document=True, # Return the document after update
10621067
)
1063-
return result["counter"]
1068+
1069+
# Validate the result with Pydantic
1070+
try:
1071+
counter_doc = VersionCounter(**result)
1072+
return counter_doc.counter
1073+
except ValidationError as exc:
1074+
LOGGER.error(
1075+
"Version counter validation failed for refcode %s: %s",
1076+
refcode,
1077+
str(exc),
1078+
)
1079+
# Fallback: return raw counter value to prevent blocking saves
1080+
# This should only happen if the document is corrupted
1081+
return result["counter"]
10641082

10651083

10661084
@ITEMS.route("/items/<refcode>/versions/", methods=["GET"])
@@ -1113,14 +1131,19 @@ def compare_versions(refcode):
11131131
if len(refcode.split(":")) != 2:
11141132
refcode = f"{CONFIG.IDENTIFIER_PREFIX}:{refcode}"
11151133

1116-
v1_id = request.args.get("v1")
1117-
v2_id = request.args.get("v2")
1118-
if not v1_id or not v2_id:
1119-
return jsonify({"status": "error", "message": "Both v1 and v2 must be provided"}), 400
1134+
# Validate query parameters using Pydantic model
1135+
try:
1136+
query_params = CompareVersionsQuery(
1137+
v1=request.args.get("v1", ""), v2=request.args.get("v2", "")
1138+
)
1139+
except ValidationError as exc:
1140+
return jsonify(
1141+
{"status": "error", "message": "Invalid query parameters", "errors": exc.errors()}
1142+
), 400
11201143

11211144
try:
1122-
v1_object_id = ObjectId(v1_id)
1123-
v2_object_id = ObjectId(v2_id)
1145+
v1_object_id = ObjectId(query_params.v1)
1146+
v2_object_id = ObjectId(query_params.v2)
11241147
except (InvalidId, TypeError) as e:
11251148
return jsonify({"status": "error", "message": f"Invalid version ID format: {str(e)}"}), 400
11261149

@@ -1168,15 +1191,20 @@ def restore_version(refcode):
11681191
if len(refcode.split(":")) != 2:
11691192
refcode = f"{CONFIG.IDENTIFIER_PREFIX}:{refcode}"
11701193

1171-
req = request.get_json()
1172-
version_id = req.get("version_id")
1173-
if not version_id:
1174-
return jsonify({"status": "error", "message": "version_id must be provided"}), 400
1194+
# Validate request body using Pydantic model
1195+
try:
1196+
restore_request = RestoreVersionRequest(**request.get_json())
1197+
except ValidationError as exc:
1198+
return jsonify(
1199+
{"status": "error", "message": "Invalid request body", "errors": exc.errors()}
1200+
), 400
11751201

11761202
try:
1177-
version_object_id = ObjectId(version_id)
1203+
version_object_id = ObjectId(restore_request.version_id)
11781204
except (InvalidId, TypeError):
1179-
return jsonify({"status": "error", "message": f"Invalid version_id: {version_id}"}), 400
1205+
return jsonify(
1206+
{"status": "error", "message": f"Invalid version_id: {restore_request.version_id}"}
1207+
), 400
11801208

11811209
# Check permissions - user must have write access
11821210
current_item = flask_mongo.db.items.find_one(
@@ -1279,7 +1307,7 @@ def restore_version(refcode):
12791307
return jsonify(
12801308
{
12811309
"status": "success",
1282-
"restored_version": version_id,
1310+
"restored_version": restore_request.version_id,
12831311
"new_version_number": next_version_number,
12841312
}
12851313
), 200

pydatalab/tests/server/test_item_versions.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,22 +234,43 @@ def test_compare_versions_missing_parameters(self, client, sample_with_version):
234234
"""Test comparing versions with missing parameters."""
235235
refcode = sample_with_version.refcode.split(":")[1]
236236

237-
# Missing v2
237+
# Missing v2 - request.args.get() returns "" for missing params, which fails ObjectId validation
238238
response = client.get(f"/items/{refcode}/compare-versions/?v1=some_id")
239239
assert response.status_code == 400
240-
assert "Both v1 and v2 must be provided" in response.json["message"]
241-
242-
# Missing v1
240+
assert response.json["message"] == "Invalid query parameters"
241+
assert "errors" in response.json
242+
errors = response.json["errors"]
243+
# Should have error for v2 (empty string is invalid ObjectId)
244+
v2_errors = [e for e in errors if "v2" in str(e["loc"])]
245+
assert len(v2_errors) == 1
246+
assert "valid objectid" in v2_errors[0]["msg"].lower()
247+
248+
# Missing v1 - same behavior
243249
response = client.get(f"/items/{refcode}/compare-versions/?v2=some_id")
244250
assert response.status_code == 400
251+
assert response.json["message"] == "Invalid query parameters"
252+
assert "errors" in response.json
253+
errors = response.json["errors"]
254+
# Should have error for v1 (empty string is invalid ObjectId)
255+
v1_errors = [e for e in errors if "v1" in str(e["loc"])]
256+
assert len(v1_errors) == 1
257+
assert "valid objectid" in v1_errors[0]["msg"].lower()
245258

246259
def test_compare_versions_invalid_id(self, client, sample_with_version):
247260
"""Test comparing versions with invalid ID format."""
248261
refcode = sample_with_version.refcode.split(":")[1]
249262
response = client.get(f"/items/{refcode}/compare-versions/?v1=invalid&v2=invalid")
250263

251264
assert response.status_code == 400
252-
assert "Invalid version ID format" in response.json["message"]
265+
assert response.json["message"] == "Invalid query parameters"
266+
# Check Pydantic's structured error response
267+
assert "errors" in response.json
268+
errors = response.json["errors"]
269+
# Should have errors for both v1 and v2
270+
assert len(errors) == 2
271+
for error in errors:
272+
assert error["loc"][0] in ["v1", "v2"]
273+
assert "valid ObjectId" in error["msg"]
253274

254275
def test_compare_versions_detects_changes(self, client, sample_with_version):
255276
"""Test that compare_versions properly detects changes using DeepDiff."""
@@ -417,15 +438,27 @@ def test_restore_version_missing_version_id(self, client, sample_with_version):
417438
response = client.post(f"/items/{refcode}/restore-version/", json={})
418439

419440
assert response.status_code == 400
420-
assert "version_id must be provided" in response.json["message"]
441+
assert response.json["message"] == "Invalid request body"
442+
# Check Pydantic's structured error response
443+
assert "errors" in response.json
444+
errors = response.json["errors"]
445+
assert len(errors) == 1
446+
assert errors[0]["loc"] == ["version_id"]
447+
assert "required" in errors[0]["msg"].lower()
421448

422449
def test_restore_version_invalid_id(self, client, sample_with_version):
423450
"""Test restoring with invalid version ID."""
424451
refcode = sample_with_version.refcode.split(":")[1]
425452
response = client.post(f"/items/{refcode}/restore-version/", json={"version_id": "invalid"})
426453

427454
assert response.status_code == 400
428-
assert "Invalid version_id" in response.json["message"]
455+
assert response.json["message"] == "Invalid request body"
456+
# Check Pydantic's structured error response
457+
assert "errors" in response.json
458+
errors = response.json["errors"]
459+
assert len(errors) == 1
460+
assert errors[0]["loc"] == ["version_id"]
461+
assert "valid ObjectId" in errors[0]["msg"]
429462

430463
def test_restore_version_nonexistent(self, client, sample_with_version):
431464
"""Test restoring non-existent version."""

0 commit comments

Comments
 (0)