Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
CeleryTaskQueueGateway,
DatadogMonitoringMetricsGateway,
FakeMonitoringMetricsGateway,
GCSFileStorageGateway,
GCSFilesystemGateway,
GCSLLMArtifactGateway,
LiveAsyncModelEndpointInferenceGateway,
LiveBatchJobOrchestrationGateway,
LiveBatchJobProgressGateway,
Expand Down Expand Up @@ -100,6 +103,9 @@
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
SQSQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import (
RedisQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.s3_file_storage_gateway import S3FileStorageGateway
from model_engine_server.infra.repositories import (
ABSFileLLMFineTuneEventsRepository,
Expand All @@ -112,6 +118,8 @@
DbTriggerRepository,
ECRDockerRepository,
FakeDockerRepository,
GCSFileLLMFineTuneEventsRepository,
GCSFileLLMFineTuneRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
RedisModelEndpointCacheRepository,
Expand Down Expand Up @@ -225,6 +233,9 @@ def _get_external_interfaces(
queue_delegate = FakeQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate
queue_delegate = RedisQueueEndpointResourceDelegate()
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -238,7 +249,8 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().celery_broker_type_redis:
elif infra_config().cloud_provider == "gcp" or infra_config().celery_broker_type_redis:
# GCP uses Redis (Memorystore) for Celery broker
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
else:
Expand Down Expand Up @@ -274,16 +286,15 @@ def _get_external_interfaces(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
filesystem_gateway = (
ABSFilesystemGateway()
if infra_config().cloud_provider == "azure"
else S3FilesystemGateway()
)
llm_artifact_gateway = (
ABSLLMArtifactGateway()
if infra_config().cloud_provider == "azure"
else S3LLMArtifactGateway()
)
if infra_config().cloud_provider == "azure":
filesystem_gateway = ABSFilesystemGateway()
llm_artifact_gateway = ABSLLMArtifactGateway()
elif infra_config().cloud_provider == "gcp":
filesystem_gateway = GCSFilesystemGateway()
llm_artifact_gateway = GCSLLMArtifactGateway()
else:
filesystem_gateway = S3FilesystemGateway()
llm_artifact_gateway = S3LLMArtifactGateway()
model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
filesystem_gateway=filesystem_gateway
)
Expand Down Expand Up @@ -331,15 +342,17 @@ def _get_external_interfaces(
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository()
elif infra_config().cloud_provider == "gcp":
llm_fine_tune_repository = GCSFileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_events_repository = GCSFileLLMFineTuneEventsRepository()
else:
llm_fine_tune_repository = S3FileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_events_repository = (
ABSFileLLMFineTuneEventsRepository()
if infra_config().cloud_provider == "azure"
else S3FileLLMFineTuneEventsRepository()
)
llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository()
llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService(
docker_image_batch_job_gateway=docker_image_batch_job_gateway,
docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository,
Expand All @@ -350,11 +363,12 @@ def _get_external_interfaces(
docker_image_batch_job_gateway=docker_image_batch_job_gateway
)

file_storage_gateway = (
ABSFileStorageGateway()
if infra_config().cloud_provider == "azure"
else S3FileStorageGateway()
)
if infra_config().cloud_provider == "azure":
file_storage_gateway = ABSFileStorageGateway()
elif infra_config().cloud_provider == "gcp":
file_storage_gateway = GCSFileStorageGateway()
else:
file_storage_gateway = S3FileStorageGateway()

docker_repository: DockerRepository
if CIRCLECI:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def excluded_namespaces():


ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master"
GCP_MEMORYSTORE_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master"
SQS_BROKER = "sqs-message-broker-master"
SERVICEBUS_BROKER = "servicebus-message-broker-master"

Expand Down Expand Up @@ -589,6 +590,8 @@ async def main():

BROKER_NAME_TO_CLASS = {
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True),
# GCP Memorystore also doesn't support CONFIG GET
GCP_MEMORYSTORE_REDIS_BROKER: RedisBroker(use_elasticache=True),
SQS_BROKER: SQSBroker(),
SERVICEBUS_BROKER: ASBBroker(),
}
Expand Down
6 changes: 6 additions & 0 deletions model-engine/model_engine_server/infra/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway
from .fake_model_primitive_gateway import FakeModelPrimitiveGateway
from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway
from .gcs_file_storage_gateway import GCSFileStorageGateway
from .gcs_filesystem_gateway import GCSFilesystemGateway
from .gcs_llm_artifact_gateway import GCSLLMArtifactGateway
from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway
from .live_batch_job_orchestration_gateway import LiveBatchJobOrchestrationGateway
from .live_batch_job_progress_gateway import LiveBatchJobProgressGateway
Expand Down Expand Up @@ -37,6 +40,9 @@
"DatadogMonitoringMetricsGateway",
"FakeModelPrimitiveGateway",
"FakeMonitoringMetricsGateway",
"GCSFileStorageGateway",
"GCSFilesystemGateway",
"GCSLLMArtifactGateway",
"LiveAsyncModelEndpointInferenceGateway",
"LiveBatchJobOrchestrationGateway",
"LiveBatchJobProgressGateway",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
from typing import List, Optional

from google.auth import default
from google.cloud import storage
from model_engine_server.core.config import infra_config
from model_engine_server.domain.gateways.file_storage_gateway import (
FileMetadata,
FileStorageGateway,
)
from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway


def get_gcs_key(owner: str, file_id: str) -> str:
return os.path.join(owner, file_id)


def get_gcs_url(owner: str, file_id: str) -> str:
return f"gs://{infra_config().s3_bucket}/{get_gcs_key(owner, file_id)}"


class GCSFileStorageGateway(FileStorageGateway):
"""
Concrete implementation of a file storage gateway backed by GCS.
"""

def __init__(self):
self.filesystem_gateway = GCSFilesystemGateway()

def _get_client(self) -> storage.Client:
credentials, project = default()
return storage.Client(credentials=credentials, project=project)

def _get_bucket(self) -> storage.Bucket:
return self._get_client().bucket(infra_config().s3_bucket)

async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]:
return self.filesystem_gateway.generate_signed_url(get_gcs_url(owner, file_id))

async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]:
try:
bucket = self._get_bucket()
blob = bucket.blob(get_gcs_key(owner, file_id))
blob.reload() # Fetch metadata
return FileMetadata(
id=file_id,
filename=file_id,
size=blob.size,
owner=owner,
updated_at=blob.updated,
)
except Exception:
return None

async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
try:
with self.filesystem_gateway.open(get_gcs_url(owner, file_id)) as f:
return f.read()
except Exception:
return None

async def upload_file(self, owner: str, filename: str, content: bytes) -> str:
with self.filesystem_gateway.open(get_gcs_url(owner, filename), mode="w") as f:
f.write(content.decode("utf-8"))
return filename

async def delete_file(self, owner: str, file_id: str) -> bool:
try:
bucket = self._get_bucket()
blob = bucket.blob(get_gcs_key(owner, file_id))
blob.delete()
return True
except Exception:
return False

async def list_files(self, owner: str) -> List[FileMetadata]:
bucket = self._get_bucket()
blobs = bucket.list_blobs(prefix=owner)
files = []
for blob in blobs:
file_id = blob.name.replace(f"{owner}/", "", 1)
file_metadata = await self.get_file(owner, file_id)
if file_metadata is not None:
files.append(file_metadata)
return files
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import re
from datetime import timedelta
from typing import IO

import smart_open
from google.auth import default
from google.cloud import storage
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway


class GCSFilesystemGateway(FilesystemGateway):
"""
Concrete implementation for interacting with a filesystem backed by Google Cloud Storage.
"""

def _get_storage_client(self) -> storage.Client:
credentials, project = default()
return storage.Client(credentials=credentials, project=project)

def open(self, uri: str, mode: str = "rt", **kwargs) -> IO:
client = self._get_storage_client()
transport_params = {"client": client}
return smart_open.open(uri, mode, transport_params=transport_params)

def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
# Parse gs://bucket/key format
match = re.search(r"^gs://([^/]+)/(.*?)$", uri)
if not match:
# Try https://storage.googleapis.com/bucket/key format
match = re.search(r"^https://storage\.googleapis\.com/([^/]+)/(.*?)$", uri)
assert match, f"Invalid GCS URI: {uri}"

bucket_name, blob_name = match.group(1), match.group(2)

client = self._get_storage_client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)

signed_url = blob.generate_signed_url(
version="v4",
expiration=timedelta(seconds=expiration),
method="GET",
**kwargs,
)
return signed_url
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import os
from typing import Any, Dict, List

from google.auth import default
from google.cloud import storage
from model_engine_server.common.config import get_model_cache_directory_name, hmi_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.core.utils.url import parse_attachment_url
from model_engine_server.domain.gateways import LLMArtifactGateway

logger = make_logger(logger_name())


def _get_gcs_client() -> storage.Client:
credentials, project = default()
return storage.Client(credentials=credentials, project=project)


class GCSLLMArtifactGateway(LLMArtifactGateway):
"""
Concrete implementation using Google Cloud Storage.
"""

def list_files(self, path: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket_name = parsed_remote.bucket
prefix = parsed_remote.key

client = _get_gcs_client()
bucket = client.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=prefix)
return [blob.name for blob in blobs]

def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket_name = parsed_remote.bucket
prefix = parsed_remote.key

client = _get_gcs_client()
bucket = client.bucket(bucket_name)

downloaded_files: List[str] = []
for blob in bucket.list_blobs(prefix=prefix):
file_path_suffix = blob.name.replace(prefix, "").lstrip("/")
local_path = os.path.join(target_path, file_path_suffix).rstrip("/")

if not overwrite and os.path.exists(local_path):
downloaded_files.append(local_path)
continue

local_dir = "/".join(local_path.split("/")[:-1])
if not os.path.exists(local_dir):
os.makedirs(local_dir)

logger.info(f"Downloading {blob.name} to {local_path}")
blob.download_to_filename(local_path)
downloaded_files.append(local_path)
return downloaded_files

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
)
bucket_name = parsed_remote.bucket
fine_tuned_weights_prefix = parsed_remote.key

client = _get_gcs_client()
bucket = client.bucket(bucket_name)

model_files: List[str] = []
model_cache_name = get_model_cache_directory_name(model_name)
prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}"
for blob in bucket.list_blobs(prefix=prefix):
model_files.append(f"gs://{bucket_name}/{blob.name}")
return model_files

def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket_name = parsed_remote.bucket
key = os.path.join(parsed_remote.key, "config.json")

client = _get_gcs_client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(key)
return json.loads(blob.download_as_text())
Loading