diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py index 44e895524..f5f54455d 100644 --- a/application/agents/agent_creator.py +++ b/application/agents/agent_creator.py @@ -18,3 +18,41 @@ def create_agent(cls, type, *args, **kwargs): raise ValueError(f"No agent class found for type {type}") return agent_class(*args, **kwargs) + + +DOCSGPT_DEFAULT_ORIGINS = { + "https://app.docsgpt.cloud", + "https://ent.docsgpt.cloud", +} + + +def _is_origin_allowed(agent, origin: str | None) -> bool: + """ + Basic origin whitelist check. + + - If agent has no origin whitelisting enabled/configured, allow all. + - If whitelisting is enabled and Origin is missing, reject. + - Always allow DocsGPT default origins. + - Otherwise, check against agent-configured allowed origins. + """ + # If feature is not enabled or no config on this agent, allow + if not getattr(agent, "origin_whitelist_enabled", False): + return True + + # No Origin header and whitelist enabled → reject + if origin is None: + return False + + # Always allow internal DocsGPT origins + if origin in DOCSGPT_DEFAULT_ORIGINS: + return True + + # Read agent-configured allowed origins + raw_allowed = getattr(agent, "allowed_origins", "") or "" + allowed = {o.strip() for o in raw_allowed.split(",") if o.strip()} + + # If no custom origins, fall back to only default ones (already checked) + if not allowed: + return False + + return origin in allowed diff --git a/tests/api/test_agent_origin_whitelist.py b/tests/api/test_agent_origin_whitelist.py new file mode 100644 index 000000000..504f3c86d --- /dev/null +++ b/tests/api/test_agent_origin_whitelist.py @@ -0,0 +1,35 @@ +# tests/application/test_agent_origin_whitelist.py + +import types +from application.routes import agent_api + + +def _fake_agent(enabled, allowed_origins): + agent = types.SimpleNamespace() + agent.origin_whitelist_enabled = enabled + agent.allowed_origins = allowed_origins + return agent + + +def test_origin_allowed_when_feature_disabled(): + agent = _fake_agent(enabled=False, allowed_origins="") + assert agent_api._is_origin_allowed(agent, None) is True + assert agent_api._is_origin_allowed(agent, "https://example.com") is True + + +def test_origin_rejected_when_missing_and_enabled(): + agent = _fake_agent(enabled=True, allowed_origins="https://example.com") + assert agent_api._is_origin_allowed(agent, None) is False + + +def test_default_docsgpt_origins_always_allowed(): + agent = _fake_agent(enabled=True, allowed_origins="") + assert agent_api._is_origin_allowed(agent, "https://app.docsgpt.cloud") is True + assert agent_api._is_origin_allowed(agent, "https://ent.docsgpt.cloud") is True + + +def test_custom_origin_must_be_in_whitelist(): + agent = _fake_agent(enabled=True, allowed_origins="https://a.com, https://b.com") + assert agent_api._is_origin_allowed(agent, "https://a.com") is True + assert agent_api._is_origin_allowed(agent, "https://b.com") is True + assert agent_api._is_origin_allowed(agent, "https://c.com") is False