Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions application/agents/agent_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,41 @@ def create_agent(cls, type, *args, **kwargs):
raise ValueError(f"No agent class found for type {type}")

return agent_class(*args, **kwargs)


DOCSGPT_DEFAULT_ORIGINS = {
"https://app.docsgpt.cloud",
"https://ent.docsgpt.cloud",
}


def _is_origin_allowed(agent, origin: str | None) -> bool:
"""
Basic origin whitelist check.

- If agent has no origin whitelisting enabled/configured, allow all.
- If whitelisting is enabled and Origin is missing, reject.
- Always allow DocsGPT default origins.
- Otherwise, check against agent-configured allowed origins.
"""
# If feature is not enabled or no config on this agent, allow
if not getattr(agent, "origin_whitelist_enabled", False):
return True

# No Origin header and whitelist enabled → reject
if origin is None:
return False

# Always allow internal DocsGPT origins
if origin in DOCSGPT_DEFAULT_ORIGINS:
return True

# Read agent-configured allowed origins
raw_allowed = getattr(agent, "allowed_origins", "") or ""
allowed = {o.strip() for o in raw_allowed.split(",") if o.strip()}

# If no custom origins, fall back to only default ones (already checked)
if not allowed:
return False

return origin in allowed
35 changes: 35 additions & 0 deletions tests/api/test_agent_origin_whitelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# tests/application/test_agent_origin_whitelist.py

import types
from application.routes import agent_api


def _fake_agent(enabled, allowed_origins):
agent = types.SimpleNamespace()
agent.origin_whitelist_enabled = enabled
agent.allowed_origins = allowed_origins
return agent


def test_origin_allowed_when_feature_disabled():
agent = _fake_agent(enabled=False, allowed_origins="")
assert agent_api._is_origin_allowed(agent, None) is True
assert agent_api._is_origin_allowed(agent, "https://example.com") is True


def test_origin_rejected_when_missing_and_enabled():
agent = _fake_agent(enabled=True, allowed_origins="https://example.com")
assert agent_api._is_origin_allowed(agent, None) is False


def test_default_docsgpt_origins_always_allowed():
agent = _fake_agent(enabled=True, allowed_origins="")
assert agent_api._is_origin_allowed(agent, "https://app.docsgpt.cloud") is True
assert agent_api._is_origin_allowed(agent, "https://ent.docsgpt.cloud") is True


def test_custom_origin_must_be_in_whitelist():
agent = _fake_agent(enabled=True, allowed_origins="https://a.com, https://b.com")
assert agent_api._is_origin_allowed(agent, "https://a.com") is True
assert agent_api._is_origin_allowed(agent, "https://b.com") is True
assert agent_api._is_origin_allowed(agent, "https://c.com") is False