diff --git a/backend/app/api/docs/credentials/create.md b/backend/app/api/docs/credentials/create.md index df5779201..31fb87980 100644 --- a/backend/app/api/docs/credentials/create.md +++ b/backend/app/api/docs/credentials/create.md @@ -6,6 +6,7 @@ Credentials are encrypted and stored securely for provider integrations (OpenAI, - **LLM:** openai, sarvamai, google(gemini) - **Observability:** langfuse - **Audio:** elevenlabs +- **Miscellaneous** webhook_secret ### Examples: @@ -40,7 +41,19 @@ Credentials are encrypted and stored securely for provider integrations (OpenAI, "public_key": "pk-lf-....", "secret_key": "sk-lf-...", "host": "https://cloud.langfuse.com" - } + }, + "webhook_secret": { + "webhook_secret: "webhook_secret" + }, } } ``` +#### For registering Webhook Secret +```json +{ + "credential":{ + "webhook_secret":"your-webhook-secret" + } + +} +``` diff --git a/backend/app/api/docs/credentials/update.md b/backend/app/api/docs/credentials/update.md index cf08360d4..f9f792abd 100644 --- a/backend/app/api/docs/credentials/update.md +++ b/backend/app/api/docs/credentials/update.md @@ -32,3 +32,4 @@ The `credential` field accepts **two formats** (both work the same): - **LLM:** openai, sarvamai, google(gemini) - **Observability:** langfuse - **Audio:** elevenlabs +- **Miscellaneous** webhook_secret diff --git a/backend/app/api/docs/credentials/update_by_org_project.md b/backend/app/api/docs/credentials/update_by_org_project.md index c010871d4..5efe026fe 100644 --- a/backend/app/api/docs/credentials/update_by_org_project.md +++ b/backend/app/api/docs/credentials/update_by_org_project.md @@ -36,3 +36,4 @@ The `credential` field accepts **two formats** (both work the same): - **LLM:** openai, sarvamai, google(gemini) - **Observability:** langfuse - **Audio:** elevenlabs +- **Miscellaneous** webhook_secret diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 793995422..11d2c888b 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -14,6 +14,7 @@ class Provider(str, Enum): GOOGLE = "google" SARVAMAI = "sarvamai" ELEVENLABS = "elevenlabs" + WEBHOOK_SECRET = "webhook_secret" @dataclass @@ -42,6 +43,9 @@ class ProviderConfig: Provider.ELEVENLABS: ProviderConfig( required_fields=["api_key"], sensitive_fields=["api_key"] ), + Provider.WEBHOOK_SECRET: ProviderConfig( + required_fields=["webhook_secret"], sensitive_fields=["webhook_secret"] + ), } diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 0ffedf96b..4acce89e1 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -29,7 +29,7 @@ ) from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_create_collection_job -from app.utils import send_callback, APIResponse +from app.utils import send_callback, get_webhook_secret, APIResponse logger = logging.getLogger(__name__) @@ -274,7 +274,12 @@ def execute_job( ) if creation_request.callback_url: - send_callback(creation_request.callback_url, success_payload) + webhook_secret = get_webhook_secret(project_id, organization_id) + send_callback( + str(creation_request.callback_url), + success_payload, + webhook_secret=webhook_secret, + ) except Exception as err: span.record_exception(err) @@ -303,5 +308,10 @@ def execute_job( if creation_request and creation_request.callback_url and collection_job: failure_payload = build_failure_payload(collection_job, str(err)) - send_callback(creation_request.callback_url, failure_payload) + webhook_secret = get_webhook_secret(project_id, organization_id) + send_callback( + str(creation_request.callback_url), + failure_payload, + webhook_secret=webhook_secret, + ) raise diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index 99cfaa8bf..1c8e8a497 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -19,7 +19,7 @@ from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_delete_collection_job from app.core.telemetry import log_context -from app.utils import send_callback, APIResponse +from app.utils import send_callback, get_webhook_secret, APIResponse logger = logging.getLogger(__name__) @@ -104,6 +104,7 @@ def build_failure_payload( def _mark_job_failed_and_callback( *, + organization_id: int, project_id: int, collection_id: UUID, job_id: UUID, @@ -146,7 +147,8 @@ def _mark_job_failed_and_callback( collection_id=collection_id, error_message=str(err), ) - send_callback(callback_url, failure_payload) + webhook_secret = get_webhook_secret(project_id, organization_id) + send_callback(callback_url, failure_payload, webhook_secret=webhook_secret) def execute_job( @@ -162,8 +164,11 @@ def execute_job( deletion_request = DeletionRequest(**request) - collection_id = UUID(collection_id) + collection_uuid = UUID(collection_id) job_uuid = UUID(job_id) + callback_url = ( + str(deletion_request.callback_url) if deletion_request.callback_url else None + ) collection_job = None @@ -172,12 +177,12 @@ def execute_job( lifecycle="collection.delete.execute_job", action="delete", collection_job_id=job_id, - collection_id=collection_id, + collection_id=str(collection_uuid), task_id=task_id, project_id=project_id, organization_id=organization_id, ), tracer.start_as_current_span("collections.delete.execute_job") as span: - span.set_attribute("collection.id", str(collection_id)) + span.set_attribute("collection.id", str(collection_uuid)) span.set_attribute("collection.job_id", str(job_uuid)) span.set_attribute("kaapi.project_id", project_id) span.set_attribute("kaapi.organization_id", organization_id) @@ -194,7 +199,9 @@ def execute_job( ), ) - collection = CollectionCrud(session, project_id).read_one(collection_id) + collection = CollectionCrud(session, project_id).read_one( + collection_uuid + ) span.set_attribute("collection.provider", str(collection.provider)) provider = get_llm_provider( @@ -208,7 +215,7 @@ def execute_job( provider.delete(collection) with Session(engine) as session: - CollectionCrud(session, project_id).delete_by_id(collection_id) + CollectionCrud(session, project_id).delete_by_id(collection_uuid) collection_job_crud = CollectionJobCrud(session, project_id) collection_job_crud.update( @@ -223,24 +230,30 @@ def execute_job( logger.info( "[delete_collection.execute_job] Collection deleted successfully | " "{'collection_id': '%s', 'job_id': '%s'}", - str(collection_id), + str(collection_uuid), str(job_uuid), ) - if deletion_request.callback_url and collection_job: + if callback_url and collection_job: success_payload = build_success_payload( collection_job=collection_job, - collection_id=collection_id, + collection_id=collection_uuid, + ) + webhook_secret = get_webhook_secret(project_id, organization_id) + send_callback( + callback_url, + success_payload, + webhook_secret=webhook_secret, ) - send_callback(deletion_request.callback_url, success_payload) except Exception as err: span.record_exception(err) span.set_status(trace.Status(trace.StatusCode.ERROR, str(err))) _mark_job_failed_and_callback( + organization_id=organization_id, project_id=project_id, - collection_id=collection_id, + collection_id=collection_uuid, job_id=job_uuid, err=err, - callback_url=deletion_request.callback_url, + callback_url=callback_url, ) raise diff --git a/backend/app/services/doctransform/job.py b/backend/app/services/doctransform/job.py index cdbe295f8..62ba9b240 100644 --- a/backend/app/services/doctransform/job.py +++ b/backend/app/services/doctransform/job.py @@ -20,10 +20,11 @@ DocTransformationJobPublic, TransformedDocumentPublic, DocTransformationJob, + Project, ) from app.core.cloud import get_cloud_storage from app.celery.utils import start_doctransform_job -from app.utils import send_callback, APIResponse +from app.utils import send_callback, get_webhook_secret, APIResponse from app.services.doctransform.registry import convert_document, FORMAT_TO_EXTENSION from app.core.db import engine @@ -117,11 +118,18 @@ def execute_job( tmp_dir: Path | None = None job_for_payload = None # keep latest job snapshot for payloads + webhook_secret: str | None = None try: job_uuid = UUID(job_id) source_uuid = UUID(source_document_id) + if callback_url: + with Session(engine) as db: + project = db.get(Project, project_id) + if project is not None: + webhook_secret = get_webhook_secret(project_id, project.organization_id) + logger.info( "[doc_transform.execute_job] started | job_id=%s | transformer=%s | target=%s | project_id=%s", job_uuid, @@ -222,7 +230,7 @@ def execute_job( ) if callback_url: - send_callback(callback_url, success_payload) + send_callback(callback_url, success_payload, webhook_secret=webhook_secret) except Exception as e: logger.error( @@ -251,7 +259,9 @@ def execute_job( if callback_url and job_for_payload: try: failure_payload = build_failure_payload(job_for_payload, str(e)) - send_callback(callback_url, failure_payload) + send_callback( + callback_url, failure_payload, webhook_secret=webhook_secret + ) except Exception as cb_error: logger.error( "[doc_transform.execute_job] callback failed | job_id=%s | error=%s", diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index 29f4fc349..a7bde2799 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -13,7 +13,7 @@ from app.models.llm.response import IntermediateChainResponse, LLMChainResponse from app.services.llm.chain.chain import ChainContext, LLMChain from app.services.llm.chain.types import BlockResult -from app.utils import APIResponse, send_callback +from app.utils import APIResponse, get_webhook_secret, send_callback logger = logging.getLogger(__name__) @@ -31,6 +31,7 @@ def __init__( self._chain = chain self._context = context self._request = request + self._webhook_secret: str | None = None def run(self) -> dict: """Execute the full chain lifecycle. Returns serialized APIResponse.""" @@ -60,6 +61,10 @@ def _setup(self) -> None: status=ChainStatus.RUNNING, ) + self._webhook_secret = get_webhook_secret( + self._context.project_id, self._context.organization_id + ) + def _teardown(self, result: BlockResult) -> dict: """Finalize chain record, send callback, and update job status.""" @@ -76,6 +81,7 @@ def _teardown(self, result: BlockResult) -> dict: send_callback( callback_url=str(self._request.callback_url), data=callback_response.model_dump(), + webhook_secret=self._webhook_secret, ) with Session(engine) as session: JobCrud(session).update( @@ -107,6 +113,7 @@ def _handle_error(self, error: str) -> dict: send_callback( callback_url=str(self._request.callback_url), data=callback_response.model_dump(), + webhook_secret=self._webhook_secret, ) with Session(engine) as session: @@ -166,6 +173,7 @@ def _send_intermediate_callback( send_callback( callback_url=str(self._request.callback_url), data=callback_data.model_dump(), + webhook_secret=self._webhook_secret, ) logger.info( f"[_send_intermediate_callback] Sent intermediate callback | " diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 60dbdb49e..e797040a2 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -47,7 +47,13 @@ ) from app.services.llm.mappers import transform_kaapi_config_to_native from app.services.llm.providers.registry import get_llm_provider -from app.utils import APIResponse, cleanup_temp_file, resolve_input, send_callback +from app.utils import ( + APIResponse, + cleanup_temp_file, + get_webhook_secret, + resolve_input, + send_callback, +) logger = logging.getLogger(__name__) tracer = trace.get_tracer(__name__) @@ -178,15 +184,19 @@ def handle_job_error( job_id: UUID, callback_url: str | None, callback_response: APIResponse, + organization_id: int | None = None, + project_id: int | None = None, ) -> dict: """Handle job failure uniformly — send callback and update DB.""" if callback_url: + webhook_secret = get_webhook_secret(project_id, organization_id) with tracer.start_as_current_span("llm.send_callback") as cb_span: cb_span.set_attribute("callback.url", callback_url) cb_span.set_attribute("callback.status", "failure") send_callback( callback_url=callback_url, data=callback_response.model_dump(), + webhook_secret=webhook_secret, ) with Session(engine) as session: @@ -777,6 +787,7 @@ def execute_job( data=result.response, metadata=result.metadata ) if callback_url_str: + webhook_secret = get_webhook_secret(project_id, organization_id) with tracer.start_as_current_span("llm.send_callback") as cb_span: cb_span.set_attribute("callback.url", callback_url_str) cb_span.set_attribute("callback.status", "success") @@ -784,6 +795,7 @@ def execute_job( send_callback( callback_url=callback_url_str, data=callback_response.model_dump(), + webhook_secret=webhook_secret, ) with Session(engine) as session: @@ -796,12 +808,17 @@ def execute_job( ) return callback_response.model_dump() - error_message = result.error or "Unknown error occurred" callback_response = APIResponse.failure_response( - error=error_message, + error=result.error or "Unknown error occurred", metadata=request.request_metadata, ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + return handle_job_error( + job_uuid, + callback_url_str, + callback_response, + organization_id=organization_id, + project_id=project_id, + ) except Exception as e: callback_response = APIResponse.failure_response( @@ -812,7 +829,13 @@ def execute_job( f"[execute_job] Unexpected error: {str(e)} | job_id={job_uuid}, task_id={task_id}", exc_info=True, ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + return handle_job_error( + job_uuid, + callback_url_str, + callback_response, + organization_id=organization_id, + project_id=project_id, + ) finally: # Ensure task spans are pushed promptly so Sentry dashboards update faster. flush_telemetry() @@ -936,7 +959,13 @@ def execute_chain_job( error="Unexpected error occurred", metadata=request.request_metadata, ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + return handle_job_error( + job_uuid, + callback_url_str, + callback_response, + organization_id=organization_id, + project_id=project_id, + ) finally: # Ensure task spans are pushed promptly so Sentry dashboards update faster. flush_telemetry() diff --git a/backend/app/services/response/callbacks.py b/backend/app/services/response/callbacks.py index fbf0af1f5..b58dad2d4 100644 --- a/backend/app/services/response/callbacks.py +++ b/backend/app/services/response/callbacks.py @@ -1,5 +1,5 @@ from app.models import ResponsesAPIRequest, ResponsesSyncAPIRequest -from app.utils import APIResponse, send_callback +from app.utils import APIResponse, get_webhook_secret, send_callback def get_additional_data(request: dict) -> dict: @@ -19,10 +19,15 @@ def send_response_callback( callback_url: str, callback_response: APIResponse, request_dict: dict, + organization_id: int, + project_id: int, ) -> None: """Send a standardized callback response to the provided callback URL.""" callback_response = callback_response.model_dump() + + webhook_secret = get_webhook_secret(project_id, organization_id) + send_callback( callback_url, { @@ -34,4 +39,5 @@ def send_response_callback( "error": callback_response.get("error"), "metadata": None, }, + webhook_secret=webhook_secret, ) diff --git a/backend/app/services/response/jobs.py b/backend/app/services/response/jobs.py index 2be98b6a5..bd6f6e916 100644 --- a/backend/app/services/response/jobs.py +++ b/backend/app/services/response/jobs.py @@ -74,4 +74,6 @@ def execute_job( callback_url=request_data.callback_url, callback_response=response, request_dict=request_data.model_dump(), + organization_id=organization_id, + project_id=project_id, ) diff --git a/backend/app/tests/core/test_callback_ssrf.py b/backend/app/tests/core/test_callback_ssrf.py index 1aaf5ac44..15dde46a0 100644 --- a/backend/app/tests/core/test_callback_ssrf.py +++ b/backend/app/tests/core/test_callback_ssrf.py @@ -1,3 +1,4 @@ +import json import socket from typing import Any import requests @@ -5,7 +6,15 @@ import pytest -from app.utils import _is_private_ip, validate_callback_url, send_callback +import hashlib +import hmac + +from app.utils import ( + _is_private_ip, + validate_callback_url, + send_callback, + sign_webhook_payload, +) class TestIsPrivateIP: @@ -326,5 +335,46 @@ def test_callback_sends_json_data( send_callback("https://api.example.com/callback", test_data) call_kwargs = mock_session.post.call_args[1] - assert "json" in call_kwargs - assert call_kwargs["json"] == test_data + # Body is now pre-serialized so HMAC signs the exact bytes we send. + assert "data" in call_kwargs + assert ( + call_kwargs["data"] == json.dumps(test_data, separators=(",", ":")).encode() + ) + assert call_kwargs["headers"]["Content-Type"] == "application/json" + # No webhook_secret passed → no signature headers. + assert "X-Webhook-Signature" not in call_kwargs["headers"] + assert "X-Webhook-Timestamp" not in call_kwargs["headers"] + + +class TestSignWebhookPayload: + def test_returns_hex_signature_and_timestamp(self): + sig, ts = sign_webhook_payload("secret", b"body") + assert isinstance(sig, str) and len(sig) == 64 # sha256 hex + assert isinstance(ts, int) and ts > 0 + + def test_deterministic_with_fixed_timestamp(self): + body = b'{"key":"value"}' + sig1, _ = sign_webhook_payload("secret", body, timestamp_ms=1000) + sig2, _ = sign_webhook_payload("secret", body, timestamp_ms=1000) + assert sig1 == sig2 + + def test_different_secrets_produce_different_signatures(self): + body = b"payload" + sig1, _ = sign_webhook_payload("secret-a", body, timestamp_ms=1000) + sig2, _ = sign_webhook_payload("secret-b", body, timestamp_ms=1000) + assert sig1 != sig2 + + def test_different_timestamps_produce_different_signatures(self): + body = b"payload" + sig1, _ = sign_webhook_payload("secret", body, timestamp_ms=1000) + sig2, _ = sign_webhook_payload("secret", body, timestamp_ms=2000) + assert sig1 != sig2 + + def test_signature_matches_manual_hmac(self): + secret, body, ts = "mysecret", b"hello", 999 + expected = hmac.new( + secret.encode(), f"{ts}.".encode() + body, hashlib.sha256 + ).hexdigest() + sig, returned_ts = sign_webhook_payload(secret, body, timestamp_ms=ts) + assert sig == expected + assert returned_ts == ts diff --git a/backend/app/tests/services/doctransformer/test_job/test_callbacks.py b/backend/app/tests/services/doctransformer/test_job/test_callbacks.py new file mode 100644 index 000000000..e090a2224 --- /dev/null +++ b/backend/app/tests/services/doctransformer/test_job/test_callbacks.py @@ -0,0 +1,568 @@ +""" +Tests for doctransform execute_job: callbacks, payload builders, signed URL, and tmp dir cleanup. + +All existing tests pass callback_url=None. This file covers the gaps: +- success / failure callbacks (payload structure, single send, webhook secret) +- build_success_payload / build_failure_payload +- tmp dir cleaned up in both success and failure paths +- signed URL included when storage supports it; exception swallowed when it doesn't +""" +import shutil +from datetime import datetime +from io import BytesIO +from typing import Tuple +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from moto import mock_aws +from sqlmodel import Session + +from app.crud import DocTransformationJobCrud +from app.models import ( + Document, + DocTransformJobCreate, + Project, + TransformationStatus, + TransformedDocumentPublic, +) +from app.services.doctransform.job import ( + build_failure_payload, + build_success_payload, + execute_job, +) +from app.tests.services.doctransformer.test_job.utils import ( + DocTransformTestBase, + MockTestTransformer, +) + + +def _make_transformed_doc(document: Document) -> TransformedDocumentPublic: + return TransformedDocumentPublic( + id=uuid4(), + project_id=document.project_id, + fname="output.md", + object_store_url="s3://bucket/key", + source_document_id=document.id, + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + + +# --------------------------------------------------------------------------- +# Payload builders — pure logic, no S3 +# --------------------------------------------------------------------------- + + +class TestBuildPayloads: + def test_success_payload_structure( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + payload = build_success_payload(job, _make_transformed_doc(document)) + + assert payload["success"] is True + assert payload["error"] is None + assert "error_message" not in payload["data"] + assert payload["data"]["transformed_document"]["fname"] == "output.md" + + def test_failure_payload_structure( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + payload = build_failure_payload(job, "conversion crashed") + + assert payload["success"] is False + assert "conversion crashed" in payload["error"] + assert "error_message" not in payload["data"] + assert payload["data"]["transformed_document"] is None + + +# --------------------------------------------------------------------------- +# Callback — success path +# --------------------------------------------------------------------------- + + +class TestCallbacksSuccess(DocTransformTestBase): + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_success_sends_callback_once_with_correct_payload( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + callback_url = "https://example.com/webhook" + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.job.get_webhook_secret", return_value=None + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + execute_job( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=callback_url, + task_instance=None, + ) + + mock_send.assert_called_once() + url_arg, payload_arg = mock_send.call_args.args + assert url_arg == callback_url + assert payload_arg["success"] is True + assert payload_arg["data"]["status"] == TransformationStatus.COMPLETED + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_success_callback_not_sent_without_callback_url( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + execute_job( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=None, + task_instance=None, + ) + + mock_send.assert_not_called() + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_webhook_secret_passed_to_send_callback( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.job.get_webhook_secret", + return_value="my-secret", + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + execute_job( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url="https://example.com/webhook", + task_instance=None, + ) + + assert mock_send.call_args.kwargs["webhook_secret"] == "my-secret" + + +# --------------------------------------------------------------------------- +# Callback — failure path +# --------------------------------------------------------------------------- + + +class TestCallbacksFailure(DocTransformTestBase): + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_failure_sends_callback_with_error_payload( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.job.get_webhook_secret", return_value=None + ), + patch( + "app.services.doctransform.job.convert_document", + side_effect=RuntimeError("converter crashed"), + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + with pytest.raises(RuntimeError): + execute_job.__wrapped__( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url="https://example.com/webhook", + task_instance=None, + ) + + mock_send.assert_called_once() + url_arg, payload_arg = mock_send.call_args.args + assert payload_arg["success"] is False + assert "converter crashed" in payload_arg["error"] + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_failure_callback_not_sent_without_callback_url( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.job.convert_document", + side_effect=RuntimeError("crash"), + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + with pytest.raises(RuntimeError): + execute_job.__wrapped__( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=None, + task_instance=None, + ) + + mock_send.assert_not_called() + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_failure_marks_job_failed_before_callback( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback"), + patch( + "app.services.doctransform.job.get_webhook_secret", return_value=None + ), + patch( + "app.services.doctransform.job.convert_document", + side_effect=RuntimeError("crash"), + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + with pytest.raises(RuntimeError): + execute_job.__wrapped__( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url="https://example.com/webhook", + task_instance=None, + ) + + db.refresh(job) + assert job.status == TransformationStatus.FAILED + assert "crash" in job.error_message + + +# --------------------------------------------------------------------------- +# Tmp dir cleanup +# --------------------------------------------------------------------------- + + +class TestTmpDirCleanup(DocTransformTestBase): + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_tmp_dir_removed_on_success( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + removed: list[str] = [] + real_rmtree = shutil.rmtree + + def capture(path, **kw): + removed.append(str(path)) + real_rmtree(path, **kw) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.shutil.rmtree", side_effect=capture), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + execute_job( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=None, + task_instance=None, + ) + + assert len(removed) == 1 + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_tmp_dir_removed_on_failure( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + removed: list[str] = [] + real_rmtree = shutil.rmtree + + def capture(path, **kw): + removed.append(str(path)) + real_rmtree(path, **kw) + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.shutil.rmtree", side_effect=capture), + patch( + "app.services.doctransform.job.convert_document", + side_effect=RuntimeError("crash"), + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + with pytest.raises(RuntimeError): + execute_job.__wrapped__( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=None, + task_instance=None, + ) + + assert len(removed) == 1 + + +# --------------------------------------------------------------------------- +# Signed URL +# --------------------------------------------------------------------------- + + +class TestSignedUrl(DocTransformTestBase): + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_signed_url_included_in_callback_when_available( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + mock_storage = MagicMock() + mock_storage.stream.return_value = BytesIO(b"content") + mock_storage.put.return_value = "s3://bucket/transformed" + mock_storage.get_signed_url.return_value = "https://signed.example.com/doc" + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.job.get_webhook_secret", return_value=None + ), + patch( + "app.services.doctransform.job.get_cloud_storage", + return_value=mock_storage, + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + execute_job( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url="https://example.com/webhook", + task_instance=None, + ) + + payload = mock_send.call_args.args[1] + assert ( + payload["data"]["transformed_document"]["signed_url"] + == "https://signed.example.com/doc" + ) + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_signed_url_exception_swallowed_job_still_succeeds( + self, db: Session, test_document: Tuple[Document, Project] + ) -> None: + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job = DocTransformationJobCrud(db, project_id=project.id).create( + DocTransformJobCreate(source_document_id=document.id) + ) + + mock_storage = MagicMock() + mock_storage.stream.return_value = BytesIO(b"content") + mock_storage.put.return_value = "s3://bucket/transformed" + mock_storage.get_signed_url.side_effect = Exception("token expired") + + with ( + patch("app.services.doctransform.job.Session") as mock_session, + patch("app.services.doctransform.job.send_callback") as mock_send, + patch( + "app.services.doctransform.job.get_webhook_secret", return_value=None + ), + patch( + "app.services.doctransform.job.get_cloud_storage", + return_value=mock_storage, + ), + patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), + ): + mock_session.return_value.__enter__.return_value = db + mock_session.return_value.__exit__.return_value = None + + execute_job( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url="https://example.com/webhook", + task_instance=None, + ) + + db.refresh(job) + assert job.status == TransformationStatus.COMPLETED + payload = mock_send.call_args.args[1] + assert payload["data"]["transformed_document"]["signed_url"] is None diff --git a/backend/app/utils.py b/backend/app/utils.py index 0d9741a5f..ab9407973 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -2,9 +2,13 @@ import base64 import functools as ft +import hashlib +import hmac import ipaddress +import json import logging import tempfile +import time from dataclasses import dataclass from datetime import timedelta from pathlib import Path @@ -410,7 +414,60 @@ def validate_callback_url(url: str) -> None: raise ValueError(f"Error validating callback URL: {str(e)}") from e -def send_callback(callback_url: str, data: dict[str, Any]) -> bool: +def sign_webhook_payload( + secret: str, raw_body: bytes, timestamp_ms: int | None = None +) -> tuple[str, int]: + """ + Generate an HMAC-SHA256 signature for a webhook payload. + + Signing string format: "." + The receiver must reconstruct the exact same signing string to verify. + + Args: + secret: Shared HMAC secret (pre-registered by the receiver). + raw_body: Exact bytes that will be sent in the HTTP body. + timestamp_ms: Unix timestamp in milliseconds. Generated if not provided. + + Returns: + (hex_signature, timestamp_ms) + """ + if timestamp_ms is None: + timestamp_ms = int(time.time() * 1000) + + signing_string = f"{timestamp_ms}.".encode() + raw_body + signature = hmac.new( + secret.encode(), + signing_string, + hashlib.sha256, + ).hexdigest() + return signature, timestamp_ms + + +def get_webhook_secret( + project_id: int | None, organization_id: int | None +) -> str | None: + """Look up the configured webhook signing secret for this project, or None.""" + if project_id is None or organization_id is None: + return None + # Imported lazily: app.core.db pulls in app.crud, which imports app.utils, + # so a top-level import here would deadlock module initialization. + from app.core.db import engine + + with Session(engine) as session: + creds = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="webhook_secret", + ) + return creds.get("webhook_secret") if isinstance(creds, dict) else None + + +def send_callback( + callback_url: str, + data: dict[str, Any], + webhook_secret: str | None = None, +) -> bool: """ Send results to the callback URL (synchronously) with SSRF protection. @@ -422,10 +479,13 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: - DNS rebinding protection - Redirect following disabled - Strict timeouts + - Optional HMAC-SHA256 signing when webhook_secret is provided Args: callback_url: The HTTPS URL to send the callback to data: The JSON data to send in the POST request + webhook_secret: If provided, sign the request with HMAC-SHA256 and + attach X-Webhook-Signature / X-Webhook-Timestamp headers. Returns: bool: True if callback succeeded, False otherwise @@ -435,14 +495,21 @@ def send_callback(callback_url: str, data: dict[str, Any]) -> bool: except ValueError as ve: logger.error(f"[send_callback] Invalid callback URL: {ve}", exc_info=True) return False - try: + raw_body = json.dumps(data, separators=(",", ":")).encode() + headers = {"Content-Type": "application/json"} + + if webhook_secret: + signature, timestamp_ms = sign_webhook_payload(webhook_secret, raw_body) + headers["X-Webhook-Signature"] = signature + headers["X-Webhook-Timestamp"] = str(timestamp_ms) with requests.Session() as session: session.trust_env = False # Ignores environment proxies and other implicit settings for SSRF safety response = session.post( callback_url, - json=data, + data=raw_body, + headers=headers, timeout=( settings.CALLBACK_CONNECT_TIMEOUT, settings.CALLBACK_READ_TIMEOUT,