From 6b1f37ab06c1e0f89a505024b63c596fafa508ae Mon Sep 17 00:00:00 2001 From: Muhammad Zoaib Date: Tue, 27 Jan 2026 17:16:18 +0300 Subject: [PATCH 1/3] added gcp support in llm engine --- .../model_engine_server/api/dependencies.py | 56 +++++++----- .../core/celery/celery_autoscaler.py | 2 + .../infra/gateways/__init__.py | 6 ++ .../gateways/gcs_file_storage_gateway.py | 34 +++++++ .../infra/gateways/gcs_filesystem_gateway.py | 45 ++++++++++ .../gateways/gcs_llm_artifact_gateway.py | 86 ++++++++++++++++++ .../redis_queue_endpoint_resource_delegate.py | 88 +++++++++++++++++++ .../infra/repositories/__init__.py | 4 + ...cs_file_llm_fine_tune_events_repository.py | 83 +++++++++++++++++ .../gcs_file_llm_fine_tune_repository.py | 51 +++++++++++ model-engine/requirements.in | 3 + 11 files changed, 437 insertions(+), 21 deletions(-) create mode 100644 model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/resources/redis_queue_endpoint_resource_delegate.py create mode 100644 model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_events_repository.py create mode 100644 model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_repository.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 9c7dd2f76..3d51918b0 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -64,6 +64,9 @@ CeleryTaskQueueGateway, DatadogMonitoringMetricsGateway, FakeMonitoringMetricsGateway, + GCSFileStorageGateway, + GCSFilesystemGateway, + GCSLLMArtifactGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobOrchestrationGateway, LiveBatchJobProgressGateway, @@ -100,6 +103,9 @@ from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( SQSQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import ( + RedisQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.s3_file_storage_gateway import S3FileStorageGateway from model_engine_server.infra.repositories import ( ABSFileLLMFineTuneEventsRepository, @@ -112,6 +118,8 @@ DbTriggerRepository, ECRDockerRepository, FakeDockerRepository, + GCSFileLLMFineTuneEventsRepository, + GCSFileLLMFineTuneRepository, LiveTokenizerRepository, LLMFineTuneRepository, RedisModelEndpointCacheRepository, @@ -225,6 +233,9 @@ def _get_external_interfaces( queue_delegate = FakeQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + # GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate + queue_delegate = RedisQueueEndpointResourceDelegate() else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) @@ -238,7 +249,8 @@ def _get_external_interfaces( elif infra_config().cloud_provider == "azure": inference_task_queue_gateway = servicebus_task_queue_gateway infra_task_queue_gateway = servicebus_task_queue_gateway - elif infra_config().celery_broker_type_redis: + elif infra_config().cloud_provider == "gcp" or infra_config().celery_broker_type_redis: + # GCP uses Redis (Memorystore) for Celery broker inference_task_queue_gateway = redis_task_queue_gateway infra_task_queue_gateway = redis_task_queue_gateway else: @@ -274,16 +286,15 @@ def _get_external_interfaces( monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) - filesystem_gateway = ( - ABSFilesystemGateway() - if infra_config().cloud_provider == "azure" - else S3FilesystemGateway() - ) - llm_artifact_gateway = ( - ABSLLMArtifactGateway() - if infra_config().cloud_provider == "azure" - else S3LLMArtifactGateway() - ) + if infra_config().cloud_provider == "azure": + filesystem_gateway = ABSFilesystemGateway() + llm_artifact_gateway = ABSLLMArtifactGateway() + elif infra_config().cloud_provider == "gcp": + filesystem_gateway = GCSFilesystemGateway() + llm_artifact_gateway = GCSLLMArtifactGateway() + else: + filesystem_gateway = S3FilesystemGateway() + llm_artifact_gateway = S3LLMArtifactGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) @@ -331,15 +342,17 @@ def _get_external_interfaces( llm_fine_tune_repository = ABSFileLLMFineTuneRepository( file_path=file_path, ) + llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository() + elif infra_config().cloud_provider == "gcp": + llm_fine_tune_repository = GCSFileLLMFineTuneRepository( + file_path=file_path, + ) + llm_fine_tune_events_repository = GCSFileLLMFineTuneEventsRepository() else: llm_fine_tune_repository = S3FileLLMFineTuneRepository( file_path=file_path, ) - llm_fine_tune_events_repository = ( - ABSFileLLMFineTuneEventsRepository() - if infra_config().cloud_provider == "azure" - else S3FileLLMFineTuneEventsRepository() - ) + llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository() llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService( docker_image_batch_job_gateway=docker_image_batch_job_gateway, docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository, @@ -350,11 +363,12 @@ def _get_external_interfaces( docker_image_batch_job_gateway=docker_image_batch_job_gateway ) - file_storage_gateway = ( - ABSFileStorageGateway() - if infra_config().cloud_provider == "azure" - else S3FileStorageGateway() - ) + if infra_config().cloud_provider == "azure": + file_storage_gateway = ABSFileStorageGateway() + elif infra_config().cloud_provider == "gcp": + file_storage_gateway = GCSFileStorageGateway() + else: + file_storage_gateway = S3FileStorageGateway() docker_repository: DockerRepository if CIRCLECI: diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index 78a4da3f5..a66d01078 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -43,6 +43,7 @@ def excluded_namespaces(): ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" +GCP_MEMORYSTORE_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master" SQS_BROKER = "sqs-message-broker-master" SERVICEBUS_BROKER = "servicebus-message-broker-master" @@ -589,6 +590,7 @@ async def main(): BROKER_NAME_TO_CLASS = { ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), + GCP_MEMORYSTORE_REDIS_BROKER: RedisBroker(use_elasticache=True), # GCP Memorystore also doesn't support CONFIG GET SQS_BROKER: SQSBroker(), SERVICEBUS_BROKER: ASBBroker(), } diff --git a/model-engine/model_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py index f8a3ee6ee..88d958195 100644 --- a/model-engine/model_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -10,6 +10,9 @@ from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway from .fake_model_primitive_gateway import FakeModelPrimitiveGateway from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway +from .gcs_file_storage_gateway import GCSFileStorageGateway +from .gcs_filesystem_gateway import GCSFilesystemGateway +from .gcs_llm_artifact_gateway import GCSLLMArtifactGateway from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway from .live_batch_job_orchestration_gateway import LiveBatchJobOrchestrationGateway from .live_batch_job_progress_gateway import LiveBatchJobProgressGateway @@ -37,6 +40,9 @@ "DatadogMonitoringMetricsGateway", "FakeModelPrimitiveGateway", "FakeMonitoringMetricsGateway", + "GCSFileStorageGateway", + "GCSFilesystemGateway", + "GCSLLMArtifactGateway", "LiveAsyncModelEndpointInferenceGateway", "LiveBatchJobOrchestrationGateway", "LiveBatchJobProgressGateway", diff --git a/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py new file mode 100644 index 000000000..0ce1f67f2 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py @@ -0,0 +1,34 @@ +from typing import List, Optional + +from model_engine_server.domain.gateways.file_storage_gateway import ( + FileMetadata, + FileStorageGateway, +) +from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway + + +class GCSFileStorageGateway(FileStorageGateway): + """ + Concrete implementation of a file storage gateway backed by GCS. + """ + + def __init__(self): + self.filesystem_gateway = GCSFilesystemGateway() + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + raise NotImplementedError("GCS file storage not fully implemented yet") + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + raise NotImplementedError("GCS file storage not fully implemented yet") + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + raise NotImplementedError("GCS file storage not fully implemented yet") + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + raise NotImplementedError("GCS file storage not fully implemented yet") + + async def delete_file(self, owner: str, file_id: str) -> bool: + raise NotImplementedError("GCS file storage not fully implemented yet") + + async def list_files(self, owner: str) -> List[FileMetadata]: + raise NotImplementedError("GCS file storage not fully implemented yet") diff --git a/model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py new file mode 100644 index 000000000..0f717fe10 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py @@ -0,0 +1,45 @@ +import re +from datetime import timedelta +from typing import IO + +import smart_open +from google.auth import default +from google.cloud import storage +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway + + +class GCSFilesystemGateway(FilesystemGateway): + """ + Concrete implementation for interacting with a filesystem backed by Google Cloud Storage. + """ + + def _get_storage_client(self) -> storage.Client: + credentials, project = default() + return storage.Client(credentials=credentials, project=project) + + def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = self._get_storage_client() + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: + # Parse gs://bucket/key format + match = re.search(r"^gs://([^/]+)/(.*?)$", uri) + if not match: + # Try https://storage.googleapis.com/bucket/key format + match = re.search(r"^https://storage\.googleapis\.com/([^/]+)/(.*?)$", uri) + assert match, f"Invalid GCS URI: {uri}" + + bucket_name, blob_name = match.group(1), match.group(2) + + client = self._get_storage_client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + signed_url = blob.generate_signed_url( + version="v4", + expiration=timedelta(seconds=expiration), + method="GET", + **kwargs, + ) + return signed_url diff --git a/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py new file mode 100644 index 000000000..e3f56c1df --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py @@ -0,0 +1,86 @@ +import json +import os +from typing import Any, Dict, List + +from google.auth import default +from google.cloud import storage +from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.url import parse_attachment_url +from model_engine_server.domain.gateways import LLMArtifactGateway + +logger = make_logger(logger_name()) + + +def _get_gcs_client() -> storage.Client: + credentials, project = default() + return storage.Client(credentials=credentials, project=project) + + +class GCSLLMArtifactGateway(LLMArtifactGateway): + """ + Concrete implementation using Google Cloud Storage. + """ + + def list_files(self, path: str, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket_name = parsed_remote.bucket + prefix = parsed_remote.key + + client = _get_gcs_client() + bucket = client.bucket(bucket_name) + blobs = bucket.list_blobs(prefix=prefix) + return [blob.name for blob in blobs] + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket_name = parsed_remote.bucket + prefix = parsed_remote.key + + client = _get_gcs_client() + bucket = client.bucket(bucket_name) + + downloaded_files: List[str] = [] + for blob in bucket.list_blobs(prefix=prefix): + file_path_suffix = blob.name.replace(prefix, "").lstrip("/") + local_path = os.path.join(target_path, file_path_suffix).rstrip("/") + + if not overwrite and os.path.exists(local_path): + downloaded_files.append(local_path) + continue + + local_dir = "/".join(local_path.split("/")[:-1]) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + logger.info(f"Downloading {blob.name} to {local_path}") + blob.download_to_filename(local_path) + downloaded_files.append(local_path) + return downloaded_files + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url( + hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False + ) + bucket_name = parsed_remote.bucket + fine_tuned_weights_prefix = parsed_remote.key + + client = _get_gcs_client() + bucket = client.bucket(bucket_name) + + model_files: List[str] = [] + model_cache_name = get_model_cache_directory_name(model_name) + prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" + for blob in bucket.list_blobs(prefix=prefix): + model_files.append(f"gs://{bucket_name}/{blob.name}") + return model_files + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket_name = parsed_remote.bucket + key = os.path.join(parsed_remote.key, "config.json") + + client = _get_gcs_client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(key) + return json.loads(blob.download_as_text()) diff --git a/model-engine/model_engine_server/infra/gateways/resources/redis_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/redis_queue_endpoint_resource_delegate.py new file mode 100644 index 000000000..87282025e --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/redis_queue_endpoint_resource_delegate.py @@ -0,0 +1,88 @@ +"""Redis-based queue endpoint resource delegate for GCP deployments. + +When using Redis (Memorystore) as the Celery broker on GCP, queues are implicit +Redis lists that don't need explicit creation/deletion. This delegate manages +queue lifecycle for async endpoints using Redis. +""" + +from typing import Any, Dict, Sequence + +from model_engine_server.core.celery.app import get_redis_instance +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + +__all__: Sequence[str] = ("RedisQueueEndpointResourceDelegate",) + + +class RedisQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + """ + Redis-based queue delegate for GCP deployments using Memorystore. + + Redis queues (lists) are created implicitly when messages are pushed, + so this delegate mainly handles queue name management and metrics retrieval. + """ + + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + """ + For Redis, queues are created implicitly. We just return the queue name. + The queue_url is None since Redis doesn't use URLs for queue access. + """ + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + logger.info(f"Redis queue ready for endpoint: {queue_name}") + return QueueInfo(queue_name=queue_name, queue_url=None) + + async def delete_queue(self, endpoint_id: str) -> None: + """ + Delete the Redis queue (list) for the endpoint. + This removes all pending messages in the queue. + """ + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + try: + redis = get_redis_instance() + # Delete the queue (Redis list) + deleted = redis.delete(queue_name) + if deleted: + logger.info(f"Deleted Redis queue: {queue_name}") + else: + logger.info(f"Redis queue already empty or doesn't exist: {queue_name}") + redis.close() + except Exception as e: + logger.warning(f"Error deleting Redis queue {queue_name}: {e}") + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + """ + Get queue attributes including the approximate number of messages. + """ + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + try: + redis = get_redis_instance() + queue_length = redis.llen(queue_name) + redis.close() + + # Return in a format compatible with the existing code + # that checks for "Attributes.ApproximateNumberOfMessages" + return { + "name": queue_name, + "Attributes": { + "ApproximateNumberOfMessages": str(queue_length), + }, + } + except Exception as e: + logger.warning(f"Error getting Redis queue attributes for {queue_name}: {e}") + return { + "name": queue_name, + "Attributes": { + "ApproximateNumberOfMessages": "0", + }, + } diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index f14cf69f7..f8bdd05e3 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -12,6 +12,8 @@ from .ecr_docker_repository import ECRDockerRepository from .fake_docker_repository import FakeDockerRepository from .feature_flag_repository import FeatureFlagRepository +from .gcs_file_llm_fine_tune_events_repository import GCSFileLLMFineTuneEventsRepository +from .gcs_file_llm_fine_tune_repository import GCSFileLLMFineTuneRepository from .live_tokenizer_repository import LiveTokenizerRepository from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository @@ -34,6 +36,8 @@ "ECRDockerRepository", "FakeDockerRepository", "FeatureFlagRepository", + "GCSFileLLMFineTuneEventsRepository", + "GCSFileLLMFineTuneRepository", "LiveTokenizerRepository", "LLMFineTuneRepository", "ModelEndpointRecordRepository", diff --git a/model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_events_repository.py new file mode 100644 index 000000000..02e21eb8f --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_events_repository.py @@ -0,0 +1,83 @@ +import json +from json.decoder import JSONDecodeError +from typing import IO, List + +import smart_open +from google.auth import default +from google.cloud import storage +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( + LLMFineTuneEventsRepository, +) + +# Echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py +GCS_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( + f"gs://{infra_config().s3_bucket}/hosted-model-inference/fine_tuned_weights" +) + + +class GCSFileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + pass + + def _get_gcs_client(self): + credentials, project = default() + return storage.Client(credentials=credentials, project=project) + + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = self._get_gcs_client() + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + def _get_model_cache_directory_name(self, model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in GCS. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + def _get_file_location(self, user_id: str, model_endpoint_name: str): + model_cache_name = self._get_model_cache_directory_name(model_endpoint_name) + gcs_file_location = ( + f"{GCS_HF_USER_FINE_TUNED_WEIGHTS_PREFIX}/{user_id}/{model_cache_name}.jsonl" + ) + return gcs_file_location + + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + gcs_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + try: + with self._open(gcs_file_location, "r") as f: + lines = f.readlines() + final_events = [] + for line in lines: + try: + event_dict = json.loads(line) + event = LLMFineTuneEvent( + timestamp=event_dict["timestamp"], + message=str(event_dict["message"]), + level=event_dict.get("level", "info"), + ) + except JSONDecodeError: + event = LLMFineTuneEvent( + message=line, + level="info", + ) + final_events.append(event) + return final_events + except Exception as exc: + raise ObjectNotFoundException from exc + + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + gcs_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + self._open(gcs_file_location, "w") diff --git a/model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_repository.py b/model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_repository.py new file mode 100644 index 000000000..6096c79f7 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/gcs_file_llm_fine_tune_repository.py @@ -0,0 +1,51 @@ +import json +from typing import IO, Dict, Optional + +import smart_open +from google.auth import default +from google.cloud import storage +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository + + +class GCSFileLLMFineTuneRepository(LLMFineTuneRepository): + def __init__(self, file_path: str): + self.file_path = file_path + + def _get_gcs_client(self): + credentials, project = default() + return storage.Client(credentials=credentials, project=project) + + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = self._get_gcs_client() + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + @staticmethod + def _get_key(model_name, fine_tuning_method): + return f"{model_name}-{fine_tuning_method}" + + async def get_job_template_for_model( + self, model_name: str, fine_tuning_method: str + ) -> Optional[LLMFineTuneTemplate]: + with self._open(self.file_path, "r") as f: + data = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + job_template_dict = data.get(key, None) + if job_template_dict is None: + return None + return LLMFineTuneTemplate.parse_obj(job_template_dict) + + async def write_job_template_for_model( + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate + ): + with self._open(self.file_path, "r") as f: + data: Dict = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + data[key] = dict(job_template) + with self._open(self.file_path, "w") as f: + json.dump(data, f) + + async def initialize_data(self): + with self._open(self.file_path, "w") as f: + json.dump({}, f) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 3d4162daa..d452c0e6c 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -9,6 +9,9 @@ azure-identity~=1.15.0 azure-keyvault-secrets~=4.7.0 azure-servicebus~=7.11.4 azure-storage-blob~=12.19.0 +# GCP dependencies +google-auth~=2.25.0 +google-cloud-storage~=2.14.0 boto3-stubs[essential]~=1.26.67 boto3~=1.21 botocore~=1.24 From 21f2d2e075af35f148aa2e7fab6e430e137f7cf7 Mon Sep 17 00:00:00 2001 From: Muhammad Zoaib Date: Tue, 27 Jan 2026 17:23:18 +0300 Subject: [PATCH 2/3] added gcp support in llm engine --- .../gateways/gcs_file_storage_gateway.py | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py index 0ce1f67f2..12fa255ad 100644 --- a/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py @@ -1,5 +1,9 @@ +import os from typing import List, Optional +from google.auth import default +from google.cloud import storage +from model_engine_server.core.config import infra_config from model_engine_server.domain.gateways.file_storage_gateway import ( FileMetadata, FileStorageGateway, @@ -7,6 +11,14 @@ from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway +def get_gcs_key(owner: str, file_id: str) -> str: + return os.path.join(owner, file_id) + + +def get_gcs_url(owner: str, file_id: str) -> str: + return f"gs://{infra_config().s3_bucket}/{get_gcs_key(owner, file_id)}" + + class GCSFileStorageGateway(FileStorageGateway): """ Concrete implementation of a file storage gateway backed by GCS. @@ -15,20 +27,59 @@ class GCSFileStorageGateway(FileStorageGateway): def __init__(self): self.filesystem_gateway = GCSFilesystemGateway() + def _get_client(self) -> storage.Client: + credentials, project = default() + return storage.Client(credentials=credentials, project=project) + + def _get_bucket(self) -> storage.Bucket: + return self._get_client().bucket(infra_config().s3_bucket) + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: - raise NotImplementedError("GCS file storage not fully implemented yet") + return self.filesystem_gateway.generate_signed_url(get_gcs_url(owner, file_id)) async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: - raise NotImplementedError("GCS file storage not fully implemented yet") + try: + bucket = self._get_bucket() + blob = bucket.blob(get_gcs_key(owner, file_id)) + blob.reload() # Fetch metadata + return FileMetadata( + id=file_id, + filename=file_id, + size=blob.size, + owner=owner, + updated_at=blob.updated, + ) + except Exception: + return None async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: - raise NotImplementedError("GCS file storage not fully implemented yet") + try: + with self.filesystem_gateway.open(get_gcs_url(owner, file_id)) as f: + return f.read() + except Exception: + return None async def upload_file(self, owner: str, filename: str, content: bytes) -> str: - raise NotImplementedError("GCS file storage not fully implemented yet") + with self.filesystem_gateway.open(get_gcs_url(owner, filename), mode="w") as f: + f.write(content.decode("utf-8")) + return filename async def delete_file(self, owner: str, file_id: str) -> bool: - raise NotImplementedError("GCS file storage not fully implemented yet") + try: + bucket = self._get_bucket() + blob = bucket.blob(get_gcs_key(owner, file_id)) + blob.delete() + return True + except Exception: + return False async def list_files(self, owner: str) -> List[FileMetadata]: - raise NotImplementedError("GCS file storage not fully implemented yet") + bucket = self._get_bucket() + blobs = bucket.list_blobs(prefix=owner) + files = [] + for blob in blobs: + file_id = blob.name.replace(f"{owner}/", "", 1) + file_metadata = await self.get_file(owner, file_id) + if file_metadata is not None: + files.append(file_metadata) + return files From 974f185068a02ec18af7e66d4de7542ffcc96dfb Mon Sep 17 00:00:00 2001 From: Muhammad Zoaib Date: Tue, 27 Jan 2026 17:24:34 +0300 Subject: [PATCH 3/3] added gcp support in llm engine --- .../model_engine_server/core/celery/celery_autoscaler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index a66d01078..1e17a7df5 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -590,7 +590,8 @@ async def main(): BROKER_NAME_TO_CLASS = { ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), - GCP_MEMORYSTORE_REDIS_BROKER: RedisBroker(use_elasticache=True), # GCP Memorystore also doesn't support CONFIG GET + # GCP Memorystore also doesn't support CONFIG GET + GCP_MEMORYSTORE_REDIS_BROKER: RedisBroker(use_elasticache=True), SQS_BROKER: SQSBroker(), SERVICEBUS_BROKER: ASBBroker(), }