Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/foundation-model-timeseries.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ The following `model_id` values are currently supported. Chronos-2 models native

## Data

The examples use a [retail sales](https://autogluon.s3.amazonaws.com/datasets/timeseries/retail_sales/) dataset with weekly sales for 1,115 stores. Load the historical observations:
The examples use a retail sales dataset with weekly sales for 1,115 stores. Load the historical observations:

```{code-cell} ipython3
import pandas as pd
Expand Down
17 changes: 15 additions & 2 deletions src/autogluon/cloud/backend/sagemaker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..utils.dlc_utils import parse_framework_version
from ..utils.misc import MostRecentInsertedOrderedDict
from ..utils.serializers import AutoGluonSerializationWrapper
from ..utils.tag_utils import build_tags
from ..utils.utils import (
convert_image_path_to_encoded_bytes_in_dataframe,
is_image_file,
Expand Down Expand Up @@ -68,6 +69,14 @@ def _realtime_predictor_cls(self) -> Predictor:
"""Class used for realtime endpoint"""
return AutoGluonRealtimePredictor

def _resolve_tags(
self,
kwargs: Dict[str, Any],
extra_tags: Optional[List[Dict[str, str]]] = None,
) -> None:
"""In-place: replace ``kwargs['tags']`` with the merged default + extra + user tag list."""
kwargs["tags"] = build_tags(self.predictor_type, extra_tags=extra_tags, user_tags=kwargs.get("tags"))

def initialize(self, role: Optional[str] = None, **kwargs) -> None:
"""Initialize the backend.

Expand Down Expand Up @@ -170,6 +179,7 @@ def fit(
autogluon_sagemaker_estimator_kwargs: Optional[Dict] = None,
fit_kwargs: Optional[Dict] = None,
extra_ag_args: Optional[Dict[str, Any]] = None,
extra_tags: Optional[List[Dict[str, str]]] = None,
) -> None:
"""
Fit the predictor with SageMaker.
Expand Down Expand Up @@ -317,6 +327,7 @@ def fit(
)
if fit_kwargs is None:
fit_kwargs = {}
self._resolve_tags(autogluon_sagemaker_estimator_kwargs, extra_tags)

self._fit_job.run(
role=self.role_arn,
Expand Down Expand Up @@ -361,6 +372,7 @@ def deploy(
inference_mode: Literal["realtime", "serverless"] = "realtime",
inference_config: Optional[Dict[str, Any]] = None,
repack: bool = True,
extra_tags: Optional[List[Dict[str, str]]] = None,
) -> None:
"""
Deploy a predictor as a SageMaker endpoint, which can be used to do real-time inference later.
Expand Down Expand Up @@ -525,8 +537,8 @@ def deploy(
env=model_kwargs_env,
**model_kwargs,
)
if deploy_kwargs is None:
deploy_kwargs = {}
deploy_kwargs = copy.deepcopy(deploy_kwargs or {})
self._resolve_tags(deploy_kwargs, extra_tags)

instance_kwargs = {
"instance_type": instance_type,
Expand Down Expand Up @@ -1285,6 +1297,7 @@ def _predict(
if transformer_kwargs is None:
transformer_kwargs = {}
transformer_kwargs = copy.deepcopy(transformer_kwargs)
self._resolve_tags(transformer_kwargs)
user_entry_point = model_kwargs.pop("entry_point", None)
repack_model = False
if predictor_path != self._fit_job.get_output_path() or user_entry_point is not None:
Expand Down
4 changes: 3 additions & 1 deletion src/autogluon/cloud/backend/timeseries_sagemaker_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pandas as pd

Expand Down Expand Up @@ -34,6 +34,7 @@ def fit(
autogluon_sagemaker_estimator_kwargs: Optional[Dict] = None,
fit_kwargs: Optional[Dict] = None,
extra_ag_args: Optional[Dict[str, Any]] = None,
extra_tags: Optional[List[Dict[str, str]]] = None,
) -> None:
"""Fit a TimeSeriesPredictor in SageMaker.

Expand Down Expand Up @@ -65,6 +66,7 @@ def fit(
autogluon_sagemaker_estimator_kwargs=autogluon_sagemaker_estimator_kwargs,
fit_kwargs=fit_kwargs,
extra_ag_args=extra_ag_args,
extra_tags=extra_tags,
)

def predict_real_time(
Expand Down
2 changes: 2 additions & 0 deletions src/autogluon/cloud/model/foundation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def _deploy_backend(
inference_mode=inference_mode,
inference_config=inference_config,
repack=False,
extra_tags=[{"Key": "autogluon-cloud-model-id", "Value": self.model_id}],
**backend_kwargs,
)
assert self._backend.endpoint is not None
Expand Down Expand Up @@ -584,6 +585,7 @@ def predict(
custom_image_uri=custom_image_uri,
wait=wait,
extra_ag_args=extra_ag_args,
extra_tags=[{"Key": "autogluon-cloud-model-id", "Value": self.model_id}],
**backend_kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ def fit_predict(
Backend-specific arguments. Same keys as ``fit()``.
predictions_path: Optional[str]
S3 URL where predictions will be written by the training container (e.g.
``s3://my-bucket/runs/2024-05-01/predictions.csv``). The container's SageMaker execution role must
have ``s3:PutObject`` permission for this location. Defaults to
``s3://my-bucket/runs/2024-05-01/predictions.csv``). The container's SageMaker execution role must have
``s3:PutObject`` permission for this location. Defaults to
``{cloud_output_path}/{job_name}/predictions.csv``. Predictions use AutoGluon's canonical column
names ``item_id`` and ``timestamp``, regardless of the ``id_column`` / ``timestamp_column`` passed in.

Expand Down
27 changes: 27 additions & 0 deletions src/autogluon/cloud/utils/tag_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Tag helpers for SageMaker resources created by autogluon-cloud."""

from __future__ import annotations

import os
from typing import Dict, List, Optional

DISABLE_DEFAULT_TAGS_ENV = "AG_CLOUD_DISABLE_DEFAULT_TAGS"


def build_tags(
module: str,
extra_tags: Optional[List[Dict[str, str]]] = None,
user_tags: Optional[List[Dict[str, str]]] = None,
) -> List[Dict[str, str]]:
"""Final tag list for a SageMaker resource: defaults + extras + user, with user winning on key collision.

Defaults are skipped entirely when ``AG_CLOUD_DISABLE_DEFAULT_TAGS`` is truthy, so customers in
tag-restricted AWS orgs can opt out without losing other functionality.
"""
if os.environ.get(DISABLE_DEFAULT_TAGS_ENV, "").lower() in ("1", "true", "yes"):
return list(user_tags or [])
base = [{"Key": "autogluon-cloud-module", "Value": module}] + list(extra_tags or [])
if not user_tags:
return base
user_keys = {t["Key"] for t in user_tags}
return [t for t in base if t["Key"] not in user_keys] + list(user_tags)
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def replace_image_abspath(data, image_column):
data[image_column] = data[image_column].apply(os.path.abspath)
return data

@staticmethod
def assert_ag_cloud_tags(arn: str, *, module: str, model_id: str = None):
"""Assert the resource at ``arn`` carries ``autogluon-cloud-module`` (and optionally ``autogluon-cloud-model-id``).

Works on any tagged SageMaker resource (training/transform jobs, models, endpoints).
"""
tags = {t["Key"]: t["Value"] for t in boto3.client("sagemaker").list_tags(ResourceArn=arn)["Tags"]}
assert tags.get("autogluon-cloud-module") == module, f"missing/wrong module tag on {arn}: {tags}"
if model_id is not None:
assert tags.get("autogluon-cloud-model-id") == model_id, f"missing/wrong model-id tag on {arn}: {tags}"

@staticmethod
def test_endpoint(cloud_predictor, test_data, inference_kwargs=None, **predict_real_time_kwargs):
if inference_kwargs is None:
Expand Down Expand Up @@ -164,13 +175,19 @@ def test_functionality(
assert job_name is not None
assert info["fit_job"]["status"] == "Completed"

sm = boto3.client("sagemaker")
training_job_arn = sm.describe_training_job(TrainingJobName=job_name)["TrainingJobArn"]
CloudTestHelper.assert_ag_cloud_tags(training_job_arn, module=cloud_predictor.predictor_type)

cloud_predictor.attach_job(job_name)

if deploy_kwargs is None:
deploy_kwargs = dict()
if predict_real_time_kwargs is None:
predict_real_time_kwargs = dict()
cloud_predictor.deploy(**deploy_kwargs)
endpoint_arn = sm.describe_endpoint(EndpointName=cloud_predictor.endpoint_name)["EndpointArn"]
CloudTestHelper.assert_ag_cloud_tags(endpoint_arn, module=cloud_predictor.predictor_type)
CloudTestHelper.test_endpoint(
cloud_predictor,
test_data,
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/general/test_foundation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_deploy_passes_artifact_uri_and_overrides_model_path_to_container_dir():
assert call.kwargs["repack"] is False
serve_cfg = call.kwargs["fm_serve_config"]
assert serve_cfg["hyperparameters"]["model_path"] == "/opt/ml/model/weights"
assert {"Key": "autogluon-cloud-model-id", "Value": "chronos-2"} in call.kwargs["extra_tags"]


def test_deploy_without_artifact_passes_none_predictor_path_and_source_uri():
Expand All @@ -112,6 +113,7 @@ def test_deploy_without_artifact_passes_none_predictor_path_and_source_uri():
assert call.kwargs["repack"] is False
serve_cfg = call.kwargs["fm_serve_config"]
assert serve_cfg["hyperparameters"]["model_path"] == "autogluon/chronos-2"
assert {"Key": "autogluon-cloud-model-id", "Value": "chronos-2"} in call.kwargs["extra_tags"]


def test_deploy_rejects_user_model_path_when_artifact_uri_set():
Expand Down
41 changes: 41 additions & 0 deletions tests/unittests/general/test_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Unit tests for the tag-merging helper used by SagemakerBackend."""

import pytest

from autogluon.cloud.utils.tag_utils import DISABLE_DEFAULT_TAGS_ENV, build_tags


def test_when_no_extras_or_user_then_only_module_tag_is_returned():
assert build_tags("timeseries") == [{"Key": "autogluon-cloud-module", "Value": "timeseries"}]


def test_when_extra_tags_provided_then_appended_after_module():
tags = build_tags("timeseries", extra_tags=[{"Key": "autogluon-cloud-model-id", "Value": "chronos-2"}])
assert tags == [
{"Key": "autogluon-cloud-module", "Value": "timeseries"},
{"Key": "autogluon-cloud-model-id", "Value": "chronos-2"},
]


def test_when_user_tag_collides_with_default_then_user_wins():
tags = build_tags("timeseries", user_tags=[{"Key": "autogluon-cloud-module", "Value": "override"}])
assert tags == [{"Key": "autogluon-cloud-module", "Value": "override"}]


def test_when_user_tags_unique_then_appended_after_defaults():
tags = build_tags("tabular", user_tags=[{"Key": "Owner", "Value": "team"}])
assert tags == [
{"Key": "autogluon-cloud-module", "Value": "tabular"},
{"Key": "Owner", "Value": "team"},
]


@pytest.mark.parametrize("value", ["1", "true", "True", "yes"])
def test_when_disable_env_var_set_then_defaults_and_extras_are_skipped(monkeypatch, value):
"""Extras are AG-cloud defaults too — opt-out drops them along with module."""
monkeypatch.setenv(DISABLE_DEFAULT_TAGS_ENV, value)
assert build_tags("timeseries") == []
assert build_tags("timeseries", extra_tags=[{"Key": "autogluon-cloud-model-id", "Value": "chronos-2"}]) == []
assert build_tags("timeseries", user_tags=[{"Key": "Owner", "Value": "team"}]) == [
{"Key": "Owner", "Value": "team"}
]
8 changes: 8 additions & 0 deletions tests/unittests/timeseries/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def test_foundation_model_predict(test_helper, framework_version, retail_sales_d
head = boto3.client("s3").head_object(Bucket=bucket, Key=predictions_key)
assert head["ContentLength"] > 0, "predictions file on S3 should not be empty"

sm = boto3.client("sagemaker")
job_arn = sm.describe_training_job(TrainingJobName=model._backend._fit_job.job_name)["TrainingJobArn"]
test_helper.assert_ag_cloud_tags(job_arn, module="timeseries", model_id="chronos-2")


def test_foundation_model_cache_artifact_then_deploy_serverless(test_helper, framework_version, retail_sales_dataset):
"""Cache model artifact to S3, deploy to a serverless endpoint, and verify predictions."""
Expand All @@ -230,6 +234,8 @@ def test_foundation_model_cache_artifact_then_deploy_serverless(test_helper, fra
inference_mode="serverless",
inference_config={"memory_size_in_mb": 6144},
)
endpoint_arn = boto3.client("sagemaker").describe_endpoint(EndpointName=endpoint.endpoint_name)["EndpointArn"]
test_helper.assert_ag_cloud_tags(endpoint_arn, module="timeseries", model_id="chronos-bolt-tiny")
try:
expected_item_ids = sorted(ds["train_data"][ds["id_column"]].unique())
predictions = endpoint.predict(
Expand Down Expand Up @@ -262,6 +268,8 @@ def test_foundation_model_deploy(test_helper, framework_version, retail_sales_da
endpoint = model.deploy(
custom_image_uri=inference_custom_image_uri,
)
endpoint_arn = boto3.client("sagemaker").describe_endpoint(EndpointName=endpoint.endpoint_name)["EndpointArn"]
test_helper.assert_ag_cloud_tags(endpoint_arn, module="timeseries", model_id="chronos-bolt-tiny")

try:
expected_item_ids = sorted(ds["train_data"][ds["id_column"]].unique())
Expand Down
Loading