Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv sync --group dev --group docs --group vllm
uv sync --group dev --group docs
- name: Check types
run: |
uv run mypy app
Expand Down
24 changes: 18 additions & 6 deletions app/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path
import app.api.globals as cms_globals

from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Union, Type
from concurrent.futures import ThreadPoolExecutor
from anyio.lowlevel import RunVar
from anyio import CapacityLimiter
Expand All @@ -20,7 +20,7 @@
from app.api.dependencies import ModelServiceDep
from app.api.utils import add_exception_handlers, add_rate_limiter, init_vllm_engine
from app.config import Settings
from app.domain import Tags, TagsStreamable
from app.domain import Tags, TagsStreamable, TagsGenerative
from app.management.tracker_client import TrackerClient
from app.utils import get_settings, unpack_model_data_package, get_model_data_package_base_name
from app.exception import ConfigurationException
Expand Down Expand Up @@ -131,6 +131,11 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi
app = _load_health_check_router(app)
logger.debug("Health check router loaded")

if config.ENABLE_TRAINING_APIS == "true":
app = _load_supervised_training_router(app)
logger.debug("Supervised training router loaded")
app = _load_training_operations(app)

if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)
logger.debug("Auth router loaded")
Expand Down Expand Up @@ -198,11 +203,18 @@ def _get_app(
streamable: bool = False,
generative: bool = False,
) -> FastAPI:
tags_metadata = [{ # type: ignore
"name": tag.name,
"description": tag.value
} for tag in (Tags if not streamable else TagsStreamable)]
config = get_settings()
tags: Union[Type[Tags], Type[TagsStreamable], Type[TagsGenerative]]
if generative:
tags = TagsGenerative
elif streamable:
tags = TagsStreamable
else:
tags = Tags
tags_metadata = [{ # type: ignore
"name": tag.name, # type: ignore
"description": tag.value # type: ignore
} for tag in tags]
app = FastAPI(
title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
Expand Down
Loading