diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index ffce1f8..6c3a9dd 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -24,7 +24,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - uv sync --group dev --group docs --group vllm + uv sync --group dev --group docs - name: Check types run: | uv run mypy app diff --git a/app/api/api.py b/app/api/api.py index b9874ab..5dd1522 100644 --- a/app/api/api.py +++ b/app/api/api.py @@ -4,7 +4,7 @@ import os.path import app.api.globals as cms_globals -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Union, Type from concurrent.futures import ThreadPoolExecutor from anyio.lowlevel import RunVar from anyio import CapacityLimiter @@ -20,7 +20,7 @@ from app.api.dependencies import ModelServiceDep from app.api.utils import add_exception_handlers, add_rate_limiter, init_vllm_engine from app.config import Settings -from app.domain import Tags, TagsStreamable +from app.domain import Tags, TagsStreamable, TagsGenerative from app.management.tracker_client import TrackerClient from app.utils import get_settings, unpack_model_data_package, get_model_data_package_base_name from app.exception import ConfigurationException @@ -131,6 +131,11 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi app = _load_health_check_router(app) logger.debug("Health check router loaded") + if config.ENABLE_TRAINING_APIS == "true": + app = _load_supervised_training_router(app) + logger.debug("Supervised training router loaded") + app = _load_training_operations(app) + if config.AUTH_USER_ENABLED == "true": app = _load_auth_router(app) logger.debug("Auth router loaded") @@ -198,11 +203,18 @@ def _get_app( streamable: bool = False, generative: bool = False, ) -> FastAPI: - tags_metadata = [{ # type: ignore - "name": tag.name, - "description": tag.value - } for tag in (Tags if not streamable else TagsStreamable)] config = get_settings() + tags: Union[Type[Tags], Type[TagsStreamable], Type[TagsGenerative]] + if generative: + tags = TagsGenerative + elif streamable: + tags = TagsStreamable + else: + tags = Tags + tags_metadata = [{ # type: ignore + "name": tag.name, # type: ignore + "description": tag.value # type: ignore + } for tag in tags] app = FastAPI( title="CogStack ModelServe", summary="A model serving and governance system for CogStack NLP solutions", diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py index 007eb0b..b2d9454 100644 --- a/app/api/routers/generative.py +++ b/app/api/routers/generative.py @@ -1,16 +1,36 @@ +import json import logging +import time +import uuid import app.api.globals as cms_globals +from typing import Union, Iterable, AsyncGenerator from typing_extensions import Annotated +from functools import partial from fastapi import APIRouter, Depends, Request, Body, Query -from fastapi.responses import PlainTextResponse, StreamingResponse -from app.domain import Tags +from fastapi.encoders import jsonable_encoder +from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse +from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR +from app.domain import ( + Tags, + TagsGenerative, + OpenAIChatRequest, + OpenAIChatResponse, + OpenAIEmbeddingsRequest, + OpenAIEmbeddingsResponse, + PromptMessage, + PromptRole, +) from app.model_services.base import AbstractModelService -from app.utils import get_settings +from app.utils import get_settings, get_prompt_from_messages from app.api.utils import get_rate_limiter +from app.api.dependencies import validate_tracking_id +from app.management.prometheus_metrics import cms_prompt_tokens, cms_completion_tokens, cms_total_tokens PATH_GENERATE = "/generate" PATH_GENERATE_ASYNC = "/stream/generate" +PATH_OPENAI_COMPLETIONS = "/v1/chat/completions" +PATH_OPENAI_EMBEDDINGS = "/v1/embeddings" router = APIRouter() config = get_settings() @@ -22,7 +42,7 @@ @router.post( PATH_GENERATE, - tags=[Tags.Generative.name], + tags=[TagsGenerative.Generative], response_class=PlainTextResponse, dependencies=[Depends(cms_globals.props.current_active_user)], description="Generate text", @@ -31,27 +51,48 @@ def generate_text( request: Request, prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")], max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512, + temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> PlainTextResponse: """ - Generate text based on the prompt provided. + Generates text based on the prompt provided. Args: request (Request): The request object. prompt (str): The prompt to be sent to the model. max_tokens (int): The maximum number of tokens to generate. + temperature (float): The temperature of the generated text. + tracking_id (Union[str, None]): An optional tracking ID of the requested task. model_service (AbstractModelService): The model service dependency. Returns: PlainTextResponse: A response containing the generated text. """ - return PlainTextResponse(model_service.generate(prompt, max_tokens=max_tokens)) + tracking_id = tracking_id or str(uuid.uuid4()) + if prompt: + return PlainTextResponse( + model_service.generate( + prompt, + max_tokens=max_tokens, + temperature=temperature, + report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE), + ), + headers={"x-cms-tracking-id": tracking_id}, + status_code=HTTP_200_OK, + ) + else: + return PlainTextResponse( + _empty_prompt_error(), + headers={"x-cms-tracking-id": tracking_id}, + status_code=HTTP_400_BAD_REQUEST, + ) @router.post( PATH_GENERATE_ASYNC, - tags=[Tags.Generative.name], + tags=[TagsGenerative.Generative], response_class=StreamingResponse, dependencies=[Depends(cms_globals.props.current_active_user)], description="Generate a stream of texts", @@ -60,22 +101,249 @@ async def generate_text_stream( request: Request, prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")], max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512, + temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> StreamingResponse: """ - Generate a stream of texts in near real-time. + Generates a stream of texts in near real-time. Args: request (Request): The request object. prompt (str): The prompt to be sent to the model. max_tokens (int): The maximum number of tokens to generate. + temperature (float): The temperature of the generated text. + tracking_id (Union[str, None]): An optional tracking ID of the requested task. model_service (AbstractModelService): The model service dependency. Returns: StreamingResponse: A streaming response containing the text generated in near real-time. """ - return StreamingResponse( - model_service.generate_async(prompt, max_tokens=max_tokens), - media_type="text/event-stream" - ) + tracking_id = tracking_id or str(uuid.uuid4()) + if prompt: + return StreamingResponse( + model_service.generate_async( + prompt, + max_tokens=max_tokens, + temperature=temperature, + report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE_ASYNC), + ), + media_type="text/event-stream", + headers={"x-cms-tracking-id": tracking_id}, + status_code=HTTP_200_OK, + ) + else: + return StreamingResponse( + _empty_prompt_error(), + media_type="text/event-stream", + headers={"x-cms-tracking-id": tracking_id}, + status_code=HTTP_400_BAD_REQUEST, + ) + + +@router.post( + PATH_OPENAI_COMPLETIONS, + tags=[Tags.OpenAICompatible.name], + response_model=None, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions", +) +def generate_chat_completions( + request: Request, + request_data: Annotated[OpenAIChatRequest, Body( + description="OpenAI-like completion request", media_type="application/json" + )], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> Union[StreamingResponse, JSONResponse]: + """ + Generates chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint. + + Args: + request (Request): The request object. + request_data (OpenAIChatRequest): The request data containing model, messages, and stream. + tracking_id (Union[str, None]): An optional tracking ID of the requested task. + model_service (AbstractModelService): The model service dependency. + + Returns: + StreamingResponse: A OpenAI-like response containing the text generated in near real-time. + JSONResponse: A response containing an error message if the prompt messages are empty. + """ + + messages = request_data.messages + model = model_service.model_name if request_data.model != model_service.model_name else request_data.model + stream = request_data.stream + max_tokens = request_data.max_tokens + temperature = request_data.temperature + tracking_id = tracking_id or str(uuid.uuid4()) + + if not messages: + error_response = { + "error": { + "message": "No prompt messages provided", + "type": "invalid_request_error", + "param": "messages", + "code": "missing_field", + } + } + return JSONResponse( + content=error_response, + status_code=HTTP_400_BAD_REQUEST, + headers={"x-cms-tracking-id": tracking_id}, + ) + + async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGenerator: + data = { + "id": tracking_id, + "object": "chat.completion.chunk", + "choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}], + } + yield f"data: {json.dumps(data)}\n\n" + async for chunk in model_service.generate_async( + prompt, + max_tokens=max_tokens, + temperature=temperature, + report_tokens=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS) + ): + data = { + "choices": [ + { + "delta": {"content": chunk} + } + ], + "object": "chat.completion.chunk", + } + yield f"data: {json.dumps(data)}\n\n" + yield "data: [DONE]\n\n" + + prompt = get_prompt_from_messages(model_service.tokenizer, messages) # type: ignore + if stream: + return StreamingResponse( + _stream(prompt, max_tokens, temperature), + media_type="text/event-stream", + headers={"x-cms-tracking-id": tracking_id}, + ) + else: + generated_text = model_service.generate( + prompt, + max_tokens=max_tokens, + temperature=temperature, + send_metrics=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS), + ) + completion = OpenAIChatResponse( + id=tracking_id, + object="chat.completion", + created=int(time.time()), + model=model, + choices=[ + { + "index": 0, + "message": PromptMessage( + role=PromptRole.ASSISTANT, + content=generated_text, + ), + "finish_reason": "stop", + } + ], + ) + return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id}) + + +@router.post( + PATH_OPENAI_EMBEDDINGS, + tags=[Tags.OpenAICompatible.name], + response_model=None, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint", +) +def embed_texts( + request: Request, + request_data: Annotated[OpenAIEmbeddingsRequest, Body( + description="Text(s) to be embedded", media_type="application/json" + )], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> JSONResponse: + """ + Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint. + + Args: + request (Request): The request object. + request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s). + tracking_id (Union[str, None]): An optional tracking ID of the requested task. + model_service (AbstractModelService): The model service dependency. + + Returns: + JSONResponse: A response containing the embeddings of the text(s). + """ + tracking_id = tracking_id or str(uuid.uuid4()) + + if not hasattr(model_service, "create_embeddings"): + error_response = { + "error": { + "message": "Model does not support embeddings", + "type": "invalid_request_error", + "param": "model", + "code": "model_not_supported", + } + } + return JSONResponse( + content=error_response, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + headers={"x-cms-tracking-id": tracking_id}, + ) + + input_text = request_data.input + model = model_service.model_name if request_data.model != model_service.model_name else request_data.model + + if isinstance(input_text, str): + input_texts = [input_text] + else: + input_texts = input_text + + try: + embeddings_data = [] + + for i, embedding in enumerate(model_service.create_embeddings(input_texts)): + embeddings_data.append({ + "object": "embedding", + "embedding": embedding, + "index": i, + }) + + response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model) + + return JSONResponse( + content=jsonable_encoder(response), + headers={"x-cms-tracking-id": tracking_id}, + ) + + except Exception as e: + logger.error("Failed to create embeddings") + logger.exception(e) + error_response = { + "error": { + "message": f"Failed to create embeddings: {str(e)}", + "type": "server_error", + "code": "internal_error", + } + } + return JSONResponse( + content=error_response, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + headers={"x-cms-tracking-id": tracking_id}, + ) + + +def _empty_prompt_error() -> Iterable[str]: + yield "ERROR: No prompt text provided\n" + + +def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None: + cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num) + logger.debug("Sent prompt tokens usage: %s", prompt_token_num) + cms_completion_tokens.labels(handler=handler).observe(completion_token_num) + logger.debug("Sent completion tokens usage: %s", completion_token_num) + cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num) + logger.debug("Sent total tokens usage: %s", prompt_token_num + completion_token_num) diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index 89d0d17..1ee7c9b 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -12,9 +12,9 @@ import app.api.globals as cms_globals from app.api.dependencies import validate_tracking_id -from app.domain import Tags +from app.domain import Tags, ModelType from app.model_services.base import AbstractModelService -from app.processors.metrics_collector import concat_trainer_exports +from app.processors.metrics_collector import concat_json_lists, concat_trainer_exports from app.utils import filter_by_concept_ids router = APIRouter() @@ -72,12 +72,19 @@ async def train_supervised( files.append(temp_te) file_names.append("" if te.filename is None else te.filename) - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) - logger.debug("Training exports concatenated") - data_file = tempfile.NamedTemporaryFile(mode="w") - concatenated = filter_by_concept_ids(cast(Dict[str, Any], concatenated), model_service.info().model_type) - logger.debug("Training exports filtered by concept IDs") - json.dump(concatenated, data_file) + if model_service.info().model_type is not ModelType.HUGGINGFACE_LLM: + concatenated_te = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + logger.debug("Training exports concatenated") + data_file = tempfile.NamedTemporaryFile(mode="w+") + concatenated_te = filter_by_concept_ids(cast(Dict[str, Any], concatenated_te), model_service.info().model_type) + logger.debug("Training exports filtered by concept IDs") + json.dump(concatenated_te, data_file) + else: + concatenated = concat_json_lists([file.name for file in files]) + logger.debug("Training exports concatenated") + data_file = tempfile.NamedTemporaryFile(mode="w+") + json.dump(concatenated, data_file) + data_file.flush() data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) @@ -102,6 +109,7 @@ async def train_supervised( return _get_training_response(training_response, training_id) + def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: training_accepted, experiment_id, run_id = training_response if training_accepted: diff --git a/app/api/utils.py b/app/api/utils.py index a14714c..87cea26 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -26,8 +26,14 @@ from slowapi.errors import RateLimitExceeded from fastapi_users.jwt import decode_jwt from app.config import Settings -from app.domain import Tags -from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException +from app.domain import TagsGenerative +from app.exception import ( + StartTrainingException, + AnnotationException, + ConfigurationException, + ClientException, + ExtraDependencyRequiredException, +) logger = logging.getLogger("cms") @@ -118,6 +124,24 @@ async def configuration_exception_handler(_: Request, exception: ConfigurationEx logger.exception(exception) return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) + @app.exception_handler(ExtraDependencyRequiredException) + async def extra_dependency_exception_handler( + _: Request, + exception: ExtraDependencyRequiredException + ) -> JSONResponse: + """ + Handles extra dependency required exceptions. + + Args: + _ (Request): The request object. + exception (ExtraDependencyRequiredException): The extra dependency required exception. + + Returns: + JSONResponse: A JSON response with a 500 status code and an error message. + """ + logger.exception(exception) + return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) + @app.exception_handler(ClientException) async def client_exception_handler(_: Request, exception: ClientException) -> JSONResponse: """ @@ -286,6 +310,7 @@ async def init_vllm_engine(app: FastAPI, """ try: + # Import necessary vLLM components from vllm.utils import FlexibleArgumentParser from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -299,15 +324,18 @@ async def init_vllm_engine(app: FastAPI, from vllm import SamplingParams, TokensPrompt except ImportError: logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") + raise ExtraDependencyRequiredException("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") parser = FlexibleArgumentParser() parser = make_arg_parser(parser) args = parser.parse_args([]) validate_parsed_serve_args(args) + args.model = model_dir_path args.dtype = "float16" args.served_model_name = [model_name] - # args.tokenizer = model_dir_path + args.max_model_len = 2048 # The default batched length (2048) needs to be higher than max_model_len. + # args.tokenizer = model_dir_path # Uncomment if your tokenizer is in a different path or needs explicit setting. args.log_level = log_level exit_stack = contextlib.AsyncExitStack() @@ -317,9 +345,11 @@ async def init_vllm_engine(app: FastAPI, disable_frontend_multiprocessing=True, ) ) + tokenizer = await engine.get_tokenizer() vllm_config = await engine.get_vllm_config() model_config = await engine.get_model_config() + await init_app_state(engine, vllm_config, app.state, args) async def generate_text( @@ -327,27 +357,32 @@ async def generate_text( prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")], max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512 ) -> StreamingResponse: + """ + Custom endpoint for streaming text generation. + This endpoint takes a raw text prompt and streams back the generated text. + It applies a chat template to the prompt internally for model compatibility. + """ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] params = SamplingParams(max_tokens=max_tokens) + conversation, _ = parse_chat_messages(messages, model_config, tokenizer, content_format="string") # type: ignore - prompt = TokensPrompt( - prompt_token_ids=apply_hf_chat_template( # type: ignore - tokenizer, - conversation=conversation, - tools=None, - add_generation_prompt=True, - continue_final_message=False, - chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:", - tokenize=True, - ) + prompt_tokens = apply_hf_chat_template( # type: ignore + tokenizer, + conversation=conversation, + tools=None, + add_generation_prompt=True, + continue_final_message=False, + chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:", + tokenize=True, ) + prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens) # type: ignore async def _stream() -> AsyncGenerator[bytes, None]: start = 0 - async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt, sampling_params=params): + async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt_obj, sampling_params=params): text = output.outputs[0].text - yield text[start:] # type: ignore + yield text[start:].encode("utf-8") start = len(text) return StreamingResponse(_stream(), media_type="text/event-stream") @@ -365,7 +400,7 @@ async def _stream() -> AsyncGenerator[bytes, None]: endpoint=endpoint, methods=methods, include_in_schema=True, - tags=[Tags.Generative], + tags=[TagsGenerative.Generative.name], ) app.include_router(router) diff --git a/app/cli/cli.py b/app/cli/cli.py index 8a94647..6003407 100644 --- a/app/cli/cli.py +++ b/app/cli/cli.py @@ -67,6 +67,7 @@ def serve_model( streamable: bool = typer.Option(False, help="Serve the streamable endpoints only"), device: Device = typer.Option(Device.DEFAULT.value, help="The device to serve the model on"), llm_engine: Optional[LlmEngine] = typer.Option(LlmEngine.CMS.value, help="The engine to use for text generation"), + load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"), debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), ) -> None: """ @@ -84,6 +85,7 @@ def serve_model( streamable (bool): Serve the streamable endpoints only. Defaults to False. device (Device): The device to serve the model on. Defaults to Device.DEFAULT. llm_engine (LlmEngine): The inference engine to use. Defaults to LlmEngine.CMS. + load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False. debug (Optional[bool]): Run in debug mode if set to True. """ @@ -135,7 +137,7 @@ def serve_model( if model_path: model_service = model_service_dep() model_service.model_name = model_name - model_service.init_model() + model_service.init_model(load_in_4bit=load_in_4bit) cms_globals.model_manager_dep = ModelManagerDep(model_service) elif mlflow_model_uri: model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path) @@ -187,6 +189,7 @@ def train_model( description: Optional[str] = typer.Option(None, help="The description of the training or change logs"), model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"), device: Device = typer.Option(Device.DEFAULT.value, help="The device to train the model on"), + load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"), debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), ) -> None: """ @@ -206,6 +209,7 @@ def train_model( description (Optional[str]): The optional description of the training or change logs. model_name (Optional[str]): The optional string representation of the model name. device (Device): The device to train the model on. Defaults to Device.DEFAULT. + load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False. debug (Optional[bool]): Run in debug mode if set to True. """ @@ -229,7 +233,7 @@ def train_model( pass model_service = model_service_dep() model_service.model_name = model_name if model_name is not None else "CMS model" - model_service.init_model() + model_service.init_model(load_in_4bit=load_in_4bit) elif mlflow_model_uri: model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path) model_service.model_name = model_name if model_name is not None else "CMS model" diff --git a/app/domain.py b/app/domain.py index c9d38cf..6be1564 100644 --- a/app/domain.py +++ b/app/domain.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Union from fastapi import HTTPException from starlette.status import HTTP_400_BAD_REQUEST @@ -27,12 +27,19 @@ class Tags(str, Enum): Evaluating = "Evaluate the deployed model with trainer export" Authentication = "Authenticate registered users" Generative = "Generate text based on the input prompt" + OpenAICompatible = "Compatible with OpenAI APIs" class TagsStreamable(str, Enum): + Metadata = "Get the model card" Streaming = "Retrieve NER entities as a stream by running the model" +class TagsGenerative(str, Enum): + Metadata = "Get the model card" + Generative = "Generate text based on the input prompt" + + class CodeType(str, Enum): SNOMED = "SNOMED" UMLS = "UMLS" @@ -103,6 +110,19 @@ class LlmEngine(Enum): CMS = "CMS" VLLM = "vLLM" +class LlmRole(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + +class LlmTrainerType(Enum): + GRPO = "grpo" + PPO = "ppo" + +class LlmDatasetType(Enum): + JSON = "json" + CSV = "csv" class Annotation(BaseModel): doc_name: Optional[str] = Field(default=None, description="The name of the document to which the annotation belongs") @@ -167,3 +187,42 @@ class Doc(BaseModel): text: str = Field(description="The text from which the entities are extracted") ents: List[Entity] = Field(description="The list of extracted entities") title: Optional[str] = Field(default=None, description="The headline of the text") + + +class PromptRole(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +class PromptMessage(BaseModel): + role: PromptRole = Field(description="The role who generates the message") + content: str = Field(description="The actual text of the message") + + +class OpenAIChatRequest(BaseModel): + messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model") + stream: bool = Field(..., description="Whether to stream the response") + max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0) + model: str = Field(..., description="The name of the model used for generating the completion") + temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0) + + +class OpenAIChatResponse(BaseModel): + id: str = Field(..., description="The unique identifier for the chat completion request") + object: str = Field(..., description="The type of the response") + created: int = Field(..., description="The timestamp when the completion was generated") + model: str = Field(..., description="The name of the model used for generating the completion") + choices: List = Field(..., description="The generated messages and their metadata") + + +class OpenAIEmbeddingsRequest(BaseModel): + input: Union[str, List[str]] = Field(..., description="Input text or list of texts to embed") + model: str = Field(..., description="The name of the model used for creating the embeddings") + + +class OpenAIEmbeddingsResponse(BaseModel): + object: str = Field(..., description="The type of the response") + data: List[Dict[str, Any]] = Field(..., description="List of embedding objects") + model: str = Field(..., description="The name of the model used for creating the embeddings") diff --git a/app/exception.py b/app/exception.py index 1b8f9bc..ddba71b 100644 --- a/app/exception.py +++ b/app/exception.py @@ -27,4 +27,12 @@ class ClientException(Exception): class DatasetException(Exception): - """ An exception raised due to dataset errors""" + """An exception raised due to dataset errors""" + + +class DeviceNotAvailableError(RuntimeError): + """An exception raised when a specificy device is required but not available.""" + + +class ExtraDependencyRequiredException(Exception): + """An exception raised when an extra dependency is required but not found.""" diff --git a/app/management/prometheus_metrics.py b/app/management/prometheus_metrics.py index 3f48858..78c5698 100644 --- a/app/management/prometheus_metrics.py +++ b/app/management/prometheus_metrics.py @@ -34,3 +34,24 @@ "Number of bulk-processed documents", ["handler"], ) + +# The histogram metric to track the number of tokens in the messages of the input prompt +cms_prompt_tokens = Histogram( + "cms_prompt_tokens", + "Number of tokens in the messages of the input prompt", + ["handler"], +) + +# The histogram metric to track the number of tokens in the generated assistant reply +cms_completion_tokens = Histogram( + "cms_completion_tokens", + "Number of tokens in the generated assistant reply", + ["handler"], +) + +# The histogram metric to track the total number of tokens used in the prompt and the completion +cms_total_tokens = Histogram( + "cms_total_tokens", + "Number of tokens used in the prompt and the completion", + ["handler"], +) diff --git a/app/model_services/base.py b/app/model_services/base.py index a3c1ccc..a7b6323 100644 --- a/app/model_services/base.py +++ b/app/model_services/base.py @@ -1,6 +1,6 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable +from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable, Union from app.config import Settings from app.domain import ModelCard, Annotation @@ -17,7 +17,7 @@ def tracker_client(self) -> Any: T = TypeVar("T", bound=_TrainerCommon) class AbstractModelService(ABC, Generic[T]): - """An abstract base class defining the common interface for all model services.""" + """An abstract base class defining the common interface for NER model services.""" @abstractmethod def __init__(self, config: Settings, *args: Any, **kwargs: Any) -> None: @@ -154,10 +154,14 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: raise NotImplementedError @abstractmethod - def init_model(self) -> None: + def init_model(self, *args: Any, **kwargs: Any) -> None: """ Initialises the model and auxiliary resources. + Args: + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + Raises: NotImplementedError: If the method is not implemented by the subclass. """ @@ -200,6 +204,29 @@ def generate_async(self, prompt: str, *args: Any, **kwargs: Any) -> AsyncIterabl raise NotImplementedError + def create_embeddings( + self, + text: Union[str, List[str]], + *args: Any, + **kwargs: Any + ) -> Union[List[float], List[List[float]]]: + """ + Creates embeddings for a given text or list of texts. + + Args: + text (Union[str, List[str]]): The text(s) to be embedded. + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + + Returns: + Union[List[float], List[List[float]]]: The embedding vector(s) for the text(s). + + Raises: + NotImplementedError: If the method is not implemented by the subclass. + """ + + raise NotImplementedError + def train_supervised(self, *args: Any, **kwargs: Any) -> Tuple[bool, str, str]: """ Initiates supervised training on the model. diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 566bbf9..a747739 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -1,18 +1,21 @@ import os import logging import asyncio +import torch from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple, Any, AsyncIterable +from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, TextIO, Callable, Union from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, TextIteratorStreamer, + BitsAndBytesConfig, ) from app import __version__ as app_version from app.exception import ConfigurationException from app.model_services.base import AbstractModelService +from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer from app.domain import ModelCard, ModelType, Annotation from app.config import Settings from app.utils import ( @@ -122,13 +125,19 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase) return model_service @staticmethod - def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + def load_model( + model_file_path: str, + *args: Tuple, + load_in_4bit: bool = False, + **kwargs: Dict[str, Any] + ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: """ Loads a pre-trained model and its tokenizer from a model package file. Args: model_file_path (str): The path to the model package file. *args (Tuple): Additional positional arguments. + load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False. **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: @@ -141,7 +150,16 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path)) if unpack_model_data_package(model_file_path, model_path): try: - model = AutoModelForCausalLM.from_pretrained(model_path) + if load_in_4bit: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) + else: + model = AutoModelForCausalLM.from_pretrained(model_path) ensure_tensor_contiguity(model) tokenizer = AutoTokenizer.from_pretrained( model_path, @@ -156,8 +174,14 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> else: raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}") - def init_model(self) -> None: - """Initialises the HuggingFace model and its tokenizer based on the configuration.""" + def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> None: + """Initialises the HuggingFace model and its tokenizer based on the configuration. + + Args: + load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False. + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + """ if all([ hasattr(self, "_model"), @@ -167,9 +191,11 @@ def init_model(self) -> None: ]): logger.warning("Model service is already initialised and can be initialised only once") else: - self._model, self._tokenizer = self.load_model(self._model_pack_path) + self._model, self._tokenizer = self.load_model(self._model_pack_path, load_in_4bit=load_in_4bit) + if non_default_device_is_available(get_settings().DEVICE): + self._model.to(get_settings().DEVICE) if self._enable_trainer: - logger.error("Trainers are not yet implemented for HuggingFace Generative models") + self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self) def info(self) -> ModelCard: """ @@ -191,13 +217,22 @@ def annotate(self, text: str) -> List[Annotation]: def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: raise NotImplementedError("Batch annotation is not yet implemented for HuggingFace Generative models") - def generate(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> str: + def generate( + self, + prompt: str, + max_tokens: int = 512, + temperature: float = 0.7, + report_tokens: Optional[Callable[[str], None]] = None, + **kwargs: Any + ) -> str: """ Generates text based on the prompt. Args: prompt (str): The prompt for the text generation max_tokens (int): The maximum number of tokens to generate. Defaults to 512. + temperature (float): The temperature for the text generation. Defaults to 0.7. + report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None. **kwargs (Any): Additional keyword arguments to be passed to this method. Returns: @@ -214,26 +249,39 @@ def generate(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> str: inputs=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, - do_sample=True, - temperature=0.7, + do_sample=False, + temperature=temperature, top_p=0.9, ) outputs = self.model.generate(**generation_kwargs) generated_text = self.tokenizer.decode(outputs[0], skip_prompt=True, skip_special_tokens=True) - - logger.debug("Response generation completed") + if report_tokens: + report_tokens( + prompt_token_num=inputs.input_ids.shape[-1], # type: ignore + completion_token_num=outputs[0].shape[-1], # type: ignore + ) + return generated_text - async def generate_async(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> AsyncIterable: + async def generate_async( + self, + prompt: str, + max_tokens: int = 512, + temperature: float = 0.7, + report_tokens: Optional[Callable[[str], None]] = None, + **kwargs: Any + ) -> AsyncIterable: """ Asynchronously generates text stream based on the prompt. Args: prompt (str): The prompt for the text generation. max_tokens (int): The maximum number of tokens to generate. Defaults to 512. + temperature (float): The temperature for the text generation. Defaults to 0.7. + report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None. **kwargs (Any): Additional keyword arguments to be passed to the model loader. Returns: @@ -257,18 +305,123 @@ async def generate_async(self, prompt: str, max_tokens: int = 512, **kwargs: Any streamer=streamer, max_new_tokens=max_tokens, do_sample=True, - temperature=0.7, + temperature=temperature, top_p=0.9, ) try: _ = self._text_generator.submit(self.model.generate, **generation_kwargs) + output = "" for content in streamer: yield content + output += content await asyncio.sleep(0.01) + if report_tokens: + report_tokens( + prompt_token_num=inputs.input_ids.shape[-1], # type: ignore + completion_token_num=self.tokenizer( # type: ignore + output, + add_special_tokens=False, + return_tensors="pt" + ).input_ids.shape[-1], + ) except Exception as e: logger.error("An error occurred while generating the response") logger.exception(e) return finally: logger.debug("Chat response generation completed") + + def create_embeddings( + self, + text: Union[str, List[str]], + *args: Any, + **kwargs: Any + ) -> Union[List[float], List[List[float]]]: + """ + Creates embeddings for a given text or list of texts using the model's hidden states. + + Args: + text (Union[str, List[str]]): The text(s) to be embedded. + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + + Returns: + List[float], List[List[float]]: The embedding vector(s) for the text(s). + + Raises: + NotImplementedError: If the model doesn't support embeddings. + """ + + self.model.eval() + + inputs = self.tokenizer( + text, + add_special_tokens=False, + return_tensors="pt", + padding=True, + truncation=True, + ) + + if non_default_device_is_available(self._config.DEVICE): + inputs.to(get_settings().DEVICE) + + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + + last_hidden_state = outputs.hidden_states[-1] + attention_mask = inputs["attention_mask"] + masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1) + sum_hidden_states = masked_hidden_states.sum(dim=1) + num_tokens = attention_mask.sum(dim=1, keepdim=True) + embeddings = sum_hidden_states / num_tokens + l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + results = l2_normalised.cpu().numpy().tolist() + return results[0] if isinstance(text, str) else results + + def train_supervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> Tuple[bool, str, str]: + """ + Initiates supervised training on the model. + + Args: + data_file (TextIO): The file containing the trainer export data. + epochs (int): The number of training epochs. + log_frequency (int): The number of epochs after which training metrics will be logged. + training_id (str): A unique identifier for the training process. + input_file_name (str): The name of the input file to be logged. + raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None. + description (Optional[str]): The description of the training or change logs. Defaults to empty. + synchronised (bool): Whether to wait for the training to complete. + **hyperparams (Dict[str, Any]): Additional hyperparameters for training. + + Returns: + Tuple[bool, str, str]: A tuple with the first element indicating success or failure. + + Raises: + ConfigurationException: If the supervised trainer is not enabled. + """ + if self._supervised_trainer is None: + raise ConfigurationException("The supervised trainer is not enabled") + return self._supervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index d982836..e741705 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -175,8 +175,13 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> else: raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}") - def init_model(self) -> None: - """Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration.""" + def init_model(self, *args: Any, **kwargs: Any) -> None: + """Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration. + + Args: + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + """ if all([ hasattr(self, "_model"), diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py index a3a6f2c..9ab5235 100644 --- a/app/model_services/medcat_model.py +++ b/app/model_services/medcat_model.py @@ -119,8 +119,13 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> else: raise ConfigurationException("Model package archive format is not supported") - def init_model(self) -> None: - """Initializes the MedCAT model based on the configuration.""" + def init_model(self, *args: Any, **kwargs: Any) -> None: + """Initializes the MedCAT model based on the configuration. + + Args: + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + """ if hasattr(self, "_model") and isinstance(self._model, CAT): logger.warning("Model service is already initialised and can be initialised only once") diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py index fe94dde..9ec1248 100644 --- a/app/model_services/medcat_model_deid.py +++ b/app/model_services/medcat_model_deid.py @@ -178,8 +178,13 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: return annotations_list - def init_model(self) -> None: - """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration.""" + def init_model(self, *args: Any, **kwargs: Any) -> None: + """Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration. + + Args: + *args (Any): Additional positional arguments to be passed to this method. + **kwargs (Any): Additional keyword arguments to be passed to this method. + """ if hasattr(self, "_model") and isinstance(self._model, CAT): logger.warning("Model service is already initialised and can be initialised only once") diff --git a/app/model_services/trf_model_deid.py b/app/model_services/trf_model_deid.py index fb8e3ac..fbf6290 100644 --- a/app/model_services/trf_model_deid.py +++ b/app/model_services/trf_model_deid.py @@ -86,7 +86,7 @@ def load_model( logger.info("Model loaded from %s", unpacked_model_dir) return tokenizer, model - def init_model(self) -> None: + def init_model(self, *args: Any, **kwargs: Any) -> None: if hasattr(self, "_model") and isinstance(self._model, PreTrainedModel): logger.warning("Model service is already initialised and can be initialised only once") else: diff --git a/app/processors/metrics_collector.py b/app/processors/metrics_collector.py index 07f0592..84f74da 100644 --- a/app/processors/metrics_collector.py +++ b/app/processors/metrics_collector.py @@ -194,6 +194,36 @@ def concat_trainer_exports( return combined +def concat_json_lists( + data_file_paths: List[str], + combined_data_file_path: Optional[str] = None, +) -> Union[List[Dict[str, Any]], str]: + """ + Concatenates multiple json list files into a single combined file. + + Args: + data_file_paths (List[str]): List of paths to files each containing a json list. + combined_data_file_path (Optional[str]): The file path where the combined data will be saved. If None, the combined data will be returned as a list. + + + Returns: + Union[List[Dict[str, Any]], str]: The path to the combined data file if `combined_data_file_path` is provided, or the combined data as a list otherwise. + """ + combined: List = [] + for path in data_file_paths: + with open(path, "r") as f: + data = json.load(f) + combined.extend(data) + + if isinstance(combined_data_file_path, str): + with open(combined_data_file_path, "w") as f: + json.dump(combined, f) + + return combined_data_file_path + else: + return combined + + def get_stats_from_trainer_export( trainer_export: Union[str, IO, Dict], return_df: bool = False, diff --git a/app/processors/prompt_factory.py b/app/processors/prompt_factory.py new file mode 100644 index 0000000..ee1a45c --- /dev/null +++ b/app/processors/prompt_factory.py @@ -0,0 +1,262 @@ +class PromptFactory: + + _ALPACA = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'].strip() + '\n' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{% set content = system_message + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + + "{% if message['role'] == 'user' %}" + "{{ '### Instruction:\n' + content.strip() + '\n\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{ '### Response:\n' + content.strip() + '\n\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '### Response:\n' }}" + "{% endif %}" + ) + + _CHAT_ML = ( + "{% for message in messages %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() + '<|im_end|>' + '\n'}}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{'<|im_start|>assistant\n'}}" + "{% endif %}" + ) + + _DEFAULT = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "{{'<|user|>\n' + message['content'] + eos_token}}" + "{% elif message['role'] == 'system' %}" + "{{'<|system|>\n' + message['content'] + eos_token}}" + "{% elif message['role'] == 'assistant' %}" + "{{'<|assistant|>\n' + message['content'] + eos_token}}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{'<|assistant|>'}}" + "{% endif %}" + "{% endfor %}" + ) + + _FALCON = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{{ system_message.strip() }}" + "{% endif %}" + "{{ '\n\n' + message['role'].title() + ': ' + message['content'].strip().replace('\r\n', '\n').replace('\n\n', '\n') }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{ '\n\nAssistant:' }}" + "{% endif %}" + ) + + _GEMMA = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'].strip() + '\n\n' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{% set content = system_message + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if (message['role'] == 'assistant') %}" + "{% set role = 'model' %}" + "{% else %}" + "{% set role = message['role'] %}" + "{% endif %}" + "{{ '' + role + '\n' + content.strip() + '\n' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{'model\n'}}" + "{% endif %}" + ) + + _LLAMA_2 = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = '<>\n' + messages[0]['content'].strip() + '\n<>\n\n' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{% set content = system_message + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + + _LLAMA_3 = ( + "{{ bos_token }}" + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + messages[0]['content'].strip() + '<|eot_id|>' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{{ system_message }}" + "{% endif %}" + "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'].strip() + '<|eot_id|>' }}" + "{% if loop.last and message['role'] == 'user' and add_generation_prompt %}" + "{{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}" + "{% endif %}" + "{% endfor %}" + ) + + _MISTRAL = ( + "{{ bos_token }}" + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'].strip() + '\n\n' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{% set content = system_message + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ content.strip() + eos_token}}" + "{% endif %}" + "{% endfor %}" + ) + + _PHI_2 = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'].strip() + '\n\n' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = '' %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 %}" + "{% set content = system_message + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ 'Instruct: ' + content.strip() + '\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ 'Output: ' + content.strip() + '\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ 'Output:' }}" + "{% endif %}" + ) + + _PHI_3 = ( + "{{ bos_token }}" + "{% for message in messages %}" + "{% if (message['role'] == 'system') %}" + "{{'<|system|>' + '\n' + message['content'].strip() + '<|end|>' + '\n'}}" + "{% elif (message['role'] == 'user') %}" + "{{'<|user|>' + '\n' + message['content'].strip() + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{message['content'].strip() + '<|end|>' + '\n'}}" + "{% endif %}" + "{% endfor %}" + ) + + _QWEN = ( + "{% for message in messages %}" + "{% if loop.first and messages[0]['role'] != 'system' %}" + "{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}" + "{% endif %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() }}" + "{% if (loop.last and add_generation_prompt) or not loop.last %}" + "{{ '<|im_end|>' + '\n'}}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + ) + + @classmethod + def create_chat_template(cls, name: str = "default") -> str: + if name.lower() == "default": + return cls._DEFAULT + elif name.lower() == "alpaca": + return cls._ALPACA + elif name.lower() == "chat_ml": + return cls._CHAT_ML + elif name.lower() == "falcon": + return cls._FALCON + elif name.lower() == "gemma": + return cls._GEMMA + elif name.lower() == "llama_2": + return cls._LLAMA_2 + elif name.lower() == "llama_3": + return cls._LLAMA_3 + elif name.lower() == "mistral": + return cls._MISTRAL + elif name.lower() == "phi_2": + return cls._PHI_2 + elif name.lower() == "phi_3": + return cls._PHI_3 + elif name.lower() == "qwen": + return cls._QWEN + else: + raise ValueError("Invalid template name") diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py new file mode 100644 index 0000000..b85f44b --- /dev/null +++ b/app/trainers/huggingface_llm_trainer.py @@ -0,0 +1,802 @@ +import os +import logging +import math +import torch +import gc +import datasets +import re +import threading +import json +import inspect +import pandas as pd +from typing import final, Dict, TextIO, Optional, Any, List, Tuple, TYPE_CHECKING, Callable +from transformers import __version__ as transformers_version +from transformers import ( + TrainingArguments, + PreTrainedModel, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, + TrainerCallback, + TrainerState, + TrainerControl, +) +from peft import LoraConfig, get_peft_model # type: ignore +from app.management.model_manager import ModelManager +from app.management.tracker_client import TrackerClient +from app.utils import ( + reset_random_seed, + non_default_device_is_available, + create_model_data_package, + get_model_data_package_extension, + load_pydantic_object_from_dict, + get_default_chat_template, + get_default_system_prompt, + get_model_data_package_base_name, +) +from app.trainers.base import SupervisedTrainer +from app.domain import ModelType, TrainerBackend, LlmRole, LlmTrainerType, LlmDatasetType, PromptMessage, Device +from app.exception import ( + TrainingCancelledException, + ManagedModelException, + DatasetException, + ConfigurationException, + DeviceNotAvailableError, + ExtraDependencyRequiredException, +) +if TYPE_CHECKING: + from app.model_services.huggingface_llm_model import HuggingFaceLlmModel + +logger = logging.getLogger("cms") + + +class _HuggingFaceLlmTrainerCommon(object): + + @staticmethod + def deploy_model( + model_service: "HuggingFaceLlmModel", + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + del model_service.model + del model_service.tokenizer + gc.collect() + model_service.model = model + model_service.tokenizer = tokenizer + logger.info("Retrained model deployed") + + +@final +class HuggingFaceLlmSupervisedTrainer(SupervisedTrainer, _HuggingFaceLlmTrainerCommon): + """ + A supervised trainer class for HuggingFace LLM models. + + Args: + model_service (HuggingFaceLlmModel): An instance of the HuggingFace LLM model service. + """ + + MIN_EXAMPLE_COUNT_FOR_TRAINABLE_CONCEPT = 5 + MAX_CONCEPTS_TO_TRACK = 20 + PAD_LABEL_ID = -100 + DEFAULT_LABEL_ID = 0 + CONTINUING_TOKEN_LABEL_ID = 1 + + def __init__(self, model_service: "HuggingFaceLlmModel") -> None: + if not isinstance(model_service.tokenizer, PreTrainedTokenizerFast): + logger.error("The supervised trainer requires a fast tokenizer to function correctly") + raise ManagedModelException("The supervised trainer requires a fast tokenizer to function correctly") + SupervisedTrainer.__init__(self, model_service._config, model_service.model_name) + self._model_service = model_service + self._model_name = model_service.model_name + self._model_pack_path = model_service._model_pack_path + self._retrained_models_dir = os.path.join( + model_service._model_parent_dir, + "retrained", + self._model_name.replace(" ", "_"), + ) + self._model_manager = ModelManager(type(model_service), model_service._config) + self._max_length = model_service.model.config.max_position_embeddings + os.makedirs(self._retrained_models_dir, exist_ok=True) + + def _load_dataset_from_config(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: + """ + Loads training and validation datasets based on configuration in training_params. + + Args: + data_file: The training data file + training_params: Dictionary containing dataset configuration + + Returns: + Tuple of (train_dataset, validation_dataset) + """ + dataset_type = training_params.get("dataset_type", "json") + + # if dataset_type == "huggingface": + # return self._load_huggingface_dataset(training_params) + if dataset_type == LlmDatasetType.JSON.value: + return self._load_json_dataset(data_file, training_params) + elif dataset_type == LlmDatasetType.CSV.value: + return self._load_csv_dataset(data_file, training_params) + else: + raise DatasetException(f"Unsupported dataset type: {dataset_type}") + + @staticmethod + def _set_dataset_format(train_dataset: datasets.Dataset, test_dataset: datasets.Dataset) -> None: + """Sets the format of the datasets based on the dataset structure.""" + + if "messages" in train_dataset.column_names: + train_dataset.set_format(type=None, columns=["messages"]) + test_dataset.set_format(type=None, columns=["messages"]) + elif "question" in train_dataset.column_names and "answer" in train_dataset.column_names: + train_dataset.set_format(type=None, columns=["question", "answer"]) + test_dataset.set_format(type=None, columns=["question", "answer"]) + elif "input" in train_dataset.column_names and "output" in train_dataset.column_names: + train_dataset.set_format(type=None, columns=["input", "output"]) + test_dataset.set_format(type=None, columns=["input", "output"]) + elif "prompt" in train_dataset.column_names and "completion" in train_dataset.column_names: + train_dataset.set_format(type=None, columns=["prompt", "completion"]) + test_dataset.set_format(type=None, columns=["prompt", "completion"]) + elif "problem" in train_dataset.column_names and "solution" in train_dataset.column_names: + train_dataset.set_format(type=None, columns=["problem", "solution"]) + test_dataset.set_format(type=None, columns=["problem", "solution"]) + else: + raise DatasetException("Unsupported dataset format") + + def _load_huggingface_dataset(self, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: + """Loads dataset from HuggingFace Hub.""" + + dataset_id = training_params.get("dataset_id", "AI-MO/NuminaMath-TIR") + test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] + split_ratio = 1 - test_size + train_percentage = int(split_ratio * 100) + test_percentage = 100 - train_percentage + train_split = training_params.get("train_split", f"train[:{train_percentage}%]") + test_split = training_params.get("test_split", f"test[:{test_percentage}%]") + + logger.info(f"Loading HuggingFace dataset: {dataset_id}") + train_dataset, test_dataset = datasets.load_dataset(dataset_id, split=[train_split, test_split]) + self._set_dataset_format(train_dataset, test_dataset) + + return train_dataset, test_dataset + + + def _load_json_dataset(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: + """Loads dataset from JSON file.""" + + data = json.load(data_file) + test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] + split_ratio = 1 - test_size + + if isinstance(data, list): + examples = data + split_idx = int(len(examples) * split_ratio) + train_examples = examples[:split_idx] + test_examples = examples[split_idx:] + elif isinstance(data, dict) and "train" in data and "test" in data: + train_examples = data["train"] + test_examples = data["test"] + elif isinstance(data, dict) and "examples" in data: + examples = data["examples"] + split_idx = int(len(examples) * split_ratio) + train_examples = examples[:split_idx] + test_examples = examples[split_idx:] + else: + raise DatasetException("Unsupported JSON format") + + train_dataset = datasets.Dataset.from_list(train_examples) + test_dataset = datasets.Dataset.from_list(test_examples) + self._set_dataset_format(train_dataset, test_dataset) + + return train_dataset, test_dataset + + def _load_csv_dataset(self, data_file: TextIO, training_params: Dict) -> Tuple[datasets.Dataset, datasets.Dataset]: + """Loads dataset from CSV file.""" + + df = pd.read_csv(data_file) + test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] + split_ratio = 1 - test_size + split_idx = int(len(df) * split_ratio) + + train_df = df[:split_idx] + test_df = df[split_idx:] + + train_dataset = datasets.Dataset.from_pandas(train_df) + test_dataset = datasets.Dataset.from_pandas(test_df) + self._set_dataset_format(train_dataset, test_dataset) + + return train_dataset, test_dataset + + def _create_conversation_formatter(self, training_params: Dict) -> Callable: + """ + Creates a conversation formatter based on training parameters. + + Args: + training_params: Dictionary containing formatting configuration + + Returns: + Function that formats examples into conversations + """ + format_config = training_params.get("format_config", {}) + system_prompt = format_config.get("system_prompt", get_default_system_prompt()) + + def make_conversation(example: Dict[str, Any]) -> Dict[str, Any]: + # Handle different input formats + if "messages" in example: + system_content = None + question_content = None + answer_content = None + for message in example.get("messages", []): + msg = load_pydantic_object_from_dict(PromptMessage, message) + if msg.role == LlmRole.SYSTEM: + system_content = msg.content + elif msg.role == LlmRole.USER: + question_content = msg.content + elif msg.role == LlmRole.ASSISTANT: + answer_content = msg.content + + return { + "prompt": [ + {"role": "system", "content": system_prompt if system_content is None else system_content}, + {"role": "user", "content": question_content if question_content is not None else ""}, + ], + "answer": answer_content if answer_content is not None else "", + } + elif "question" in example and "answer" in example: + # Question/Answer format + return { + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example.get("question")}, + ], + "answer": example["answer"], + } + elif "input" in example and "output" in example: + # Input/Output format + return { + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example.get("input")}, + ], + "answer": example["output"], + } + elif "prompt" in example and "completion" in example: + # Prompt/Completion format + return { + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example.get("prompt")}, + ], + "answer": example["completion"], + } + elif "problem" in example and "solution" in example: + # Problem/Solution format + return { + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example.get("problem")}, + ], + "answer": example["solution"], + } + else: + raise DatasetException(f"Cannot determine the conversation format from example: {example}") + + return make_conversation + + def run( + self, + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: + """ + Runs the supervised training loop for HuggingFace LLM models. + + Args: + training_params (Dict): A dictionary containing parameters for the training. + data_file (Union[TextIO, tempfile.TemporaryDirectory]): The file-like object or temporary directory containing the training data. + log_frequency (int): The frequency at which logs should be recorded (e.g, the number of processed documents or finished epochs). + run_id (str): The run ID of the training job. + description (Optional[str]): The optional description of the training or change logs. + """ + + if self._config.DEVICE is not Device.GPU.value: + raise DeviceNotAvailableError("This trainer currently requires a CUDA device") + + try: + from trl import GRPOConfig, GRPOTrainer # , PPOConfig, PPOTrainer + except ImportError: + logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") + raise ExtraDependencyRequiredException("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.") + + trained_model_pack_path = None + redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true" + skip_save_model = self._config.SKIP_SAVE_MODEL == "true" + results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results")) + logs_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "logs")) + reset_random_seed() + eval_mode = training_params["nepochs"] == 0 + self._tracker_client.log_trainer_mode(not eval_mode) + trainer = None + max_seq_length = 1024 + + if not eval_mode: + try: + logger.info("Loading a PEFT model for training...") + model_pack_file_ext = get_model_data_package_extension(self._model_pack_path) + trained_model_pack_path = self._model_pack_path.replace( + model_pack_file_ext, + f"_trained_{run_id}{model_pack_file_ext}", + ) + model, tokenizer = self._model_service.model, self._model_service.tokenizer + trained_model_directory = os.path.join( + os.path.dirname(trained_model_pack_path), + get_model_data_package_base_name(trained_model_pack_path), + ) + + if non_default_device_is_available(self._config.DEVICE): + model.to(self._config.DEVICE) + + train_dataset, test_dataset = self._load_dataset_from_config(data_file, training_params) + make_conversation = self._create_conversation_formatter(training_params) + train_dataset = train_dataset.map(make_conversation) + test_dataset = test_dataset.map(make_conversation) + + if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + logger.warning("The tokenizer does not have a chat template. Using the default one.") + tokenizer.chat_template = get_default_chat_template() + else: + logger.debug(f"Found a chat template in the tokenizer:\n {tokenizer.chat_template}") + + lora_config = LoraConfig( + task_type="CAUSAL_LM", + r=8, + lora_alpha=32, + lora_dropout=0.1, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + ) + + peft_model = get_peft_model(model, lora_config) + + mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) + cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) + trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] + + trainer_type = training_params.get("trainer_type", LlmTrainerType.GRPO.value).lower() + max_prompt_length = max(train_dataset.map( + lambda x: { + "tokens": tokenizer.apply_chat_template( + x["prompt"], + add_generation_prompt=True, + tokenize=True + ) + }, + batched=True, + ).map(lambda x: {"length": len(x["tokens"])})["length"]) + 1 + if trainer_type == LlmTrainerType.PPO.value: + raise NotImplementedError("PPO training is not yet supported for HuggingFace LLM models") + elif trainer_type == LlmTrainerType.GRPO.value: + training_args = GRPOConfig( + output_dir=results_path, + logging_dir=logs_path, + logging_steps=log_frequency, + learning_rate=5e-6, + adam_beta1=0.9, + adam_beta2=0.99, + weight_decay=0.1, + warmup_ratio=0.1, + lr_scheduler_type="cosine", + optim="paged_adamw_8bit", + per_device_train_batch_size=6, # This global batch size must be divisible by the number of generations + gradient_accumulation_steps=1, + num_generations=6, + max_prompt_length=max_prompt_length, + max_completion_length=max_seq_length - max_prompt_length, + num_train_epochs = training_params["nepochs"], + save_steps=250, + max_grad_norm=0.1, + report_to="none", + ) + trainer = GRPOTrainer( + model=peft_model, + processing_class=tokenizer, + reward_funcs=self._get_reward_functions(), + args=training_args, + train_dataset=train_dataset, + eval_dataset=test_dataset, + callbacks=trainer_callbacks, + ) + else: + raise ConfigurationException(f"Unsupported trainer type: {trainer_type}") + + self._tracker_client.log_model_config({**model.config.to_dict(), **peft_model.peft_config}) + self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version) + + logger.info(f"Performing {trainer_type.upper()} training...") + trainer.train() + + if cancel_event_check_callback.training_cancelled: + raise TrainingCancelledException("Training was cancelled by the user") + + if not skip_save_model: + model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE) + model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}" + retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name) + model = peft_model.merge_and_unload() + model.save_pretrained( + trained_model_directory, + safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"), + ) + tokenizer.save_pretrained(trained_model_directory) + create_model_data_package(trained_model_directory, retrained_model_pack_path) + model_uri = self._tracker_client.save_model( + retrained_model_pack_path, + self._model_name, + self._model_manager, + ) + logger.info(f"Retrained model saved: {model_uri}") + else: + logger.info("Skipped saving on the retrained model") + if redeploy: + self.deploy_model(self._model_service, model, tokenizer) + else: + del model + del tokenizer + gc.collect() + logger.info("Skipped deployment on the retrained model") + logger.info("Supervised training finished") + self._tracker_client.end_with_success() + except TrainingCancelledException as e: + logger.exception(e) + logger.info("Supervised training was cancelled") + del model + gc.collect() + self._tracker_client.end_with_interruption() + except torch.OutOfMemoryError as e: + logger.exception("Supervised training failed on CUDA OOM") + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + try: + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + except Exception: + pass + torch.cuda.synchronize() + except Exception: + pass + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + except Exception as e: + logger.exception("Supervised training failed") + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + finally: + data_file.close() + with self._training_lock: + self._training_in_progress = False + self._clean_up_training_cache() + self._housekeep_file(trained_model_pack_path) + if trainer is not None: + del trainer + gc.collect() + torch.cuda.empty_cache() + else: + try: + logger.info("Evaluating the running model...") + include_rewards_metrics = training_params.get("include_rewards_metrics", False) + model, tokenizer = self._model_service.model, self._model_service.tokenizer + if non_default_device_is_available(self._config.DEVICE): + model.to(self._config.DEVICE) + + eval_dataset, _ = self._load_dataset_from_config(data_file, training_params) + make_conversation = self._create_conversation_formatter(training_params) + eval_dataset = eval_dataset.map(make_conversation) + max_prompt_length = max(eval_dataset.map( + lambda x: { + "tokens": tokenizer.apply_chat_template( + x["prompt"], + add_generation_prompt=True, + tokenize=True + ) + }, + batched=True, + ).map(lambda x: {"length": len(x["tokens"])})["length"]) + 1 + + training_args = GRPOConfig( + output_dir=results_path, + logging_dir=logs_path, + logging_steps=log_frequency, + per_device_eval_batch_size=6, + num_generations=2, + max_prompt_length=max_prompt_length, + max_completion_length=max_seq_length - max_prompt_length, + num_train_epochs=training_params["nepochs"], + report_to="none", + do_train=False, + do_eval=True, + ) + + mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client) + cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event) + trainer_callbacks = [mlflow_logging_callback, cancel_event_check_callback] + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + args=training_args, + reward_funcs=self._get_reward_functions(), + train_dataset=None, + eval_dataset=eval_dataset, + callbacks=trainer_callbacks, + ) + + eval_metrics = trainer.evaluate() + if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics: + eval_metrics.update({"perplexity": math.exp(eval_metrics["eval_loss"])}) + logger.info(f"Evaluation metrics: {eval_metrics}") + self._tracker_client.send_hf_metrics_logs(eval_metrics, 0) + if include_rewards_metrics: + try: + reward_metrics = self._evaluate_with_rewards( + model=model, + tokenizer=tokenizer, + eval_dataset=eval_dataset, + max_new_tokens=training_args.max_completion_length, + ) + if reward_metrics: + logger.info(f"Reward metrics: {reward_metrics}") + self._tracker_client.send_hf_metrics_logs(reward_metrics, 0) + except Exception as e: + logger.warning(f"Failed to compute reward-based metrics: {e}") + self._tracker_client.end_with_success() + logger.info("Model evaluation finished") + except torch.OutOfMemoryError as e: + logger.exception("Evaluation failed on CUDA OOM") + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + try: + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + except Exception: + pass + torch.cuda.synchronize() + except Exception: + pass + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + except Exception as e: + logger.exception("Evaluation failed") + self._tracker_client.log_exceptions(e) + self._tracker_client.end_with_failure() + finally: + data_file.close() + with self._training_lock: + self._training_in_progress = False + self._clean_up_training_cache() + if trainer is not None: + del trainer + gc.collect() + torch.cuda.empty_cache() + + @staticmethod + def _get_reward_functions() -> List: + + def extract_xml_answer(text: str) -> str: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + + # Reward functions + def correctness_reward_func( + prompts: List, + completions: List, + answer: List, + **kwargs: Dict[str, Any] + ) -> List[float]: + responses = [completion[0]["content"] for completion in completions] + q = prompts[0][-1]["content"] + extracted_responses = [extract_xml_answer(r) for r in responses] + logger.debug( + "%s\nQuestion:\n%s\nAnswer:\n%s\nResponse:\n%s\nExtracted:\n%s", + "-" * 20, + q, + answer[0], + responses[0], + extracted_responses[0] + ) + return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + + def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + responses = [completion[0]["content"] for completion in completions] + extracted_responses = [extract_xml_answer(r) for r in responses] + return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] + + def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r"^\n.*?\n\n\n.*?\n\n$" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r) for r in responses] + return [0.5 if match else 0.0 for match in matches] + + def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r".*?\s*.*?" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r) for r in responses] + return [0.5 if match else 0.0 for match in matches] + + def count_xml(text: str) -> float: + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + count -= len(text.split("\n\n")[-1]) * 0.001 + if text.count("\n") == 1: + count += 0.125 + count -= (len(text.split("\n")[-1]) - 1) * 0.001 + return count + + def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]: + contents = [completion[0]["content"] for completion in completions] + return [count_xml(c) for c in contents] + + return [ + xmlcount_reward_func, + soft_format_reward_func, + strict_format_reward_func, + int_reward_func, + correctness_reward_func, + ] + + def _evaluate_with_rewards( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + eval_dataset: datasets.Dataset, + max_new_tokens: int, + ) -> Dict[str, float]: + model.eval() + if non_default_device_is_available(self._config.DEVICE): + model.to(self._config.DEVICE) + + reward_funcs = self._get_reward_functions() + reward_sums: Dict[str, float] = {fn.__name__: 0.0 for fn in reward_funcs} + count = 0 + + for example in eval_dataset: + if "prompt" not in example: + continue + messages = example["prompt"] + answer = example.get("answer", "") + + prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = tokenizer(prompt_text, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask") + if non_default_device_is_available(self._config.DEVICE): + input_ids = input_ids.to(self._config.DEVICE) + attention_mask = attention_mask.to(self._config.DEVICE) + + with torch.no_grad(): + generated = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + temperature=0.0, + eos_token_id=getattr(tokenizer, "eos_token_id", None), + pad_token_id=getattr(tokenizer, "pad_token_id", 0), + ) + + completion_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True) + for fn in reward_funcs: + sig = inspect.signature(fn) + kwargs: Dict[str, Any] = {} + if "prompts" in sig.parameters: + kwargs["prompts"] = [messages] + if "completions" in sig.parameters: + kwargs["completions"] = [({"content": completion_text},)] + if "answer" in sig.parameters: + kwargs["answer"] = [answer] + + try: + rewards = fn(**kwargs) # type: ignore + value = float(rewards[0]) if isinstance(rewards, (list, tuple)) and rewards else float(rewards) + except Exception: + value = 0.0 + + reward_sums[fn.__name__] += value + count += 1 + if count == 0: + return {} + + reward_avgs = {f"reward_{name}": total / count for name, total in reward_sums.items()} + reward_overall_mean = sum(reward_avgs.values()) / len(reward_avgs) if reward_avgs else 0.0 + reward_avgs["reward_overall_mean"] = reward_overall_mean + reward_avgs["reward_samples"] = float(count) + return reward_avgs + + +@final +class MLflowLoggingCallback(TrainerCallback): + """ + A callback class for logging training metrics to MLflow. + + Args: + tracker_client (TrackerClient): An instance of TrackerClient used for logging. + """ + + def __init__(self, tracker_client: TrackerClient) -> None: + self.tracker_client = tracker_client + self.epoch = 0 + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Dict[str, float], + **kwargs: Dict[str, Any], + ) -> None: + """ + Logs metrics at the end of each epoch. + + Args: + args (TrainingArguments): The arguments used for training. + state (TrainerState): The current state of the Trainer. + control (TrainerControl): The current control of the Trainer. + logs (Dict[str, float]): A dictionary containing the metrics to log. + **kwargs (Dict[str, Any]): Additional keyword arguments. + """ + + if logs is not None: + if logs.get("eval_loss", None) is not None: + logs["perplexity"] = math.exp(logs["eval_loss"]) + self.tracker_client.send_hf_metrics_logs(logs, self.epoch) + self.epoch += 1 + + +@final +class CancelEventCheckCallback(TrainerCallback): + """ + A callback class for checking a cancellation event during training. + + Args: + cancel_event (threading.Event): A threading event that signals whether training should be cancelled. + """ + + def __init__(self, cancel_event: threading.Event) -> None: + self.cancel_event = cancel_event + self.training_cancelled = False + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs: Dict[str, Any], + ) -> None: + """ + Checks if the training should be cancelled at the end of each training step. + + Args: + args (TrainingArguments): The arguments used for training. + state (TrainerState): The current state of the Trainer. + control (TrainerControl): The current control of the Trainer. + **kwargs (Dict[str, Any]): Additional keyword arguments. + """ + + if self.cancel_event.is_set(): + control.should_training_stop = True + self.cancel_event.clear() + self.training_cancelled = True diff --git a/app/utils.py b/app/utils.py index 245da97..47faefa 100644 --- a/app/utils.py +++ b/app/utils.py @@ -20,13 +20,14 @@ from spacy.lang.en import English from spacy.util import filter_spans from safetensors.torch import load_file -from transformers import PreTrainedModel +from transformers import PreTrainedModel, PreTrainedTokenizer from urllib.parse import ParseResult from functools import lru_cache from typing import List, Optional, Dict, Callable, Any, Union, Type, TypeVar from app.config import Settings -from app.domain import Annotation, Entity, CodeType, ModelType, Device +from app.domain import Annotation, Entity, CodeType, ModelType, Device, PromptMessage, PromptRole from app.exception import ManagedModelException +from app.processors.prompt_factory import PromptFactory @lru_cache @@ -546,11 +547,6 @@ def unpack_model_data_package(model_data_file_path: str, model_data_folder_path: elif model_data_file_path.endswith(".tar.gz"): with tarfile.open(model_data_file_path, "r:gz") as f: for member in f.getmembers(): - path_parts = member.name.split(os.sep) - stripped_path = os.sep.join(path_parts[1:]) - if not stripped_path: - continue - member.name = stripped_path f.extract(member, path=model_data_folder_path) return True else: @@ -682,6 +678,24 @@ def load_pydantic_object_from_dict(model: Type[T], obj: Dict) -> T: raise TypeError("Model must have a known method for parsing objects.") +def dump_pydantic_object_to_dict(model: BaseModel) -> Dict: + """ + Dumps the pydantic model object to a dictionary. + + Args: + model (BaseModel): The pydantic model to dump. + + Returns: + Dict: The dictionary object. + """ + + if hasattr(model, "model_dump"): + return model.model_dump(mode="json") # type: ignore + elif hasattr(model, "dict"): + return model.dict() # type: ignore + else: + raise TypeError("Model must have a known method for dumping objects.") + def download_model_package( model_package_url: str, destination_path: str, @@ -721,6 +735,112 @@ def download_model_package( retry_delay *= 2 +def get_default_chat_template() -> str: + """ + Gets the default chat template. + + Returns: + str: The default chat template. + """ + + return ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ '[INST] ' + content + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content + ' ' }}" + "{% endif %}" + "{% endfor %}" + ) + + +def get_default_system_prompt() -> str: + """ + Gets the default system prompt. + + Returns: + str: The default system prompt. + """ + return ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "process and answer are enclosed within and tags, respectively, i.e., " + " reasoning process here answer here " + ) + + +def get_prompt_from_messages( + tokenizer: PreTrainedTokenizer, + messages: List[PromptMessage], + override_template: Optional[str] = None, +) -> str: + """ + Generates a prompt from a list of prompt messages. + + Args: + tokenizer (PreTrainedTokenizer): The tokenizer to use for applying the chat template. + messages (List[PromptMessage]): The list of prompt messages to use for generating the prompt. + override_template (str): The name of the chat template to use for generating the prompt. + + Returns: + str: The generated prompt. + """ + if override_template is None: + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + prompt = tokenizer.apply_chat_template( + [dump_pydantic_object_to_dict(message) for message in messages], + tokenize=False, + add_generation_prompt=True, + ) + elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template: + # This largely depends on how older versions of HF tokenizers behave and may not work universally + tokenizer.chat_template = tokenizer.default_chat_template + prompt = tokenizer.apply_chat_template( + [dump_pydantic_object_to_dict(message) for message in messages], + tokenize=False, + add_generation_prompt=True, + ) + else: + system_content = "" + prompt_parts: List[str] = [] + for message in messages: + content = message.content.strip() + if message.role == PromptRole.SYSTEM: + system_content = content + elif message.role == PromptRole.USER: + prompt_parts.append(f"<|user|>\n{content}") + elif message.role == PromptRole.ASSISTANT: + prompt_parts.append(f"<|assistant|>\n{content}") + if system_content: + prompt = f"<|system|>\n{system_content}\n" + "\n".join(prompt_parts) + else: + prompt = "\n".join(prompt_parts) + prompt += "\n<|assistant|>\n" + else: + tokenizer.chat_template = PromptFactory.create_chat_template(name=override_template) + prompt = tokenizer.apply_chat_template( + [dump_pydantic_object_to_dict(message) for message in messages], + tokenize=False, + add_generation_prompt=True, + ) + return prompt + + TYPE_ID_TO_NAME_PATCH = { "32816260": "physical object", "2680757": "observable entity", diff --git a/pyproject.toml b/pyproject.toml index 01b8d3d..0aaaaea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "pynvml~=11.5.3", "toml~=0.10.2", "peft<0.14.0", - "huggingface-hub~=0.32.0", + "huggingface-hub~=0.33.0", ] readme = "README.md" keywords = ["natural-language-processing", "electronic-health-records", "clinical-data"] @@ -71,6 +71,7 @@ dev = [ "locust<2.32.0", "typer-cli~=0.15.1", "types-toml==0.10.8.20240310", + "openai>=1.84.0", ] docs = [ "sphinx~=7.1.2", @@ -81,6 +82,8 @@ docs = [ vllm = [ "vllm~=0.8.5; python_version >= '3.9'", + "trl>=0.11.4", + "bitsandbytes>=0.45.5", ] # For pip versions not supporting PEP 735 @@ -99,6 +102,7 @@ dev = [ "locust<2.32.0", "typer-cli~=0.15.1", "types-toml==0.10.8.20240310", + "openai>=1.84.0", ] docs = [ "sphinx~=7.1.2", @@ -109,6 +113,8 @@ docs = [ vllm = [ "vllm~=0.8.5; python_version >= '3.9'", + "trl>=0.11.4", + "bitsandbytes>=0.45.5", ] [tool.setuptools] diff --git a/tests/app/api/test_api.py b/tests/app/api/test_api.py index 1a0399b..1d5f077 100644 --- a/tests/app/api/test_api.py +++ b/tests/app/api/test_api.py @@ -27,7 +27,6 @@ def test_get_model_server(): assert {"name": "Training", "description": "Trigger model training on input annotations"} in tags assert {"name": "Evaluating", "description": "Evaluate the deployed model with trainer export"} in tags assert {"name": "Authentication", "description": "Authenticate registered users"} in tags - assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags assert "/info" in paths assert "/process" in paths assert "/process_jsonl" in paths @@ -91,7 +90,8 @@ def test_get_generative_server(): assert isinstance(info["title"], str) assert isinstance(info["summary"], str) assert isinstance(info["version"], str) - assert {"name": "Streaming", "description": "Retrieve NER entities as a stream by running the model"} in tags + assert {"name": "Metadata", "description": "Get the model card"} in tags + assert {"name": "Generative", "description": "Generate text based on the input prompt"} in tags assert "/info" in paths assert "/generate" in paths assert "/stream/generate" in paths diff --git a/tests/app/api/test_serving_hf_llm.py b/tests/app/api/test_serving_hf_llm.py index 39e82bf..9b7ea9f 100644 --- a/tests/app/api/test_serving_hf_llm.py +++ b/tests/app/api/test_serving_hf_llm.py @@ -1,8 +1,10 @@ import httpx +import json import pytest import app.api.globals as cms_globals from unittest.mock import create_autospec +from fastapi.testclient import TestClient from app.api.api import get_generative_server from app.model_services.huggingface_llm_model import HuggingFaceLlmModel from app.utils import get_settings @@ -27,14 +29,94 @@ def llm_app(llm_model_service): yield app app.dependency_overrides.clear() +@pytest.fixture(scope="function") +def client(llm_model_service): + llm_model_service.model_name = "HuggingFace LLM model" + llm_model_service.generate.return_value = "Yeah." + llm_model_service.create_embeddings.return_value = [[1.0, 2.0, 3.0]] + app = get_generative_server(config, msd_overwritten=lambda: llm_model_service) + app.dependency_overrides[cms_globals.props.current_active_user] = lambda: None + client = TestClient(app) + yield client + client.app.dependency_overrides.clear() + + +def test_generate(client): + response = client.post( + "/generate?max_tokens=128&temperature=0.7", + data="Alright?", + headers={"Content-Type": "text/plain"}, + ) + + assert response.status_code == 200 + assert response.headers["x-cms-tracking-id"], "x-cms-tracking-id header is missing" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + assert response.text == "Yeah." + @pytest.mark.asyncio async def test_stream_generate(llm_model_service, llm_app): + llm_model_service.generate_async.return_value = "Fine." async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: response = await ac.post( - "/stream/generate?max_tokens=32", + "/stream/generate?max_tokens=32&temperature=0.7", data="How are you doing?", headers={"Content-Type": "text/plain"}, ) - assert response.status_code == 200 \ No newline at end of file + assert response.status_code == 200 + assert response.headers["x-cms-tracking-id"], "x-cms-tracking-id header is missing" + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.text == "Fine." + + +@pytest.mark.asyncio +async def test_generate_chat_completions(llm_model_service, llm_app): + llm_model_service.generate.return_value = "I'm a chat bot." + request_data = { + "messages": [ + { + "role": "system", + "content": "You are a chat bot." + }, + { + "role": "user", + "content": "Who are you?" + } + ], + "model": "HuggingFace LLM model", + "stream": True, + "max_tokens": 128, + "temperature": 0.7 + } + async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: + response = await ac.post( + "/v1/chat/completions?max_tokens=128&temperature=0.7", + data=json.dumps(request_data), + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.text.startswith("data:") + assert "id" in response.text + assert "chat.completion.chunk" in response.text + + +def test_create_embeddings(client): + request_data = { + "input": ["Alright"], + "model": "HuggingFace LLM model", + } + response = client.post( + "/v1/embeddings", + data=json.dumps(request_data), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + assert response.json() == { + "object": "list", + "data": [{"object": "embedding", "embedding": [1.0, 2.0, 3.0], "index": 0}], + "model": "HuggingFace LLM model" + } diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py index 1f33572..e6ee637 100644 --- a/tests/app/model_services/test_huggingface_llm_model.py +++ b/tests/app/model_services/test_huggingface_llm_model.py @@ -1,4 +1,5 @@ import os +from unittest.mock import MagicMock, patch from tests.app.conftest import MODEL_PARENT_DIR from transformers import PreTrainedModel, PreTrainedTokenizerBase from app import __version__ @@ -43,5 +44,189 @@ def test_info(huggingface_llm_model): def test_generate(huggingface_llm_model): huggingface_llm_model.init_model() - output = huggingface_llm_model.generate("How are you doing?") - assert isinstance(output, str) + huggingface_llm_model.model = MagicMock() + huggingface_llm_model.tokenizer = MagicMock() + mock_send_metrics = MagicMock() + inputs = MagicMock() + inputs.input_ids = MagicMock(shape=[1, 2]) + inputs.attention_mask = MagicMock() + huggingface_llm_model.tokenizer.return_value = inputs + outputs = [MagicMock(shape=[2])] + huggingface_llm_model.model.generate.return_value = outputs + huggingface_llm_model.tokenizer.decode.return_value = "Yeah." + + result = huggingface_llm_model.generate( + prompt="Alright?", + max_tokens=128, + temperature=0.5, + report_tokens=mock_send_metrics + ) + + huggingface_llm_model.tokenizer.assert_called_once_with( + "Alright?", + add_special_tokens=False, + return_tensors="pt", + ) + huggingface_llm_model.model.generate.assert_called_once_with( + inputs=inputs.input_ids, + attention_mask=inputs.attention_mask, + max_new_tokens=128, + do_sample=False, + temperature=0.5, + top_p=0.9, + ) + huggingface_llm_model.tokenizer.decode.assert_called_once_with( + outputs[0], + skip_prompt=True, + skip_special_tokens=True, + ) + mock_send_metrics.assert_called_once_with( + prompt_token_num=2, + completion_token_num=2, + ) + assert result == "Yeah." + + +async def test_generate_async(huggingface_llm_model): + huggingface_llm_model.init_model() + huggingface_llm_model.model = MagicMock() + huggingface_llm_model.tokenizer = MagicMock() + mock_send_metrics = MagicMock() + inputs = MagicMock() + inputs.input_ids = MagicMock(shape=[1, 2]) + inputs.attention_mask = MagicMock() + huggingface_llm_model.tokenizer.return_value = inputs + outputs = [MagicMock(shape=[2])] + huggingface_llm_model.model.generate.return_value = outputs + huggingface_llm_model.tokenizer.decode.return_value = "Yeah." + + result = await huggingface_llm_model.generate_async( + prompt="Alright?", + max_tokens=128, + temperature=0.5, + report_tokens=mock_send_metrics + ) + + huggingface_llm_model.tokenizer.assert_called_once_with( + "Alright?", + add_special_tokens=False, + return_tensors="pt", + ) + huggingface_llm_model.model.generate_async.assert_called_once_with( + inputs=inputs.input_ids, + attention_mask=inputs.attention_mask, + max_new_tokens=128, + do_sample=False, + temperature=0.5, + top_p=0.9, + ) + huggingface_llm_model.tokenizer.decode.assert_called_once_with( + outputs[0], + skip_prompt=True, + skip_special_tokens=True, + ) + mock_send_metrics.assert_called_once_with( + prompt_token_num=2, + completion_token_num=2, + ) + assert result == "Yeah." + + +def test_create_embeddings_single_text(huggingface_llm_model): + """Test create_embeddings with single text input.""" + huggingface_llm_model.init_model() + huggingface_llm_model.model = MagicMock() + huggingface_llm_model.tokenizer = MagicMock() + mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()] + mock_outputs = MagicMock() + mock_outputs.hidden_states = mock_hidden_states + mock_last_hidden_state = MagicMock() + mock_last_hidden_state.shape = [1, 3, 768] + mock_hidden_states[-1] = mock_last_hidden_state + mock_attention_mask = MagicMock() + mock_attention_mask.shape = [1, 3] + mock_attention_mask.sum.return_value = MagicMock() + mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock() + mock_inputs = MagicMock() + mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock() + huggingface_llm_model.tokenizer.return_value = mock_inputs + huggingface_llm_model.model.return_value = mock_outputs + expected_result = [0.1, 0.2, 0.3] + mock_embeddings_batch = MagicMock() + mock_first_embedding = MagicMock() + mock_cpu_tensor = MagicMock() + mock_numpy_array = MagicMock() + mock_numpy_array.tolist.return_value = expected_result + mock_embeddings_batch.__getitem__.return_value = mock_first_embedding + mock_first_embedding.cpu.return_value = mock_cpu_tensor + mock_cpu_tensor.numpy.return_value = mock_numpy_array + mock_masked_hidden_states = MagicMock() + mock_sum_hidden_states = MagicMock() + mock_num_tokens = MagicMock() + mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states + mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states + mock_attention_mask.sum.return_value = mock_num_tokens + mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch + + result = huggingface_llm_model.create_embeddings("Alright") + + huggingface_llm_model.model.eval.assert_called_once() + huggingface_llm_model.tokenizer.assert_called_once_with( + "Alright", + add_special_tokens=False, + return_tensors="pt", + padding=True, + truncation=True + ) + huggingface_llm_model.model.assert_called_once_with( + **mock_inputs, + output_hidden_states=True + ) + + assert result is not None + + +def test_create_embeddings_list_text(huggingface_llm_model): + huggingface_llm_model.init_model() + huggingface_llm_model.model = MagicMock() + huggingface_llm_model.tokenizer = MagicMock() + mock_hidden_states = [MagicMock(), MagicMock(), MagicMock()] + mock_outputs = MagicMock() + mock_outputs.hidden_states = mock_hidden_states + mock_last_hidden_state = MagicMock() + mock_last_hidden_state.shape = [2, 3, 768] + mock_hidden_states[-1] = mock_last_hidden_state + mock_attention_mask = MagicMock() + mock_attention_mask.shape = [2, 3] + mock_attention_mask.sum.return_value = MagicMock() + mock_attention_mask.sum.return_value.unsqueeze.return_value = MagicMock() + mock_inputs = MagicMock() + mock_inputs.__getitem__.side_effect = lambda key: mock_attention_mask if key == "attention_mask" else MagicMock() + huggingface_llm_model.tokenizer.return_value = mock_inputs + huggingface_llm_model.model.return_value = mock_outputs + mock_embeddings_batch = MagicMock() + mock_first_embedding = MagicMock() + mock_cpu_tensor = MagicMock() + mock_numpy_array = MagicMock() + mock_numpy_array.tolist.return_value = [[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]] + mock_embeddings_batch.__getitem__.return_value = mock_first_embedding + mock_first_embedding.cpu.return_value = mock_cpu_tensor + mock_cpu_tensor.numpy.return_value = mock_numpy_array + mock_masked_hidden_states = MagicMock() + mock_sum_hidden_states = MagicMock() + mock_num_tokens = MagicMock() + mock_last_hidden_state.__mul__.return_value = mock_masked_hidden_states + mock_masked_hidden_states.sum.return_value = mock_sum_hidden_states + mock_attention_mask.sum.return_value = mock_num_tokens + mock_sum_hidden_states.__truediv__.return_value = mock_embeddings_batch + + result = huggingface_llm_model.create_embeddings(["Alright", "Alright"]) + + huggingface_llm_model.tokenizer.assert_called_once_with( + ["Alright", "Alright"], + add_special_tokens=False, + return_tensors="pt", + padding=True, + truncation=True, + ) + assert result is not None diff --git a/tests/app/processors/test_metrics_collector.py b/tests/app/processors/test_metrics_collector.py index bd46fff..6e53275 100644 --- a/tests/app/processors/test_metrics_collector.py +++ b/tests/app/processors/test_metrics_collector.py @@ -7,6 +7,7 @@ from app.processors.metrics_collector import ( sanity_check_model_with_trainer_export, concat_trainer_exports, + concat_json_lists, get_stats_from_trainer_export, get_iaa_scores_per_concept, get_iaa_scores_per_doc, @@ -332,3 +333,56 @@ def test_get_iaa_scores_per_span_and_return_dataframe(): assert len(result["cohens_kappa"]) == 30 assert len(result["iaa_percentage_meta"]) == 30 assert len(result["cohens_kappa_meta"]) == 30 + + +def test_concat_json_lists_return_list(): + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f1: + json.dump([{"question": "question_1", "answer": "answer_1"}, {"question": "question_2", "answer": "answer_2"}], f1) + file1_path = f1.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f2: + json.dump([{"question": "question_3", "answer": "answer_3"}], f2) + file2_path = f2.name + + try: + result = concat_json_lists([file1_path, file2_path]) + + assert isinstance(result, list) + assert len(result) == 3 + assert result[0] == {"question": "question_1", "answer": "answer_1"} + assert result[1] == {"question": "question_2", "answer": "answer_2"} + assert result[2] == {"question": "question_3", "answer": "answer_3"} + finally: + os.unlink(file1_path) + os.unlink(file2_path) + + +def test_concat_json_lists_save_to_file(): + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f1: + json.dump([{"question": "question_1", "answer": "answer_1"}], f1) + file1_path = f1.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f2: + json.dump([{"question": "question_2", "answer": "answer_2"}], f2) + file2_path = f2.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as output_file: + output_path = output_file.name + + try: + result = concat_json_lists([file1_path, file2_path], output_path) + + assert isinstance(result, str) + assert result == output_path + + with open(output_path, 'r') as f: + saved_data = json.load(f) + + assert isinstance(saved_data, list) + assert len(saved_data) == 2 + assert saved_data[0] == {"question": "question_1", "answer": "answer_1"} + assert saved_data[1] == {"question": "question_2", "answer": "answer_2"} + finally: + os.unlink(file1_path) + os.unlink(file2_path) + os.unlink(output_path) diff --git a/tests/app/processors/test_prompt_factory.py b/tests/app/processors/test_prompt_factory.py new file mode 100644 index 0000000..ab03388 --- /dev/null +++ b/tests/app/processors/test_prompt_factory.py @@ -0,0 +1,228 @@ +from jinja2.sandbox import ImmutableSandboxedEnvironment +from app.processors.prompt_factory import PromptFactory + + +def test_create_default_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("default")) + prompt = template.render( + messages=messages, + bos_token="<|system|>", + eos_token="<|end|>", + add_generation_prompt=True, + ) + assert prompt == "<|system|>\nAlright?<|end|><|user|>\nYeah.<|end|><|assistant|>" + + +def test_create_alpaca_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("alpaca")) + prompt = template.render( + messages=messages, + add_generation_prompt=True, + ) + assert prompt == "### Instruction:\nAlright?\nYeah.\n\n### Response:\n" + + +def test_create_chat_ml_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("chat_ml")) + prompt = template.render( + messages=messages, + add_generation_prompt=True, + ) + assert prompt == "<|im_start|>system\nAlright?<|im_end|>\n<|im_start|>user\nYeah.<|im_end|>\n<|im_start|>assistant\n" + +def test_create_falcon_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("falcon")) + prompt = template.render( + messages=messages, + add_generation_prompt=True, + ) + assert prompt == "Alright?\n\nUser: Yeah.{ '\n\nAssistant:' }}" + +def test_create_gemma_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("gemma")) + prompt = template.render( + messages=messages, + add_generation_prompt=True, + ) + assert prompt == "user\nAlright?\n\nYeah.\nmodel\n" + +def test_create_llama_2_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("LLAMA_2")) + prompt = template.render( + messages=messages, + bos_token="", + eos_token="", + add_generation_prompt=True, + ) + assert prompt == "[INST] <>\nAlright?\n<>\n\nYeah. [/INST]" + +def test_create_llama_3_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("LLAMA_2")) + prompt = template.render( + messages=messages, + bos_token="", + eos_token="", + add_generation_prompt=True, + ) + assert prompt == "[INST] <>\nAlright?\n<>\n\nYeah. [/INST]" + +def test_create_mistral_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("mistral")) + prompt = template.render( + messages=messages, + bos_token="", + eos_token="", + add_generation_prompt=True, + ) + assert prompt == "[INST] Alright?\n\nYeah. [/INST]" + +def test_create_phi_2_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("phi_2")) + prompt = template.render( + messages=messages, + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_generation_prompt=True, + ) + assert prompt == "Instruct: Alright?\n\nYeah.\nOutput:" + +def test_create_phi_3_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("phi_3")) + prompt = template.render( + messages=messages, + bos_token="", + eos_token="<|end|>", + add_generation_prompt=True, + ) + assert prompt == "<|system|>\nAlright?<|end|>\n<|user|>\nYeah.<|end|>\n<|assistant|>\n" + +def test_create_qwen_chat_template(): + env = ImmutableSandboxedEnvironment() + messages = [ + { + "role": "system", + "content": "Alright?" + }, + { + "role": "user", + "content": "Yeah." + }, + ] + template = env.from_string(PromptFactory.create_chat_template("qwen")) + prompt = template.render( + messages=messages, + bos_token="", + eos_token="<|end|>", + add_generation_prompt=True, + ) + assert prompt == "<|im_start|>system\nAlright?<|im_end|>\n<|im_start|>user\nYeah.<|im_end|>\n<|im_start|>assistant\n" diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 4519350..2f00e10 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -6,7 +6,7 @@ import zipfile import tarfile import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from safetensors.torch import save_file from transformers import PreTrainedModel from urllib.parse import urlparse @@ -33,8 +33,10 @@ pyproject_dependencies_to_pip_requirements, get_model_data_package_base_name, load_pydantic_object_from_dict, + dump_pydantic_object_to_dict, + get_prompt_from_messages, ) -from app.domain import Annotation, Entity +from app.domain import Annotation, Entity, PromptMessage, PromptRole def test_get_code_base_uri(): @@ -393,3 +395,63 @@ def __init__(self): def forward(self, x): return self.linear(x) + + +def test_get_prompt_with_chat_template(): + with patch('transformers.PreTrainedTokenizer') as tok: + mock_tokenizer = tok.return_value + mock_tokenizer.chat_template = "Mock chat template" + mock_tokenizer.apply_chat_template.return_value = "Mock chat template applied" + messages = [ + PromptMessage(content="Alright?", role=PromptRole.USER.value), + PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value), + ] + + prompt = get_prompt_from_messages(mock_tokenizer, messages) + + assert prompt == "Mock chat template applied" + + +def test_get_prompt_with_default_chat_template(): + with patch('transformers.PreTrainedTokenizer') as tok: + mock_tokenizer = tok.return_value + mock_tokenizer.chat_template = None + mock_tokenizer.default_chat_template = "Mock default chat template" + mock_tokenizer.apply_chat_template.return_value = "Mock default chat template applied" + messages = [ + PromptMessage(content="Alright?", role=PromptRole.USER.value), + PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value), + ] + + prompt = get_prompt_from_messages(mock_tokenizer, messages) + + assert prompt == "Mock default chat template applied" + + +def test_get_prompt_without_chat_template(): + with patch('transformers.PreTrainedTokenizer') as tok: + mock_tokenizer = tok.return_value + mock_tokenizer.chat_template = None + mock_tokenizer.default_chat_template = None + messages = [ + PromptMessage(content="You are a helpful assistant.", role=PromptRole.SYSTEM.value), + PromptMessage(content="Alright?", role=PromptRole.USER.value), + PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value), + ] + + prompt = get_prompt_from_messages(mock_tokenizer, messages) + + expected_prompt = "<|system|>\nYou are a helpful assistant.\n<|user|>\nAlright?\n<|assistant|>\nYeah.\n<|assistant|>\n" + assert prompt == expected_prompt + + +def test_get_prompt_with_no_messages(): + with patch('transformers.PreTrainedTokenizer') as tok: + mock_tokenizer = tok.return_value + mock_tokenizer.chat_template = None + mock_tokenizer.default_chat_template = None + messages = [] + + prompt = get_prompt_from_messages(mock_tokenizer, messages) + + assert prompt == "\n<|assistant|>\n" diff --git a/uv.lock b/uv.lock index 0b032d7..255a329 100644 --- a/uv.lock +++ b/uv.lock @@ -734,6 +734,55 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d6/59/831b66ba317496332d4e9e1a33bcdd14922d6cfecc411dc315a229b67127/bcrypt-4.1.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba4e4cc26610581a6329b3937e02d319f5ad4b85b074846bf4fef8a8cf51e7bb", size = 698384 }, ] +[[package]] +name = "bitsandbytes" +version = "0.45.5" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9' and sys_platform != 'win32'", + "python_full_version < '3.9' and sys_platform == 'win32'", +] +dependencies = [ + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/b7/cb5ce4d1a382cf53c19ef06c5fc29e85f5e129b4da6527dd207d90a5b8ad/bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a5453f30cc6aab6ccaac364e6bf51a7808d3da5f71763dffeb6d9694c59136e4", size = 76059261 }, + { url = "https://files.pythonhosted.org/packages/a6/4c/77b535e025ce780d2ada8271c1e481fb7337c1df2588a52fe1c9bd87d2e8/bitsandbytes-0.45.5-py3-none-win_amd64.whl", hash = "sha256:ed1c61b91d989d6a33fd05737d6edbf5086d8ebc89235ee632c7a19144085da2", size = 75430204 }, +] + +[[package]] +name = "bitsandbytes" +version = "0.47.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version > '3.11' and sys_platform == 'darwin'", + "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11' and sys_platform == 'darwin'", + "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.10.*' and sys_platform == 'darwin'", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version > '3.11' and sys_platform == 'win32'", + "python_full_version == '3.11' and sys_platform == 'win32'", + "python_full_version == '3.10.*' and sys_platform == 'win32'", + "python_full_version == '3.9.*' and sys_platform == 'win32'", +] +dependencies = [ + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/eb/477d6b5602f469c7305fd43eec71d890c39909f615c1d7138f6e7d226eff/bitsandbytes-0.47.0-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:2f805b76891a596025e9e13318b675d08481b9ee650d65e5d2f9d844084c6521", size = 30004641 }, + { url = "https://files.pythonhosted.org/packages/9c/40/91f1a5a694f434bc13cba160045fdc4e867032e627b001bf411048fefd9c/bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:68f3fffd494a47ed1fd7593bfc5dd2ac69b68260599b71b4c4b3a32f90f3b184", size = 61284639 }, + { url = "https://files.pythonhosted.org/packages/18/a9/e07a227f1cd6562844cea2f05ee576b0991a9a91f45965c06034178ba0f6/bitsandbytes-0.47.0-py3-none-win_amd64.whl", hash = "sha256:4880a6d42ca9628b5a571c8cc3093dc3f5f52511e5a9e47d52d569807975531a", size = 60725121 }, +] + [[package]] name = "blake3" version = "1.0.5" @@ -1215,6 +1264,7 @@ dev = [ { name = "locust", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "locust", version = "2.31.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "mypy" }, + { name = "openai" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-bdd" }, @@ -1233,6 +1283,10 @@ docs = [ { name = "sphinx-rtd-theme" }, ] vllm = [ + { name = "bitsandbytes", version = "0.45.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "bitsandbytes", version = "0.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "trl", version = "0.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "trl", version = "0.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "vllm", marker = "python_full_version >= '3.9'" }, ] @@ -1242,6 +1296,7 @@ dev = [ { name = "locust", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "locust", version = "2.31.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "mypy" }, + { name = "openai" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-bdd" }, @@ -1260,6 +1315,10 @@ docs = [ { name = "sphinx-rtd-theme" }, ] vllm = [ + { name = "bitsandbytes", version = "0.45.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "bitsandbytes", version = "0.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "trl", version = "0.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "trl", version = "0.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "vllm", marker = "python_full_version >= '3.9'" }, ] @@ -1267,6 +1326,7 @@ vllm = [ requires-dist = [ { name = "aiosqlite", specifier = "~=0.19.0" }, { name = "asyncpg", specifier = "~=0.27.0" }, + { name = "bitsandbytes", marker = "extra == 'vllm'", specifier = ">=0.45.5" }, { name = "blis", specifier = "<1.0.0" }, { name = "boto3", specifier = "~=1.28.84" }, { name = "click", specifier = "<8.2.0" }, @@ -1277,13 +1337,14 @@ requires-dist = [ { name = "fastapi-users-db-sqlalchemy", specifier = "~=5.0.0" }, { name = "graypy", specifier = "~=2.1.0" }, { name = "httpx", marker = "extra == 'dev'", specifier = "~=0.24.1" }, - { name = "huggingface-hub", specifier = "~=0.32.0" }, + { name = "huggingface-hub", specifier = "~=0.33.0" }, { name = "ijson", specifier = "~=3.1.4" }, { name = "locust", marker = "extra == 'dev'", specifier = "<2.32.0" }, { name = "medcat", marker = "python_full_version < '3.9'", specifier = "~=1.13.1" }, { name = "medcat", marker = "python_full_version >= '3.9'", specifier = "~=1.16.0" }, { name = "mlflow", specifier = "~=2.16.2" }, { name = "mypy", marker = "extra == 'dev'", specifier = "~=1.14.0" }, + { name = "openai", marker = "extra == 'dev'", specifier = ">=1.84.0" }, { name = "peft", specifier = "<0.14.0" }, { name = "prometheus-fastapi-instrumentator", specifier = "~=7.0.0" }, { name = "psycopg2-binary", specifier = "~=2.9.4" }, @@ -1306,6 +1367,7 @@ requires-dist = [ { name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = "~=3.0.2" }, { name = "toml", specifier = "~=0.10.2" }, { name = "torch", marker = "python_full_version < '3.9'", specifier = "<2.5.0" }, + { name = "trl", marker = "extra == 'vllm'", specifier = ">=0.11.4" }, { name = "typer", specifier = "~=0.15.1" }, { name = "typer-cli", marker = "extra == 'dev'", specifier = "~=0.15.1" }, { name = "types-toml", marker = "extra == 'dev'", specifier = "==0.10.8.20240310" }, @@ -1320,6 +1382,7 @@ dev = [ { name = "httpx", specifier = "~=0.24.1" }, { name = "locust", specifier = "<2.32.0" }, { name = "mypy", specifier = "~=1.14.0" }, + { name = "openai", specifier = ">=1.84.0" }, { name = "pytest", specifier = "~=7.1.2" }, { name = "pytest-asyncio", specifier = "~=0.23.7" }, { name = "pytest-bdd", specifier = "~=7.2.0" }, @@ -1337,7 +1400,11 @@ docs = [ { name = "sphinx-autodoc-typehints", specifier = "~=2.0.1" }, { name = "sphinx-rtd-theme", specifier = "~=3.0.2" }, ] -vllm = [{ name = "vllm", marker = "python_full_version >= '3.9'", specifier = "~=0.8.5" }] +vllm = [ + { name = "bitsandbytes", specifier = ">=0.45.5" }, + { name = "trl", specifier = ">=0.11.4" }, + { name = "vllm", marker = "python_full_version >= '3.9'", specifier = "~=0.8.5" }, +] [[package]] name = "colorama" @@ -1943,6 +2010,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774 }, ] +[[package]] +name = "docstring-parser" +version = "0.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/08/12/9c22a58c0b1e29271051222d8906257616da84135af9ed167c9e28f85cb3/docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e", size = 26565 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637", size = 36533 }, +] + [[package]] name = "docutils" version = "0.20.1" @@ -1975,6 +2051,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/05/8b171626b850e870fc4433225cd6d5bec5a9916b1c39b3d7c67a60492aeb/email_validator-2.1.2-py3-none-any.whl", hash = "sha256:d89f6324e13b1e39889eab7f9ca2f91dc9aebb6fa50a6d8bd4329ab50f251115", size = 30739 }, ] +[[package]] +name = "eval-type-backport" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/ea/8b0ac4469d4c347c6a385ff09dc3c048c2d021696664e26c7ee6791631b5/eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1", size = 9079 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/31/55cd413eaccd39125368be33c46de24a1f639f2e12349b0361b4678f3915/eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a", size = 5830 }, +] + [[package]] name = "evaluate" version = "0.4.3" @@ -3263,7 +3348,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.32.4" +version = "0.33.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", version = "3.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, @@ -3276,9 +3361,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494 } +sdist = { url = "https://files.pythonhosted.org/packages/02/16/5716d03e2b48bcc8e32d9b18ed7e55d2ae52e3d5df146cced9fe0581b5ff/huggingface_hub-0.33.5.tar.gz", hash = "sha256:814097e475646d170c44be4c38f7d381ccc4539156a5ac62a54f53aaf1602ed8", size = 427075 } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101 }, + { url = "https://files.pythonhosted.org/packages/33/d5/d9e9b75d8dc9cf125fff16fb0cd51d864a29e8b46b6880d8808940989405/huggingface_hub-0.33.5-py3-none-any.whl", hash = "sha256:29b4e64982c2064006021af297e1b17d44c85a8aaf90a0d7efeff7e7d2426296", size = 515705 }, ] [package.optional-dependencies] @@ -3620,10 +3705,88 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 }, ] +[[package]] +name = "jiter" +version = "0.9.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9' and sys_platform != 'win32'", + "python_full_version < '3.9' and sys_platform == 'win32'", +] +sdist = { url = "https://files.pythonhosted.org/packages/84/72/c28662416d9807bb5a38625eadedb82d4bd14fd2700c308ece7acdb8e89f/jiter-0.9.1.tar.gz", hash = "sha256:7852990068b6e06102ecdc44c1619855a2af63347bfb5e7e009928dcacf04fdd", size = 162540 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/5f/7f6aaca7943c644b4fd220650771f39dbfb74f9690efc6fb8c0d4092a399/jiter-0.9.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:c0163baa7ee85860fdc14cc39263014500df901eeffdf94c1eab9a2d713b2a9d", size = 312882 }, + { url = "https://files.pythonhosted.org/packages/86/0d/aac9eafc5d46bdf5c4f127ac1ce85e434d003bb5e3ae886f5e726a988cf6/jiter-0.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:514d4dd845e0af4da15112502e6fcb952f0721f27f17e530454e379472b90c14", size = 311743 }, + { url = "https://files.pythonhosted.org/packages/b8/54/fab1f4d8634af7bb1ad6dc49bee50ea9f649de0e5309c80192ace739f968/jiter-0.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b879faee1cc1a67fde3f3f370041239fd260ac452bd53e861aa4a94a51e3fd02", size = 1085889 }, + { url = "https://files.pythonhosted.org/packages/bd/86/bf4ed251d8035d5d72a46c8f9969bd5054fad052371cbea0cb161060e660/jiter-0.9.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20a5ce641f93bfb8d8e336f8c4a045e491652f41eaacc707b15b245ece611e72", size = 1117896 }, + { url = "https://files.pythonhosted.org/packages/62/40/b04c40deccd5edd5f2a3853f4a80dc0ddbe157d1d523a573fb3d224315fc/jiter-0.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8575b1d2b49df04ca82d658882f4a432b7ed315a69126a379df4d10aeb416021", size = 1211956 }, + { url = "https://files.pythonhosted.org/packages/85/f0/114e9893e4ef5b423718efe9b3da01117539c333f06ef19543c68c8b7ed1/jiter-0.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc61831699904e0c58e82943f529713833db87acd13f95a3c0feb791f862d47b", size = 1219691 }, + { url = "https://files.pythonhosted.org/packages/02/9a/1aeac4541ce1c59c65dc76dbab642232da3d8db0581df3e61b8943033bd7/jiter-0.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb733faf4d0e730d6663873249c1acb572fc8bd9dae3836ceda69751f27c5be", size = 352604 }, + { url = "https://files.pythonhosted.org/packages/6b/27/446ec6ca0a25d9d2f45ad546633a2b4a1b6a7f28fb6819c7056b163c5aee/jiter-0.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d903b3bb917c0df24f2ef62f587c8f32f6003cb2f97264109ca56c023262557f", size = 1147136 }, + { url = "https://files.pythonhosted.org/packages/09/9d/c8540bc097b07e106d060c21395c6fa6561223e7366c948a04ef0aa39979/jiter-0.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:eac3eb5206845b170142c016ae467eca523a25459dc9c53fcd8e154ea263406c", size = 1255843 }, + { url = "https://files.pythonhosted.org/packages/d3/61/9b377ecf4e09e325e90f77a7a4859ec933162f58ff5c6b7730aff6352033/jiter-0.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7ea0c20cfc61acc5335bb8ee36d639e6a4ded03f34f878b2b3038bb9f3bb553c", size = 1257536 }, + { url = "https://files.pythonhosted.org/packages/ed/f6/b6754e11ac9d02f05a2d713c0846ce813a69c1f6f7de7f1ae216c4e35ace/jiter-0.9.1-cp310-cp310-win32.whl", hash = "sha256:0f8f812dd6d2b4112db9ab4c1079c4fe73e553a500e936657fdda394fa2517e1", size = 214064 }, + { url = "https://files.pythonhosted.org/packages/1d/cb/7b9c5d6f73499d1fb5e97e36e8078f3bea00d7541a973117eccf9db1e079/jiter-0.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:f7f0198889170e7af6210509803e6527b402efc6c26f42e2896883597a10426f", size = 209952 }, + { url = "https://files.pythonhosted.org/packages/ee/3b/9f9deaef471e346354c832b6627e0d1b9ba3d9611d0e0fd394c2acf2a615/jiter-0.9.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b8564e3198c4c8d835fc95cc54d6bcbd2fd8dc33a047fecc12c208491196995", size = 312737 }, + { url = "https://files.pythonhosted.org/packages/36/00/76fa6d519f8289aad32ec1caf3716eb700ba48e3212d1dda71e74c385a5c/jiter-0.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:90b92044588d14efe89b394eca735adc4ac096eba82dc75d93c3083b1eebce8d", size = 313357 }, + { url = "https://files.pythonhosted.org/packages/b3/e9/f864ebe9ddf07761d5bdd3148b45a5d433c6cbce7c7e8be29baf806fa612/jiter-0.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3505f7f419b355c7788fcaae0dfc4c6ccbc50c0dc3633a2da797e841c5a423dc", size = 1085946 }, + { url = "https://files.pythonhosted.org/packages/82/a1/ed02d4c86d620989dcd392366daa67198961eedaf2e66f7a68f0d3846dba/jiter-0.9.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:93af8c3f4a3bf145c690e857a945eb5c655534bf95c67e1447d85c02e5af64d7", size = 1118090 }, + { url = "https://files.pythonhosted.org/packages/d3/01/d107531d215a57cda3cbc4adfcf3119166dd32adc1c332c1f3f36efd3484/jiter-0.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43b81dd21e260a249780764921b1f9a6379cb31e24e7b61e6bf0799f38ec4b91", size = 1212231 }, + { url = "https://files.pythonhosted.org/packages/45/1e/6801a81a2ef1f917fe9a7d2139e576dd4f53497c309dab9461136922709c/jiter-0.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:db639fad5631b3d1692609f6dd77b64e8578321b7aeec07a026acd2c867c04a5", size = 1219263 }, + { url = "https://files.pythonhosted.org/packages/a5/d4/40082e8666cfdb24461855e9bb29fe77f063cc65a6c903291f2e5225f780/jiter-0.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15356b943e70ca7ab3b587ffaffadc0158467f6c4e0b491e52a0743c4bdf5ba1", size = 350364 }, + { url = "https://files.pythonhosted.org/packages/c4/09/09bc72dd143f76acd55e04c3a45b9f9ee3ed28e00b49924e3702ad041812/jiter-0.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53a7033a46141ff815518a6972d657c75d8f5946b9315e1c25b07e9677c1ff6c", size = 1146802 }, + { url = "https://files.pythonhosted.org/packages/5b/34/9d15a9c04d5760537b432134447bde94b936ec73dc922b4d14a48def2e1f/jiter-0.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:68cf519a6f00b8127f9be64a37e97e978094438abced5adebe088a98c64bdcff", size = 1256019 }, + { url = "https://files.pythonhosted.org/packages/8f/01/1fcd165fb28968a54bb46a209d5919f7649b96608eef7dc4622ea378b95a/jiter-0.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9098abdd34cd9ddeb04768cc4f5fc725ebd9a52978c488da74e58a837ce93506", size = 1257610 }, + { url = "https://files.pythonhosted.org/packages/9f/87/93ac6a57331dd90e4c896ac852bf8ce6b28b40dace4b9698a207dbb99af2/jiter-0.9.1-cp311-cp311-win32.whl", hash = "sha256:7179ce96aecd096af890dd57b84133e47a59fbde32a77734f09bafa6a4da619e", size = 214515 }, + { url = "https://files.pythonhosted.org/packages/bb/ee/3678b8a3bd5f6471d0a492540e7ff9c63db278d844214458ec5cfb22adb2/jiter-0.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:e6517f5b7b6f60fd77fc1099572f445be19553c6f61b907ab5b413fb7179663f", size = 212258 }, + { url = "https://files.pythonhosted.org/packages/26/ca/1c7438d66969a13938266492de65daf752754ec59f2a3f3716027c7d708f/jiter-0.9.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:95065923a49ae387bab62b1bf5f798beb12e6fb4469a079fdd0ecad64b40b272", size = 313516 }, + { url = "https://files.pythonhosted.org/packages/e8/d9/3a6300309e312f8ed529ae57d565f69abdb520e4f12460cefa7996d0716c/jiter-0.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a179fbc5c7922844a673be35099a3036a7276dc63753c6c81a77c3cb525f2f8d", size = 308161 }, + { url = "https://files.pythonhosted.org/packages/b3/91/2aca15be38514daf8f1a1460fd9c4b652ed09148fe109520298858be7928/jiter-0.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abd30dc5c0183d31faf30ce8279d723809c54b3fe6d95d922d4a4b31bc462799", size = 1086100 }, + { url = "https://files.pythonhosted.org/packages/9f/6f/f7ba3dfe7be08bf58939324e0bb4f4aa605eff7f2c2ac140a41221cf50a4/jiter-0.9.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9765512bdeae269843e6615377f48123432da247e18048d05e9c5685377c241c", size = 1118922 }, + { url = "https://files.pythonhosted.org/packages/b5/4e/b1f4d9bdba81de293e1b8672598300a9195cf3d77b0acc5f331a75695b58/jiter-0.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f15cdbdc1e1e89e0d9ea581de63e03975043a4b40ab87d5554fdc440357b771", size = 1212327 }, + { url = "https://files.pythonhosted.org/packages/3e/ab/e417aaf5a62067bd91c5f7ed4e5ab83bd46f349449adde1159ad8e2d3a21/jiter-0.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b1a639b2cfe56b5b687c678ed45d68f46dfb922c2f338fdfb227eb500053929d", size = 1220860 }, + { url = "https://files.pythonhosted.org/packages/1e/50/c5ba756c641ca8ebc1e4ff07c03ce5c8ef5052b0238f514436f8de3c9fc4/jiter-0.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41955c9d83c8470de9cc64c97b04a3ffd2f32815bb2c4307f44d8e21542b74df", size = 344077 }, + { url = "https://files.pythonhosted.org/packages/c6/b3/bd7d8d4bad65aa1f4a20562233080054149785c0d7f7b9027e761335d882/jiter-0.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f26f6d42c330e26a6ba3471b390364faad96f3ca965a6c579957810b0c078efa", size = 1148785 }, + { url = "https://files.pythonhosted.org/packages/c0/12/bfd9a167709f96171312d1e0ae2c1be70a167abcc3bff6f3441967e3626a/jiter-0.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a23e01bd7e918f27f02d3df8721b8a395211070a8a65aeb353209b8c72720cf", size = 1255962 }, + { url = "https://files.pythonhosted.org/packages/5f/3c/3a79020862d2511b854b350bc9229cf228fd38b836e94f274ca940e22e95/jiter-0.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8a96ad217989dd9df661711c3fa2e6fb2601c4bbb482e89718110bdafbc16c9e", size = 1257561 }, + { url = "https://files.pythonhosted.org/packages/93/d3/7f6f8e57613d4947a872980befa6af19de9252e310ea4a512eed0fe1e064/jiter-0.9.1-cp38-cp38-win32.whl", hash = "sha256:4b180e7baa4747b3834c5a9202b1ba30dc64797f45236d9142cdb2a8807763cf", size = 215019 }, + { url = "https://files.pythonhosted.org/packages/9b/5d/b6f0cd60c8f702936f253644a92dee19e2c82010290e4607af462033351f/jiter-0.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:baf881de1fbc7b3343cce24f75a2ab6350e03fc13d16d00f452929788a6cdc3f", size = 199563 }, + { url = "https://files.pythonhosted.org/packages/4f/3a/a8a4768af26578c87894bb130bcd6fb6c97f4cb36ed7a20a664412d41935/jiter-0.9.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ec95aa1b433c50b2b129456b4680b239ec93206ea3f86cfd41b6a70be5beb2f3", size = 313942 }, + { url = "https://files.pythonhosted.org/packages/63/74/05977891db48000d985a5f573493c43adf0f190eada670e51b92c9ed9139/jiter-0.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5d92cb50d135dbdd33b638fa2e0c6af25e1d635d38da13aa9ab05d021fb0c869", size = 308160 }, + { url = "https://files.pythonhosted.org/packages/21/54/75f529e90442c8ad41acd8cf08323a4f3dcaa105710b2c8a1fda56e3a462/jiter-0.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b146dc2464f1d96007271d08bdf79288a5f1aa4aae5329eb79dcffb1181c703e", size = 1086503 }, + { url = "https://files.pythonhosted.org/packages/bf/fa/02532a7ce7b712c576125d4f2614e77bc897c95b2b15e21ee25f42b3ff34/jiter-0.9.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcf20ba858658ecd54b4710172d92009afa66d41d967c86d11607592a3c220fa", size = 1120444 }, + { url = "https://files.pythonhosted.org/packages/91/c2/ab8cebaea6f2691eddcc5b6c67deb1399adbd85f12ad836f7cd77be78bf8/jiter-0.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:147fccc44bebdb672d4c601e9312730488b840d415e201e89c8ea0929a63dacf", size = 1212370 }, + { url = "https://files.pythonhosted.org/packages/13/e3/90dddb7877b67cc0e1ddb864c2ca74314def26ff6542431a6e3061e0f805/jiter-0.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a428061aae26efaa6fb690ef9e7d6224aefe4eef7524165d073beb3cdad75f6f", size = 1221210 }, + { url = "https://files.pythonhosted.org/packages/81/76/90ee847519a94a4a1a8bad7addce7019f424aea03c55eacf068469226760/jiter-0.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7164d92bb901784bd3c098ac0b0beae4306ea6c741dbd3a375449a8affc5366", size = 353774 }, + { url = "https://files.pythonhosted.org/packages/59/a6/614a5d672d4b9c6bc9ad34579f0522577a0a78cc265069fca96543a832ca/jiter-0.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:93049a562233808914a2b938b0c745d7049db1667b3f42f0f5cf48e617393ba5", size = 1148581 }, + { url = "https://files.pythonhosted.org/packages/2d/94/c100147c310361fa83e25c4c6ce17723532147580252962b89e6085795c2/jiter-0.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f6dcf2cb16cc15d82a018e20eeaf169e6f6cd8c426f4c312ebe11710c623bed2", size = 1256636 }, + { url = "https://files.pythonhosted.org/packages/51/9a/dc82e218ba839052899df555e34f16b8ad1d7da9c01be208f65a5bf0083c/jiter-0.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2da9d485a7c526817cde9ff8b3394fa50ff5b782b86b6896378a3ba8844550f2", size = 1258099 }, + { url = "https://files.pythonhosted.org/packages/58/d5/d853e069624038950265ac0e877985b249049b624e925dab6cd11035140c/jiter-0.9.1-cp39-cp39-win32.whl", hash = "sha256:ea58c155d827d24e5ba8d7958ec4738b26be0894c0881a91d88b39ff48bb06c9", size = 214611 }, + { url = "https://files.pythonhosted.org/packages/cb/8d/7b6b1ee6e3d9d1a06237bbdfe4c6bb21baf323d3f70a0cc8f203de40c6b2/jiter-0.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:be2e911ecdb438951290c2079fe4190e7cc5be9e849df4caeb085b83ed620ff6", size = 211171 }, +] + [[package]] name = "jiter" version = "0.10.0" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version > '3.11' and sys_platform == 'darwin'", + "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11' and sys_platform == 'darwin'", + "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.10.*' and sys_platform == 'darwin'", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version > '3.11' and sys_platform == 'win32'", + "python_full_version == '3.11' and sys_platform == 'win32'", + "python_full_version == '3.10.*' and sys_platform == 'win32'", + "python_full_version == '3.9.*' and sys_platform == 'win32'", +] sdist = { url = "https://files.pythonhosted.org/packages/ee/9d/ae7ddb4b8ab3fb1b51faf4deb36cb48a4fbbd7cb36bad6a5fca4741306f7/jiter-0.10.0.tar.gz", hash = "sha256:07a7142c38aacc85194391108dc91b5b57093c978a9932bd86a36862759d9500", size = 162759 } wheels = [ { url = "https://files.pythonhosted.org/packages/be/7e/4011b5c77bec97cb2b572f566220364e3e21b51c48c5bd9c4a9c26b41b67/jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303", size = 317215 }, @@ -5504,7 +5667,7 @@ version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12", version = "12.1.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9' and sys_platform != 'win32'" }, - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -5532,7 +5695,7 @@ resolution-markers = [ "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -5590,9 +5753,9 @@ resolution-markers = [ "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, @@ -5623,7 +5786,7 @@ resolution-markers = [ "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, @@ -5717,14 +5880,17 @@ name = "openai" version = "1.84.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "anyio", version = "4.5.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "anyio", version = "4.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "distro", marker = "python_full_version >= '3.9'" }, - { name = "httpx", marker = "python_full_version >= '3.9'" }, - { name = "jiter", marker = "python_full_version >= '3.9'" }, + { name = "distro" }, + { name = "httpx" }, + { name = "jiter", version = "0.9.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "jiter", version = "0.10.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "pydantic", version = "1.10.22", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "sniffio", marker = "python_full_version >= '3.9'" }, - { name = "tqdm", marker = "python_full_version >= '3.9'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.9'" }, + { name = "sniffio" }, + { name = "tqdm" }, + { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/91/a3/128caf24e116f48fad3e4d5122cdf84db06c5127911849d51663c66158c8/openai-1.84.0.tar.gz", hash = "sha256:4caa43bdab262cc75680ce1a2322cfc01626204074f7e8d9939ab372acf61698", size = 467066 } wheels = [ @@ -8013,6 +8179,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 }, ] +[[package]] +name = "shtab" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/3e/837067b970c1d2ffa936c72f384a63fdec4e186b74da781e921354a94024/shtab-1.7.2.tar.gz", hash = "sha256:8c16673ade76a2d42417f03e57acf239bfb5968e842204c17990cae357d07d6f", size = 45751 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/03/3271b7bb470fbab4adf5bd30b0d32143909d96f3608d815b447357f47f2b/shtab-1.7.2-py3-none-any.whl", hash = "sha256:858a5805f6c137bb0cda4f282d27d08fd44ca487ab4a6a36d2a400263cd0b5c1", size = 14214 }, +] + [[package]] name = "six" version = "1.17.0" @@ -9269,6 +9444,73 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/74/9f12bdedeb110242d8bb1bd621f6605e753ee0cbf73cf7f3a62b8173f190/triton-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ceed0eff2c4a73b14eb63e052992f44bbdf175f3fad21e1ac8097a772de7ee", size = 253057866 }, ] +[[package]] +name = "trl" +version = "0.11.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9' and sys_platform != 'win32'", + "python_full_version < '3.9' and sys_platform == 'win32'", +] +dependencies = [ + { name = "accelerate", version = "1.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "datasets", marker = "python_full_version < '3.9'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "transformers", version = "4.46.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "tyro", marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/39/a78c0608190cc412c49631dfb8c3e57f5c5b2fb0d79709071c992e707aa4/trl-0.11.4.tar.gz", hash = "sha256:de52a023fc35d580ab809fd74cd4f362a259e463bb968580e0e97e1b98a0fe79", size = 307304 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/dd/d2cf3dbc1013cee71ceef584f5ab69915fc05d209ef1e276f8652058c350/trl-0.11.4-py3-none-any.whl", hash = "sha256:071d64164c152ef65b44d15f878793b28d3340310c9e157dc3608bbe5fa549a9", size = 316575 }, +] + +[[package]] +name = "trl" +version = "0.15.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version > '3.11' and sys_platform == 'darwin'", + "python_full_version > '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version > '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version > '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11' and sys_platform == 'darwin'", + "python_full_version == '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.10.*' and sys_platform == 'darwin'", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.9.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.9.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.9.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.9.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version > '3.11' and sys_platform == 'win32'", + "python_full_version == '3.11' and sys_platform == 'win32'", + "python_full_version == '3.10.*' and sys_platform == 'win32'", + "python_full_version == '3.9.*' and sys_platform == 'win32'", +] +dependencies = [ + { name = "accelerate", version = "1.7.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "datasets", marker = "python_full_version >= '3.9'" }, + { name = "rich", marker = "python_full_version >= '3.9'" }, + { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/fe/ae0d782c48eef56d0ec125ebd05998539ede7cbf0e307a48f9323998b9e7/trl-0.15.2.tar.gz", hash = "sha256:0f82190a058a0a194dbcfae1fe9548b68a0a05b2f4d1824f8db1ae7d949cdd47", size = 333962 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/29/25378447c48359843de0e4ce1995d367210601c3b437ddf1c779b6393d74/trl-0.15.2-py3-none-any.whl", hash = "sha256:bf2b88e3cf5da08cd533dc03273d977965bd5d86c5878f76285fba45d9cb9634", size = 318931 }, +] + +[[package]] +name = "typeguard" +version = "4.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "typing-extensions", marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/5a/91b7c8cfc2e96962442abc9d65c650436dd831910b4d7878980d6596fb98/typeguard-4.4.0.tar.gz", hash = "sha256:463bd8697a65a4aa576a63767c369b1ecfba8a5ba735edfe3223127b6ecfa28c", size = 74399 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/a3/00203767544b597a9e3c57b29a84967b3230f00bdd9aa6a52a73187043b4/typeguard-4.4.0-py3-none-any.whl", hash = "sha256:8ca34c14043f53b2caae7040549ba431770869bcd6287cfa8239db7ecb882b4a", size = 35736 }, +] + [[package]] name = "typer" version = "0.15.4" @@ -9326,6 +9568,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, ] +[[package]] +name = "tyro" +version = "0.9.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "python_full_version < '3.9' and sys_platform == 'win32'" }, + { name = "docstring-parser", marker = "python_full_version < '3.9'" }, + { name = "eval-type-backport", marker = "python_full_version < '3.9'" }, + { name = "rich", marker = "python_full_version < '3.9'" }, + { name = "shtab", marker = "python_full_version < '3.9'" }, + { name = "typeguard", marker = "python_full_version < '3.9'" }, + { name = "typing-extensions", marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/57/49/ca1698fcc5479fe9c7eff48861ebb671c5a6afba0245ea7cd560a939f281/tyro-0.9.24.tar.gz", hash = "sha256:5a9ef93d1b8e93cff2c5d82789a571d905d152e92af82a3ec96a17d668194df3", size = 303651 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/59/4c865b56babef1aa6a9662879c94507dc62d0173ac7433579d7a2728f7e5/tyro-0.9.24-py3-none-any.whl", hash = "sha256:d8152e47375419752210da455226007b4bb9bd9c65af1de8fb12daf0658c91dc", size = 128326 }, +] + [[package]] name = "tzdata" version = "2025.2" @@ -9805,8 +10065,8 @@ name = "xformers" version = "0.0.29.post2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and sys_platform != 'win32'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version == '3.9.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.9' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/27/ed/04ec7ef97a7e1c836add41ef5a2aef8cbdd45c0190ca42cc08f3c21e2b7b/xformers-0.0.29.post2.tar.gz", hash = "sha256:6ca3d1a6db6f2abff25c1154adee96987f77f4dfd5141771805afa5fc13e9395", size = 8468494 } wheels = [ @@ -9820,12 +10080,12 @@ name = "xgrammar" version = "0.1.18" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ninja", marker = "python_full_version >= '3.9'" }, - { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "sentencepiece", marker = "python_full_version >= '3.9'" }, - { name = "tiktoken", marker = "python_full_version >= '3.9'" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "ninja", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, + { name = "pydantic", version = "2.11.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, + { name = "sentencepiece", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, + { name = "tiktoken", marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, + { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, + { name = "transformers", version = "4.51.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.9' and platform_machine != 'arm64') or (python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.9' and sys_platform != 'darwin')" }, { name = "triton", version = "3.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/8f/c3/22c9eeab6ee1dd6d0513d227e9d307fd20a0491db58f1f04bc5d566d13dc/xgrammar-0.1.18.tar.gz", hash = "sha256:a0438a0f9262fff1d0e4f184268eb759f094243edce92b67eb7aa5f245c47471", size = 1697230 }