diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/mlflow_config.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/mlflow_config.py new file mode 100644 index 0000000000..c5e3114254 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/mlflow_config.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MLflow config for SageMaker pipeline.""" +from __future__ import absolute_import + +from typing import Dict, Any + + +class MlflowConfig: + """MLflow configuration for SageMaker pipeline.""" + + def __init__( + self, + mlflow_resource_arn: str, + mlflow_experiment_name: str, + ): + """Create an MLflow configuration for SageMaker Pipeline. + + Examples: + Basic MLflow configuration:: + + mlflow_config = MlflowConfig( + mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/my-server", + mlflow_experiment_name="my-experiment" + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[...], + mlflow_config=mlflow_config + ) + + Runtime override of experiment name:: + + # Override experiment name for a specific execution + execution = pipeline.start(mlflow_experiment_name="custom-experiment") + + Args: + mlflow_resource_arn (str): The ARN of the MLflow tracking server resource. + mlflow_experiment_name (str): The name of the MLflow experiment to be used for tracking. + """ + self.mlflow_resource_arn = mlflow_resource_arn + self.mlflow_experiment_name = mlflow_experiment_name + + def to_request(self) -> Dict[str, Any]: + """Returns: the request structure.""" + + return { + "MlflowResourceArn": self.mlflow_resource_arn, + "MlflowExperimentName": self.mlflow_experiment_name, + } diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py index 144726f690..b47e7e8d9c 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py @@ -34,7 +34,13 @@ from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT from sagemaker.core.s3 import s3_path_join from sagemaker.core.helper.session_helper import Session -from sagemaker.core.common_utils import resolve_value_from_config, retry_with_backoff, format_tags, Tags +from sagemaker.core.common_utils import ( + resolve_value_from_config, + retry_with_backoff, + format_tags, + Tags, +) + # Orchestration imports (now in mlops) from sagemaker.mlops.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.mlops.workflow._event_bridge_client_helper import ( @@ -44,19 +50,24 @@ EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, ) from sagemaker.mlops.workflow.lambda_step import LambdaOutput, LambdaStep +from sagemaker.mlops.workflow.mlflow_config import MlflowConfig from sagemaker.core.helper.pipeline_variable import ( RequestType, PipelineVariable, ) + # Primitive imports (stay in core) from sagemaker.core.workflow.execution_variables import ExecutionVariables from sagemaker.core.workflow.parameters import Parameter + # Orchestration imports (now in mlops) from sagemaker.core.workflow.pipeline_definition_config import PipelineDefinitionConfig from sagemaker.mlops.workflow.pipeline_experiment_config import PipelineExperimentConfig from sagemaker.mlops.workflow.parallelism_config import ParallelismConfiguration + # Primitive imports (stay in core) from sagemaker.core.workflow.properties import Properties + # Orchestration imports (now in mlops) from sagemaker.mlops.workflow.selective_execution_config import SelectiveExecutionConfig from sagemaker.core.workflow.step_outputs import StepOutput @@ -87,6 +98,7 @@ def __init__( name: str = "", parameters: Optional[Sequence[Parameter]] = None, pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG, + mlflow_config: Optional[MlflowConfig] = None, steps: Optional[Sequence[Union[Step, StepOutput]]] = None, sagemaker_session: Optional[Session] = None, pipeline_definition_config: Optional[PipelineDefinitionConfig] = _DEFAULT_DEFINITION_CFG, @@ -102,6 +114,8 @@ def __init__( the same name already exists. By default, pipeline name is used as experiment name and execution id is used as the trial name. If set to None, no experiment or trial will be created automatically. + mlflow_config (Optional[MlflowConfig]): If set, the pipeline will be configured + with MLflow tracking for experiment tracking and model versioning. steps (Sequence[Union[Step, StepOutput]]): The list of the non-conditional steps associated with the pipeline. Any steps that are within the `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a @@ -118,6 +132,7 @@ def __init__( self.name = name self.parameters = parameters if parameters else [] self.pipeline_experiment_config = pipeline_experiment_config + self.mlflow_config = mlflow_config self.steps = steps if steps else [] self.sagemaker_session = sagemaker_session if sagemaker_session else Session() self.pipeline_definition_config = pipeline_definition_config @@ -355,6 +370,7 @@ def start( execution_description: str = None, parallelism_config: ParallelismConfiguration = None, selective_execution_config: SelectiveExecutionConfig = None, + mlflow_experiment_name: str = None, pipeline_version_id: int = None, ): """Starts a Pipeline execution in the Workflow service. @@ -369,6 +385,10 @@ def start( over the parallelism configuration of the parent pipeline. selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for selective step execution. + mlflow_experiment_name (str): Optional MLflow experiment name to override + the experiment name specified in the pipeline's mlflow_config. + If provided, this will override the experiment name for this specific + pipeline execution only, without modifying the pipeline definition. pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not specified, uses the latest version ID. @@ -392,6 +412,7 @@ def start( PipelineExecutionDisplayName=execution_display_name, ParallelismConfiguration=parallelism_config, SelectiveExecutionConfig=selective_execution_config, + MlflowExperimentName=mlflow_experiment_name, PipelineVersionId=pipeline_version_id, ) if self.sagemaker_session.local_mode: @@ -431,14 +452,25 @@ def definition(self) -> str: if self.pipeline_experiment_config is not None else None ), + "MlflowConfig": ( + self.mlflow_config.to_request() if self.mlflow_config is not None else None + ), "Steps": list_to_request(compiled_steps), } - - request_dict["PipelineExperimentConfig"] = interpolate( - request_dict["PipelineExperimentConfig"], {}, {}, pipeline_name=self.name - ) callback_output_to_step_map = _map_callback_outputs(self.steps) lambda_output_to_step_name = _map_lambda_outputs(self.steps) + request_dict["PipelineExperimentConfig"] = interpolate( + request_dict["PipelineExperimentConfig"], + callback_output_to_step_map=callback_output_to_step_map, + lambda_output_to_step_map=lambda_output_to_step_name, + pipeline_name=self.name, + ) + request_dict["MlflowConfig"] = interpolate( + request_dict["MlflowConfig"], + callback_output_to_step_map=callback_output_to_step_map, + lambda_output_to_step_map=lambda_output_to_step_name, + pipeline_name=self.name, + ) request_dict["Steps"] = interpolate( request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map, @@ -1131,7 +1163,6 @@ def _initialize_adjacency_list(self) -> Dict[str, List[str]]: if isinstance(child_step, Step): dependency_list[child_step.name].add(step.name) - adjacency_list = {} for step in dependency_list: for step_dependency in dependency_list[step]: @@ -1169,9 +1200,7 @@ def is_cyclic_helper(current_step): return True return False - def get_steps_in_sub_dag( - self, current_step: Step, sub_dag_steps: Set[str] = None - ) -> Set[str]: + def get_steps_in_sub_dag(self, current_step: Step, sub_dag_steps: Set[str] = None) -> Set[str]: """Get names of all steps (including current step) in the sub dag of current step. Returns a set of step names in the sub dag. @@ -1211,4 +1240,4 @@ def __next__(self) -> Step: while self.stack: return self.step_map.get(self.stack.pop()) - raise StopIteration \ No newline at end of file + raise StopIteration diff --git a/sagemaker-mlops/tests/integ/test_mlflow_integration.py b/sagemaker-mlops/tests/integ/test_mlflow_integration.py new file mode 100644 index 0000000000..5faa82d212 --- /dev/null +++ b/sagemaker-mlops/tests/integ/test_mlflow_integration.py @@ -0,0 +1,223 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import time +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.mlops.workflow.callback_step import ( + CallbackStep, + CallbackOutput, + CallbackOutputTypeEnum, +) +from sagemaker.mlops.workflow.mlflow_config import MlflowConfig +from sagemaker.mlops.workflow.pipeline import Pipeline + + +@pytest.fixture +def sagemaker_session(): + """Return a SageMaker session for integration tests.""" + return Session() + + +@pytest.fixture +def role(): + """Return the execution role ARN.""" + return get_execution_role() + + +@pytest.fixture +def region_name(sagemaker_session): + """Return the AWS region name.""" + return sagemaker_session.boto_session.region_name + + +@pytest.fixture +def pipeline_name(): + return f"mlflow-test-pipeline-{int(time.time() * 10 ** 7)}" + + +def test_pipeline_definition_with_mlflow_config(sagemaker_session, role, pipeline_name, region_name): + """Verify MLflow config appears correctly in pipeline definition when pipeline is created.""" + + mlflow_config = MlflowConfig( + mlflow_resource_arn=( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/integ-test-server" + ), + mlflow_experiment_name="integ-test-experiment", + ) + + callback_step = CallbackStep( + name="test-callback-step", + sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + inputs={"test_input": "test_value"}, + outputs=[CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)], + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[callback_step], + mlflow_config=mlflow_config, + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + assert response["PipelineArn"] + + describe_response = pipeline.describe() + definition = json.loads(describe_response["PipelineDefinition"]) + + assert "MlflowConfig" in definition + assert definition["MlflowConfig"] == { + "MlflowResourceArn": ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/integ-test-server" + ), + "MlflowExperimentName": "integ-test-experiment", + } + + assert definition["Version"] == "2020-12-01" + assert "Steps" in definition + assert len(definition["Steps"]) == 1 + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_pipeline_start_with_mlflow_experiment_override( + sagemaker_session, role, pipeline_name, region_name +): + """Verify pipeline can be started with MLflow experiment name override.""" + + original_mlflow_config = MlflowConfig( + mlflow_resource_arn=( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/original-server" + ), + mlflow_experiment_name="original-experiment", + ) + + callback_step = CallbackStep( + name="test-callback-step", + sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + inputs={"test_input": "test_value"}, + outputs=[CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)], + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[callback_step], + mlflow_config=original_mlflow_config, + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + assert response["PipelineArn"] + + describe_response = pipeline.describe() + definition = json.loads(describe_response["PipelineDefinition"]) + assert definition["MlflowConfig"]["MlflowExperimentName"] == "original-experiment" + + execution = pipeline.start(mlflow_experiment_name="runtime-override-experiment") + + assert execution.arn + execution_response = execution.describe() + assert execution_response["PipelineExecutionStatus"] in ["Executing", "Succeeded", "Failed"] + + assert execution_response.get("MLflowConfig", {}).get("MlflowExperimentName") == "runtime-override-experiment" + assert execution_response.get("MLflowConfig", {}).get("MlflowResourceArn") == "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/original-server" + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_pipeline_update_with_mlflow_config(sagemaker_session, role, pipeline_name, region_name): + """Verify pipeline can be updated to add or modify MLflow config.""" + + callback_step = CallbackStep( + name="test-callback-step", + sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + inputs={"test_input": "test_value"}, + outputs=[CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)], + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[callback_step], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + assert response["PipelineArn"] + + describe_response = pipeline.describe() + definition = json.loads(describe_response["PipelineDefinition"]) + assert definition["MlflowConfig"] is None + + mlflow_config = MlflowConfig( + mlflow_resource_arn=( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/update-test-server" + ), + mlflow_experiment_name="update-test-experiment", + ) + pipeline.mlflow_config = mlflow_config + + update_response = pipeline.update(role) + assert update_response["PipelineArn"] + + describe_response = pipeline.describe() + definition = json.loads(describe_response["PipelineDefinition"]) + assert "MlflowConfig" in definition + assert definition["MlflowConfig"] == { + "MlflowResourceArn": ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/update-test-server" + ), + "MlflowExperimentName": "update-test-experiment", + } + + pipeline.mlflow_config.mlflow_experiment_name = "modified-experiment" + + update_response2 = pipeline.update(role) + assert update_response2["PipelineArn"] + + describe_response = pipeline.describe() + definition = json.loads(describe_response["PipelineDefinition"]) + assert definition["MlflowConfig"]["MlflowExperimentName"] == "modified-experiment" + assert ( + definition["MlflowConfig"]["MlflowResourceArn"] + == "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/update-test-server" + ) + + pipeline.mlflow_config = None + + update_response3 = pipeline.update(role) + assert update_response3["PipelineArn"] + + describe_response = pipeline.describe() + definition = json.loads(describe_response["PipelineDefinition"]) + assert definition["MlflowConfig"] is None + + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py index 1550a95c36..80837ba2a6 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py +++ b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py @@ -375,6 +375,7 @@ def test_get_function_step_result_success(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"}, "TrainingJobStatus": "Completed", + "Environment": {}, } with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"): @@ -497,6 +498,7 @@ def test_pipeline_execution_result_terminal_failure(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"}, "TrainingJobStatus": "Completed", + "Environment": {}, } with patch.object(execution, "wait", side_effect=WaiterError("name", "Waiter encountered a terminal failure state", {})): @@ -514,6 +516,7 @@ def test_get_function_step_result_obsolete_s3_path(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/different/path"}, "TrainingJobStatus": "Completed", + "Environment": {}, } with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"): diff --git a/sagemaker-mlops/tests/unit/workflow/test_pipeline_mlflow_config.py b/sagemaker-mlops/tests/unit/workflow/test_pipeline_mlflow_config.py new file mode 100644 index 0000000000..bd23f19409 --- /dev/null +++ b/sagemaker-mlops/tests/unit/workflow/test_pipeline_mlflow_config.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import pytest +from unittest.mock import Mock, patch + +from sagemaker.mlops.workflow.mlflow_config import MlflowConfig +from sagemaker.mlops.workflow.pipeline import Pipeline +from sagemaker.mlops.workflow.steps import Step, StepTypeEnum + + +@pytest.fixture +def mock_session(): + """Create a mock SageMaker session for testing.""" + session = Mock() + session.boto_session.client.return_value = Mock() + session.sagemaker_client = Mock() + session.local_mode = False + return session + + +def ordered(obj): + """Recursively sort dict keys for comparison.""" + if isinstance(obj, dict): + return {k: ordered(v) for k, v in sorted(obj.items())} + if isinstance(obj, list): + return [ordered(x) for x in obj] + return obj + + +class CustomStep(Step): + """Custom step for testing.""" + + def __init__(self, name, input_data): + super(CustomStep, self).__init__(name=name, step_type=StepTypeEnum.TRAINING, depends_on=[]) + self.input_data = input_data + + @property + def arguments(self): + return {"input_data": self.input_data} + + @property + def properties(self): + return None + + +def test_pipeline_with_mlflow_config(mock_session): + mlflow_config = MlflowConfig( + mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/training-test", + mlflow_experiment_name="training-test-experiment", + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[CustomStep(name="MyStep", input_data="input")], + mlflow_config=mlflow_config, + sagemaker_session=mock_session, + ) + + pipeline_definition = json.loads(pipeline.definition()) + assert ordered(pipeline_definition) == ordered( + { + "Version": "2020-12-01", + "Metadata": {}, + "Parameters": [], + "PipelineExperimentConfig": { + "ExperimentName": {"Get": "Execution.PipelineName"}, + "TrialName": {"Get": "Execution.PipelineExecutionId"}, + }, + "MlflowConfig": { + "MlflowResourceArn": "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/training-test", + "MlflowExperimentName": "training-test-experiment", + }, + "Steps": [ + { + "Name": "MyStep", + "Type": "Training", + "Arguments": {"input_data": "input"}, + } + ], + } + ) + + +def test_pipeline_without_mlflow_config(mock_session): + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[CustomStep(name="MyStep", input_data="input")], + mlflow_config=None, + sagemaker_session=mock_session, + ) + + pipeline_definition = json.loads(pipeline.definition()) + assert pipeline_definition.get("MlflowConfig") is None + + +def test_pipeline_start_with_mlflow_experiment_name(mock_session): + mock_session.sagemaker_client.start_pipeline_execution.return_value = { + "PipelineExecutionArn": "my:arn" + } + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + sagemaker_session=mock_session, + ) + + # Test starting with MLflow experiment name + pipeline.start(mlflow_experiment_name="my-experiment") + mock_session.sagemaker_client.start_pipeline_execution.assert_called_with( + PipelineName="MyPipeline", MlflowExperimentName="my-experiment" + ) + + # Test starting without MLflow experiment name + pipeline.start() + mock_session.sagemaker_client.start_pipeline_execution.assert_called_with( + PipelineName="MyPipeline", + ) + + +def test_pipeline_update_with_mlflow_config(mock_session): + """Test that pipeline.update() includes MLflow config in the definition sent to the API.""" + + pipeline = Pipeline( + name="MyPipeline", + steps=[CustomStep(name="MyStep", input_data="input")], + sagemaker_session=mock_session, + ) + + initial_definition = json.loads(pipeline.definition()) + assert initial_definition.get("MlflowConfig") is None + + mlflow_config = MlflowConfig( + mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/update-test", + mlflow_experiment_name="update-test-experiment", + ) + pipeline.mlflow_config = mlflow_config + + with patch( + "sagemaker.mlops.workflow.pipeline.resolve_value_from_config", return_value="dummy-role" + ): + pipeline.update("dummy-role") + + mock_session.sagemaker_client.update_pipeline.assert_called_once() + call_args = mock_session.sagemaker_client.update_pipeline.call_args + + pipeline_definition_arg = call_args[1]["PipelineDefinition"] + definition = json.loads(pipeline_definition_arg) + + assert definition["MlflowConfig"] == { + "MlflowResourceArn": "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/update-test", + "MlflowExperimentName": "update-test-experiment", + }