Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@ def _prepare_and_upload_callable(
stored_function = StoredFunction(
sagemaker_session=sagemaker_session,
s3_base_uri=s3_base_uri,
hmac_key=self.remote_decorator_config.environment_variables[
"REMOTE_FUNCTION_SECRET_KEY"
],
s3_kms_key=self.remote_decorator_config.s3_kms_key,
)
stored_function.save(func)
Expand Down
6 changes: 0 additions & 6 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def wrapper(*args, **kwargs):
s3_uri=s3_path_join(
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
),
hmac_key=job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down Expand Up @@ -399,7 +398,6 @@ def wrapper(*args, **kwargs):
return serialization.deserialize_obj_from_s3(
sagemaker_session=job_settings.sagemaker_session,
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
hmac_key=job.hmac_key,
)

if job.describe()["TrainingJobStatus"] == "Stopped":
Expand Down Expand Up @@ -979,7 +977,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
job_return = serialization.deserialize_obj_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
hmac_key=job.hmac_key,
)
except DeserializationError as e:
client_exception = e
Expand All @@ -991,7 +988,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
job_exception = serialization.deserialize_exception_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
hmac_key=job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down Expand Up @@ -1081,7 +1077,6 @@ def result(self, timeout: float = None) -> Any:
self._return = serialization.deserialize_obj_from_s3(
sagemaker_session=self._job.sagemaker_session,
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
hmac_key=self._job.hmac_key,
)
self._state = _FINISHED
return self._return
Expand All @@ -1090,7 +1085,6 @@ def result(self, timeout: float = None) -> Any:
self._exception = serialization.deserialize_exception_from_s3(
sagemaker_session=self._job.sagemaker_session,
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
hmac_key=self._job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down
6 changes: 0 additions & 6 deletions src/sagemaker/remote_function/core/pipeline_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ class _DelayedReturnResolver:
def __init__(
self,
delayed_returns: List[_DelayedReturn],
hmac_key: str,
properties_resolver: _PropertiesResolver,
parameter_resolver: _ParameterResolver,
execution_variable_resolver: _ExecutionVariableResolver,
Expand All @@ -175,7 +174,6 @@ def __init__(

Args:
delayed_returns: list of delayed returns to resolve.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
properties_resolver: resolver used to resolve step properties.
parameter_resolver: resolver used to pipeline parameters.
execution_variable_resolver: resolver used to resolve execution variables.
Expand All @@ -197,7 +195,6 @@ def deserialization_task(uri):
return uri, deserialize_obj_from_s3(
sagemaker_session=settings["sagemaker_session"],
s3_uri=uri,
hmac_key=hmac_key,
)

with ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -247,7 +244,6 @@ def resolve_pipeline_variables(
context: Context,
func_args: Tuple,
func_kwargs: Dict,
hmac_key: str,
s3_base_uri: str,
**settings,
):
Expand All @@ -257,7 +253,6 @@ def resolve_pipeline_variables(
context: context for the execution.
func_args: function args.
func_kwargs: function kwargs.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
**settings: settings to pass to the deserialization function.
Expand All @@ -280,7 +275,6 @@ def resolve_pipeline_variables(
properties_resolver = _PropertiesResolver(context)
delayed_return_resolver = _DelayedReturnResolver(
delayed_returns=delayed_returns,
hmac_key=hmac_key,
properties_resolver=properties_resolver,
parameter_resolver=parameter_resolver,
execution_variable_resolver=execution_variable_resolver,
Expand Down
41 changes: 15 additions & 26 deletions src/sagemaker/remote_function/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,14 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:

# TODO: use dask serializer in case dask distributed is installed in users' environment.
def serialize_func_to_s3(
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
):
"""Serializes function and uploads it to S3.

Args:
sagemaker_session (sagemaker.session.Session):
The underlying Boto3 session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
func: function to be serialized and persisted
Raises:
Expand All @@ -169,14 +168,13 @@ def serialize_func_to_s3(

_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(func),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
"""Downloads from S3 and then deserializes data objects.

This method downloads the serialized training job outputs to a temporary directory and
Expand All @@ -186,7 +184,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
sagemaker_session (sagemaker.session.Session):
The underlying sagemaker session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
Returns :
The deserialized function.
Raises:
Expand All @@ -199,14 +196,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def serialize_obj_to_s3(
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
):
"""Serializes data object and uploads it to S3.

Expand All @@ -215,15 +212,13 @@ def serialize_obj_to_s3(
The underlying Boto3 session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
obj: object to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
"""

_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(obj),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
Expand Down Expand Up @@ -270,14 +265,13 @@ def json_serialize_obj_to_s3(
)


def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
"""Downloads from S3 and then deserializes data objects.

Args:
sagemaker_session (sagemaker.session.Session):
The underlying sagemaker session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
Returns :
Deserialized python objects.
Raises:
Expand All @@ -291,14 +285,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def serialize_exception_to_s3(
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
):
"""Serializes exception with traceback and uploads it to S3.

Expand All @@ -307,7 +301,6 @@ def serialize_exception_to_s3(
The underlying Boto3 session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
exc: Exception to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
Expand All @@ -316,7 +309,6 @@ def serialize_exception_to_s3(

_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(exc),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
Expand All @@ -325,7 +317,6 @@ def serialize_exception_to_s3(

def _upload_payload_and_metadata_to_s3(
bytes_to_upload: Union[bytes, io.BytesIO],
hmac_key: str,
s3_uri: str,
sagemaker_session: Session,
s3_kms_key,
Expand All @@ -334,15 +325,14 @@ def _upload_payload_and_metadata_to_s3(

Args:
bytes_to_upload (bytes): Serialized bytes to upload.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
sagemaker_session (sagemaker.session.Session):
The underlying Boto3 session which AWS service calls are delegated to.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
"""
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
sha256_hash = _compute_hash(bytes_to_upload)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
Expand All @@ -352,14 +342,13 @@ def _upload_payload_and_metadata_to_s3(
)


def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
"""Downloads from S3 and then deserializes exception.

Args:
sagemaker_session (sagemaker.session.Session):
The underlying sagemaker session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
Returns :
Deserialized exception with traceback.
Raises:
Expand All @@ -373,7 +362,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
Expand All @@ -399,18 +388,18 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
) from e


def _compute_hash(buffer: bytes, secret_key: str) -> str:
"""Compute the hmac-sha256 hash"""
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
def _compute_hash(buffer: bytes) -> str:
"""Compute the sha256 hash"""
return hashlib.sha256(buffer).hexdigest()


def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
"""Performs integrity checks for serialized code/arguments uploaded to s3.

Verifies whether the hash read from s3 matches the hash calculated
during remote function execution.
"""
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
actual_hash_value = _compute_hash(buffer=buffer)
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
raise DeserializationError(
"Integrity check for the serialized function or data failed. "
Expand Down
11 changes: 0 additions & 11 deletions src/sagemaker/remote_function/core/stored_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
self,
sagemaker_session: Session,
s3_base_uri: str,
hmac_key: str,
s3_kms_key: str = None,
context: Context = Context(),
):
Expand All @@ -63,13 +62,11 @@ def __init__(
AWS service calls are delegated to.
s3_base_uri: the base uri to which serialized artifacts will be uploaded.
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
context: Build or run context of a pipeline step.
"""
self.sagemaker_session = sagemaker_session
self.s3_base_uri = s3_base_uri
self.s3_kms_key = s3_kms_key
self.hmac_key = hmac_key
self.context = context

self.func_upload_path = s3_path_join(
Expand Down Expand Up @@ -98,7 +95,6 @@ def save(self, func, *args, **kwargs):
sagemaker_session=self.sagemaker_session,
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
s3_kms_key=self.s3_kms_key,
hmac_key=self.hmac_key,
)

logger.info(
Expand All @@ -110,7 +106,6 @@ def save(self, func, *args, **kwargs):
obj=(args, kwargs),
sagemaker_session=self.sagemaker_session,
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
hmac_key=self.hmac_key,
s3_kms_key=self.s3_kms_key,
)

Expand All @@ -128,7 +123,6 @@ def save_pipeline_step_function(self, serialized_data):
)
serialization._upload_payload_and_metadata_to_s3(
bytes_to_upload=serialized_data.func,
hmac_key=self.hmac_key,
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
sagemaker_session=self.sagemaker_session,
s3_kms_key=self.s3_kms_key,
Expand All @@ -140,7 +134,6 @@ def save_pipeline_step_function(self, serialized_data):
)
serialization._upload_payload_and_metadata_to_s3(
bytes_to_upload=serialized_data.args,
hmac_key=self.hmac_key,
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
sagemaker_session=self.sagemaker_session,
s3_kms_key=self.s3_kms_key,
Expand All @@ -156,7 +149,6 @@ def load_and_invoke(self) -> Any:
func = serialization.deserialize_func_from_s3(
sagemaker_session=self.sagemaker_session,
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
hmac_key=self.hmac_key,
)

logger.info(
Expand All @@ -166,15 +158,13 @@ def load_and_invoke(self) -> Any:
args, kwargs = serialization.deserialize_obj_from_s3(
sagemaker_session=self.sagemaker_session,
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
hmac_key=self.hmac_key,
)

logger.info("Resolving pipeline variables")
resolved_args, resolved_kwargs = resolve_pipeline_variables(
self.context,
args,
kwargs,
hmac_key=self.hmac_key,
s3_base_uri=self.s3_base_uri,
sagemaker_session=self.sagemaker_session,
)
Expand All @@ -190,7 +180,6 @@ def load_and_invoke(self) -> Any:
obj=result,
sagemaker_session=self.sagemaker_session,
s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER),
hmac_key=self.hmac_key,
s3_kms_key=self.s3_kms_key,
)

Expand Down
Loading