diff --git a/src/sagemaker/feature_store/feature_processor/_config_uploader.py b/src/sagemaker/feature_store/feature_processor/_config_uploader.py index 0646adacb4..2a35dca224 100644 --- a/src/sagemaker/feature_store/feature_processor/_config_uploader.py +++ b/src/sagemaker/feature_store/feature_processor/_config_uploader.py @@ -120,9 +120,6 @@ def _prepare_and_upload_callable( stored_function = StoredFunction( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=self.remote_decorator_config.environment_variables[ - "REMOTE_FUNCTION_SECRET_KEY" - ], s3_kms_key=self.remote_decorator_config.s3_kms_key, ) stored_function.save(func) diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 55b4654aa9..f70eb48c58 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -362,7 +362,6 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), - hmac_key=job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ @@ -399,7 +398,6 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), - hmac_key=job.hmac_key, ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -979,7 +977,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), - hmac_key=job.hmac_key, ) except DeserializationError as e: client_exception = e @@ -991,7 +988,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), - hmac_key=job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ @@ -1081,7 +1077,6 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), - hmac_key=self._job.hmac_key, ) self._state = _FINISHED return self._return @@ -1090,7 +1085,6 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), - hmac_key=self._job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/src/sagemaker/remote_function/core/pipeline_variables.py b/src/sagemaker/remote_function/core/pipeline_variables.py index 952cccdb07..ad4f9e3aa3 100644 --- a/src/sagemaker/remote_function/core/pipeline_variables.py +++ b/src/sagemaker/remote_function/core/pipeline_variables.py @@ -164,7 +164,6 @@ class _DelayedReturnResolver: def __init__( self, delayed_returns: List[_DelayedReturn], - hmac_key: str, properties_resolver: _PropertiesResolver, parameter_resolver: _ParameterResolver, execution_variable_resolver: _ExecutionVariableResolver, @@ -175,7 +174,6 @@ def __init__( Args: delayed_returns: list of delayed returns to resolve. - hmac_key: key used to encrypt serialized and deserialized function and arguments. properties_resolver: resolver used to resolve step properties. parameter_resolver: resolver used to pipeline parameters. execution_variable_resolver: resolver used to resolve execution variables. @@ -197,7 +195,6 @@ def deserialization_task(uri): return uri, deserialize_obj_from_s3( sagemaker_session=settings["sagemaker_session"], s3_uri=uri, - hmac_key=hmac_key, ) with ThreadPoolExecutor() as executor: @@ -247,7 +244,6 @@ def resolve_pipeline_variables( context: Context, func_args: Tuple, func_kwargs: Dict, - hmac_key: str, s3_base_uri: str, **settings, ): @@ -257,7 +253,6 @@ def resolve_pipeline_variables( context: context for the execution. func_args: function args. func_kwargs: function kwargs. - hmac_key: key used to encrypt serialized and deserialized function and arguments. s3_base_uri: the s3 base uri of the function step that the serialized artifacts will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name. **settings: settings to pass to the deserialization function. @@ -280,7 +275,6 @@ def resolve_pipeline_variables( properties_resolver = _PropertiesResolver(context) delayed_return_resolver = _DelayedReturnResolver( delayed_returns=delayed_returns, - hmac_key=hmac_key, properties_resolver=properties_resolver, parameter_resolver=parameter_resolver, execution_variable_resolver=execution_variable_resolver, diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py index 229cf1ed0d..ccb3eee0cf 100644 --- a/src/sagemaker/remote_function/core/serialization.py +++ b/src/sagemaker/remote_function/core/serialization.py @@ -152,7 +152,7 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes function and uploads it to S3. @@ -160,7 +160,6 @@ def serialize_func_to_s3( sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. func: function to be serialized and persisted Raises: @@ -169,14 +168,13 @@ def serialize_func_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, ) -def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable: +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable: """Downloads from S3 and then deserializes data objects. This method downloads the serialized training job outputs to a temporary directory and @@ -186,7 +184,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. Returns : The deserialized function. Raises: @@ -199,14 +196,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes data object and uploads it to S3. @@ -215,7 +212,6 @@ def serialize_obj_to_s3( The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. @@ -223,7 +219,6 @@ def serialize_obj_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, @@ -270,14 +265,13 @@ def json_serialize_obj_to_s3( ) -def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes data objects. Args: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. Returns : Deserialized python objects. Raises: @@ -291,14 +285,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. @@ -307,7 +301,6 @@ def serialize_exception_to_s3( The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. @@ -316,7 +309,6 @@ def serialize_exception_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, @@ -325,7 +317,6 @@ def serialize_exception_to_s3( def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], - hmac_key: str, s3_uri: str, sagemaker_session: Session, s3_kms_key, @@ -334,7 +325,6 @@ def _upload_payload_and_metadata_to_s3( Args: bytes_to_upload (bytes): Serialized bytes to upload. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service calls are delegated to. @@ -342,7 +332,7 @@ def _upload_payload_and_metadata_to_s3( """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + sha256_hash = _compute_hash(bytes_to_upload) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), @@ -352,14 +342,13 @@ def _upload_payload_and_metadata_to_s3( ) -def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes exception. Args: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. Returns : Deserialized exception with traceback. Raises: @@ -373,7 +362,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_ bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -399,18 +388,18 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session): ) from e -def _compute_hash(buffer: bytes, secret_key: str) -> str: - """Compute the hmac-sha256 hash""" - return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() +def _compute_hash(buffer: bytes) -> str: + """Compute the sha256 hash""" + return hashlib.sha256(buffer).hexdigest() -def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes): +def _perform_integrity_check(expected_hash_value: str, buffer: bytes): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. """ - actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key) + actual_hash_value = _compute_hash(buffer=buffer) if not hmac.compare_digest(expected_hash_value, actual_hash_value): raise DeserializationError( "Integrity check for the serialized function or data failed. " diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index 862c67d9ee..99db688a33 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -52,7 +52,6 @@ def __init__( self, sagemaker_session: Session, s3_base_uri: str, - hmac_key: str, s3_kms_key: str = None, context: Context = Context(), ): @@ -63,13 +62,11 @@ def __init__( AWS service calls are delegated to. s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. - hmac_key: Key used to encrypt serialized and deserialized function and arguments. context: Build or run context of a pipeline step. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key - self.hmac_key = hmac_key self.context = context self.func_upload_path = s3_path_join( @@ -98,7 +95,6 @@ def save(self, func, *args, **kwargs): sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), s3_kms_key=self.s3_kms_key, - hmac_key=self.hmac_key, ) logger.info( @@ -110,7 +106,6 @@ def save(self, func, *args, **kwargs): obj=(args, kwargs), sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - hmac_key=self.hmac_key, s3_kms_key=self.s3_kms_key, ) @@ -128,7 +123,6 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.func, - hmac_key=self.hmac_key, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -140,7 +134,6 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.args, - hmac_key=self.hmac_key, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -156,7 +149,6 @@ def load_and_invoke(self) -> Any: func = serialization.deserialize_func_from_s3( sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - hmac_key=self.hmac_key, ) logger.info( @@ -166,7 +158,6 @@ def load_and_invoke(self) -> Any: args, kwargs = serialization.deserialize_obj_from_s3( sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - hmac_key=self.hmac_key, ) logger.info("Resolving pipeline variables") @@ -174,7 +165,6 @@ def load_and_invoke(self) -> Any: self.context, args, kwargs, - hmac_key=self.hmac_key, s3_base_uri=self.s3_base_uri, sagemaker_session=self.sagemaker_session, ) @@ -190,7 +180,6 @@ def load_and_invoke(self) -> Any: obj=result, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - hmac_key=self.hmac_key, s3_kms_key=self.s3_kms_key, ) diff --git a/src/sagemaker/remote_function/errors.py b/src/sagemaker/remote_function/errors.py index 9c91f46061..c97d196a90 100644 --- a/src/sagemaker/remote_function/errors.py +++ b/src/sagemaker/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,7 +79,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac hash of the serialized exception. Returns : exit_code (int): Exit code to terminate current job. """ @@ -97,7 +96,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> exc=error, sagemaker_session=sagemaker_session, s3_uri=s3_path_join(s3_base_uri, "exception"), - hmac_key=hmac_key, s3_kms_key=s3_kms_key, ) diff --git a/src/sagemaker/remote_function/invoke_function.py b/src/sagemaker/remote_function/invoke_function.py index 279bf6940b..1cf1eecec0 100644 --- a/src/sagemaker/remote_function/invoke_function.py +++ b/src/sagemaker/remote_function/invoke_function.py @@ -98,7 +98,7 @@ def _load_pipeline_context(args) -> Context: def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context + sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context ): """Execute stored remote function""" from sagemaker.remote_function.core.stored_function import StoredFunction @@ -107,7 +107,6 @@ def _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, context=context, ) @@ -138,15 +137,12 @@ def main(sys_args=None): run_in_context = args.run_in_context pipeline_context = _load_pipeline_context(args) - hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") - sagemaker_session = _get_sagemaker_session(region) _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, run_in_context=run_in_context, - hmac_key=hmac_key, context=pipeline_context, ) @@ -162,7 +158,6 @@ def main(sys_args=None): sagemaker_session=sagemaker_session, s3_base_uri=s3_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, ) finally: sys.exit(exit_code) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 9000ccda08..e874aa89e4 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -583,10 +583,7 @@ def __init__( {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} ) - # The following will be overridden by the _Job.compile method. - # However, it needs to be kept here for feature store SDK. - # TODO: update the feature store SDK to set the HMAC key there. - self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)}) + if spark_config and image_uri: raise ValueError("spark_config and image_uri cannot be specified at the same time!") @@ -799,19 +796,17 @@ def _get_default_spark_image(session): class _Job: """Helper class that interacts with the SageMaker training service.""" - def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str): + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): """Initialize a _Job object. Args: job_name (str): The training job name. s3_uri (str): The training job output S3 uri. sagemaker_session (Session): SageMaker boto session. - hmac_key (str): Remote function secret key. """ self.job_name = job_name self.s3_uri = s3_uri self.sagemaker_session = sagemaker_session - self.hmac_key = hmac_key self._last_describe_response = None @staticmethod @@ -827,9 +822,8 @@ def from_describe_response(describe_training_job_response, sagemaker_session): """ job_name = describe_training_job_response["TrainingJobName"] s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] - hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"] - job = _Job(job_name, s3_uri, sagemaker_session, hmac_key) + job = _Job(job_name, s3_uri, sagemaker_session) job._last_describe_response = describe_training_job_response return job @@ -867,7 +861,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non job_name, s3_base_uri, job_settings.sagemaker_session, - training_job_request["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], ) @staticmethod @@ -892,18 +885,11 @@ def compile( jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:] - # generate hmac key for integrity check - if step_compilation_context is None: - hmac_key = secrets.token_hex(32) - else: - hmac_key = step_compilation_context.function_step_secret_token - # serialize function and arguments if step_compilation_context is None: stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, ) stored_function.save(func, *func_args, **func_kwargs) @@ -911,7 +897,6 @@ def compile( stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, context=Context( step_name=step_compilation_context.step_name, @@ -1061,7 +1046,6 @@ def compile( request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances request_dict["Environment"] = job_settings.environment_variables - request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) extended_request = _extend_mpirun_to_request(extended_request, job_settings) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index f111f5e40b..0c60401d7e 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -1084,7 +1084,6 @@ def get_function_step_result( return deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_uri, - hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], ) raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py b/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py index 8193e93708..8536c2d893 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py @@ -68,7 +68,6 @@ def remote_decorator_config(sagemaker_session): pre_execution_commands="some_commands", pre_execution_script="some_path", python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, - environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"}, custom_file_filter=None, ) @@ -91,7 +90,6 @@ def remote_decorator_config_with_filter(sagemaker_session): pre_execution_commands="some_commands", pre_execution_script="some_path", python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, - environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"}, custom_file_filter=custom_file_filter, ) @@ -103,7 +101,6 @@ def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrap assert mock_stored_function.called_once_with( s3_base_uri="s3_base_uri", s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, - hmac_key="some_secret_key", sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 7b35174940..522bbbde4e 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -309,9 +309,6 @@ def test_to_pipeline( input_mode="File", environment={ "AWS_DEFAULT_REGION": "us-west-2", - "REMOTE_FUNCTION_SECRET_KEY": job_settings.environment_variables[ - "REMOTE_FUNCTION_SECRET_KEY" - ], "scheduled_time": Parameter( name="scheduled_time", parameter_type=ParameterTypeEnum.STRING ), diff --git a/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py b/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py index 422d1949af..211a703cc7 100644 --- a/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py +++ b/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py @@ -83,13 +83,12 @@ def test_resolve_delayed_returns(mock_deserializer): } ) resolver = _DelayedReturnResolver( - delayed_returns, - "1234", + delayed_returns=delayed_returns, properties_resolver=_PropertiesResolver(context), parameter_resolver=_ParameterResolver(context), execution_variable_resolver=_ExecutionVariableResolver(context), - sagemaker_session=None, s3_base_uri=f"s3://my-bucket/{PIPELINE_NAME}", + sagemaker_session=None, ) assert resolver.resolve(delayed_returns[0]) == 1 @@ -122,13 +121,12 @@ def test_deserializer_fails(mock_deserializer): ) with pytest.raises(Exception, match="Something went wrong"): _DelayedReturnResolver( - delayed_returns, - "1234", + delayed_returns=delayed_returns, properties_resolver=_PropertiesResolver(context), parameter_resolver=_ParameterResolver(context), execution_variable_resolver=_ExecutionVariableResolver(context), - sagemaker_session=None, s3_base_uri=f"s3://my-bucket/{PIPELINE_NAME}", + sagemaker_session=None, ) @@ -149,7 +147,6 @@ def test_no_pipeline_variables_to_resolve(mock_deserializer, func_args, func_kwa Context(), func_args, func_kwargs, - hmac_key="1234", s3_base_uri="s3://my-bucket", sagemaker_session=None, ) @@ -275,7 +272,6 @@ def test_resolve_pipeline_variables( context, func_args, func_kwargs, - hmac_key="1234", s3_base_uri=s3_base_uri, sagemaker_session=None, ) @@ -285,7 +281,6 @@ def test_resolve_pipeline_variables( mock_deserializer.assert_called_once_with( sagemaker_session=None, s3_uri=s3_results_uri, - hmac_key="1234", ) diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index e87dc39b59..5069189ea6 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -32,7 +32,6 @@ from tblib import pickling_support KMS_KEY = "kms-key" -HMAC_KEY = "some-hmac-key" mock_s3 = {} @@ -67,13 +66,13 @@ def square(x): s3_uri = random_s3_uri() serialize_func_to_s3( - func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) del square deserialized = deserialize_func_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert deserialized(3) == 9 @@ -89,11 +88,10 @@ def test_serialize_deserialize_lambda(): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) deserialized = deserialize_func_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert deserialized(3) == 9 @@ -126,7 +124,6 @@ def train(x): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) @@ -153,7 +150,6 @@ def func(x): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) @@ -177,7 +173,6 @@ def square(x): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) @@ -193,7 +188,7 @@ def square(x): s3_uri = random_s3_uri() serialize_func_to_s3( - func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) del square @@ -204,7 +199,7 @@ def square(x): + r"RuntimeError\('some failure when loads'\). " + r"NOTE: this may be caused by inconsistent sagemaker python sdk versions", ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -216,14 +211,14 @@ def square(x): s3_uri = random_s3_uri() serialize_func_to_s3( - func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) mock_s3[f"{s3_uri}/metadata.json"] = b"not json serializable" del square with pytest.raises(DeserializationError, match=r"Corrupt metadata file."): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -234,15 +229,17 @@ def square(x): s3_uri = random_s3_uri() serialize_func_to_s3( - func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) + # Tamper with the payload to trigger integrity check failure + mock_s3[f"{s3_uri}/payload.pkl"] = b"tampered data" del square with pytest.raises( DeserializationError, match=r"Integrity check for the serialized function or data failed." ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key="invalid_key") + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -256,14 +253,14 @@ def __init__(self, x): s3_uri = random_s3_uri() serialize_obj_to_s3( - my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) del my_data del MyData deserialized = deserialize_obj_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert deserialized.x == 10 @@ -277,13 +274,13 @@ def test_serialize_deserialize_data_built_in_types(): s3_uri = random_s3_uri() serialize_obj_to_s3( - my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) del my_data deserialized = deserialize_obj_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert deserialized == {"a": [10]} @@ -295,11 +292,11 @@ def test_serialize_deserialize_none(): s3_uri = random_s3_uri() serialize_obj_to_s3( - None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) deserialized = deserialize_obj_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert deserialized is None @@ -327,7 +324,6 @@ def test_serialize_run(sagemaker_session, *args, **kwargs): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) @@ -351,7 +347,6 @@ def test_serialize_pipeline_variables(pipeline_variable): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) @@ -377,7 +372,6 @@ def __init__(self, x): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) @@ -395,7 +389,7 @@ def __init__(self, x): s3_uri = random_s3_uri() serialize_obj_to_s3( - obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) del my_data @@ -407,7 +401,7 @@ def __init__(self, x): + r"RuntimeError\('some failure when loads'\). " + r"NOTE: this may be caused by inconsistent sagemaker python sdk versions", ): - deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_error) @@ -427,7 +421,6 @@ def test_serialize_deserialize_service_error(): sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, ) del my_func @@ -437,7 +430,7 @@ def test_serialize_deserialize_service_error(): match=rf"Failed to read serialized bytes from {s3_uri}/metadata.json: " + r"RuntimeError\('some failure when read_bytes'\)", ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -461,11 +454,11 @@ def func_b(): except Exception as e: pickling_support.install() serialize_obj_to_s3( - e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) with pytest.raises(CustomError, match="Some error") as exc_info: - raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) assert type(exc_info.value.__cause__) is TypeError @@ -489,12 +482,12 @@ def func_b(): func_b() except Exception as e: serialize_exception_to_s3( - e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) with pytest.raises(CustomError, match="Some error") as exc_info: raise deserialize_exception_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert type(exc_info.value.__cause__) is TypeError @@ -519,11 +512,11 @@ def func_b(): func_b() except Exception as e: serialize_exception_to_s3( - e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY ) with pytest.raises(ServiceError, match="Some error") as exc_info: raise deserialize_exception_from_s3( - sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + sagemaker_session=Mock(), s3_uri=s3_uri ) assert type(exc_info.value.__cause__) is TypeError diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py index 68a05c08a6..c4c69774a9 100644 --- a/tests/unit/sagemaker/remote_function/core/test_stored_function.py +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -50,7 +50,6 @@ ) KMS_KEY = "kms-key" -HMAC_KEY = "some-hmac-key" FUNCTION_FOLDER = "function" ARGUMENT_FOLDER = "arguments" RESULT_FOLDER = "results" @@ -96,13 +95,13 @@ def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwarg s3_base_uri = random_s3_uri() stored_function = StoredFunction( - sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY ) stored_function.save(quadratic, *args, **kwargs) stored_function.load_and_invoke() assert deserialize_obj_from_s3( - session, s3_uri=f"{s3_base_uri}/results", hmac_key=HMAC_KEY + session, s3_uri=f"{s3_base_uri}/results" ) == quadratic(*args, **kwargs) @@ -139,7 +138,7 @@ def test_save_with_parameter_of_run_type( sagemaker_session=session, ) stored_function = StoredFunction( - sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY ) with pytest.raises(SerializationError) as e: stored_function.save(log_bigger, 1, 2, run) @@ -165,7 +164,6 @@ def test_save_s3_paths_verification( sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, context=Context( step_name=step_name, execution_id=execution_id, @@ -180,13 +178,11 @@ def test_save_s3_paths_verification( sagemaker_session=session, s3_uri=(upload_path + FUNCTION_FOLDER), s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, - ) + ) serialize_obj.assert_called_once_with( obj=((3,), {}), sagemaker_session=session, s3_uri=(upload_path + ARGUMENT_FOLDER), - hmac_key=HMAC_KEY, s3_kms_key=KMS_KEY, ) @@ -226,7 +222,6 @@ def test_load_and_invoke_s3_paths_verification( sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, context=Context( step_name=step_name, execution_id=execution_id, @@ -237,13 +232,12 @@ def test_load_and_invoke_s3_paths_verification( stored_function.load_and_invoke() deserialize_func.assert_called_once_with( - sagemaker_session=session, s3_uri=(download_path + FUNCTION_FOLDER), hmac_key=HMAC_KEY + sagemaker_session=session, s3_uri=(download_path + FUNCTION_FOLDER) ) deserialize_obj.assert_called_once_with( sagemaker_session=session, s3_uri=(download_path + ARGUMENT_FOLDER), - hmac_key=HMAC_KEY, - ) + ) result = deserialize_func.return_value( *deserialize_obj.return_value[0], **deserialize_obj.return_value[1] @@ -253,7 +247,6 @@ def test_load_and_invoke_s3_paths_verification( obj=result, sagemaker_session=session, s3_uri=(upload_path + RESULT_FOLDER), - hmac_key=HMAC_KEY, s3_kms_key=KMS_KEY, ) @@ -283,7 +276,6 @@ def test_load_and_invoke_json_serialization( sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, context=Context( serialize_output_to_json=serialize_output_to_json, ), @@ -318,13 +310,12 @@ def test_save_and_load_with_pipeline_variable(monkeypatch): function_step = _FunctionStep(name="func_1", display_name=None, description=None) x = DelayedReturn(function_step=function_step) - serialize_obj_to_s3(3.0, session, func1_result_path, HMAC_KEY, KMS_KEY) + serialize_obj_to_s3(3.0, session, func1_result_path, KMS_KEY) stored_function = StoredFunction( sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, context=Context( property_references={ "Parameters.a": "1.0", @@ -356,7 +347,7 @@ def test_save_and_load_with_pipeline_variable(monkeypatch): func2_result_path = f"{s3_base_uri}/execution-id/func_2/results" assert deserialize_obj_from_s3( - session, s3_uri=func2_result_path, hmac_key=HMAC_KEY + session, s3_uri=func2_result_path ) == quadratic(3.0, a=1.0, b=2.0, c=3.0) @@ -371,7 +362,6 @@ def test_save_pipeline_step_function(mock_job_settings, upload_payload): sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, - hmac_key=HMAC_KEY, context=Context( step_name="step_name", execution_id="execution_id", diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index de8758bfad..7a70c3ede0 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -54,7 +54,7 @@ S3_URI = f"s3://{BUCKET}/keyprefix" EXPECTED_JOB_RESULT = [1, 2, 3] PATH_TO_SRC_DIR = "path/to/src/dir" -HMAC_KEY = "some-hmac-key" + ROLE_ARN = "arn:aws:iam::555555555555:role/my_execution_role_arn" @@ -69,7 +69,7 @@ def describe_training_job_response(job_status): "VolumeSizeInGB": 30, }, "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + "Environment": {}, } @@ -1027,7 +1027,7 @@ def test_future_get_result_from_completed_job(mock_start, mock_deserialize): def test_future_get_result_from_failed_job_remote_error_client_function( mock_start, mock_deserialize ): - mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI, hmac_key=HMAC_KEY) + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) mock_start.return_value = mock_job mock_job.describe.return_value = FAILED_TRAINING_JOB @@ -1043,7 +1043,7 @@ def test_future_get_result_from_failed_job_remote_error_client_function( assert future.done() mock_job.wait.assert_called_once() mock_deserialize.assert_called_with( - sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception", hmac_key=HMAC_KEY + sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception" ) @@ -1374,7 +1374,6 @@ def test_get_future_completed_job_deserialization_error(mock_session, mock_deser mock_deserialize.assert_called_with( sagemaker_session=ANY, s3_uri="s3://sagemaker-123/image_uri/output/results", - hmac_key=HMAC_KEY, ) diff --git a/tests/unit/sagemaker/remote_function/test_errors.py b/tests/unit/sagemaker/remote_function/test_errors.py index 399b1aed2e..cb49370334 100644 --- a/tests/unit/sagemaker/remote_function/test_errors.py +++ b/tests/unit/sagemaker/remote_function/test_errors.py @@ -20,7 +20,7 @@ TEST_S3_BASE_URI = "s3://my-bucket/" TEST_S3_KMS_KEY = "my-kms-key" -TEST_HMAC_KEY = "some-hmac-key" + class _InvalidErrorNumberException(Exception): @@ -76,7 +76,6 @@ def test_handle_error( sagemaker_session=sagemaker_session, s3_base_uri=TEST_S3_BASE_URI, s3_kms_key=TEST_S3_KMS_KEY, - hmac_key=TEST_HMAC_KEY, ) assert exit_code == expected_exit_code @@ -87,6 +86,5 @@ def test_handle_error( exc=err, sagemaker_session=sagemaker_session, s3_uri=TEST_S3_BASE_URI + "exception", - hmac_key=TEST_HMAC_KEY, s3_kms_key=TEST_S3_KMS_KEY, ) diff --git a/tests/unit/sagemaker/remote_function/test_invoke_function.py b/tests/unit/sagemaker/remote_function/test_invoke_function.py index b4500e04e3..23fb5a4201 100644 --- a/tests/unit/sagemaker/remote_function/test_invoke_function.py +++ b/tests/unit/sagemaker/remote_function/test_invoke_function.py @@ -25,7 +25,6 @@ TEST_S3_BASE_URI = "s3://my-bucket/" TEST_S3_KMS_KEY = "my-kms-key" TEST_RUN_IN_CONTEXT = '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' -TEST_HMAC_KEY = "some-hmac-key" TEST_STEP_NAME = "training-step" TEST_EXECUTION_ID = "some-execution-id" FUNC_STEP_S3_DIR = sagemaker_timestamp() @@ -89,7 +88,6 @@ def mock_session(): return_value=mock_session(), ) def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object): - os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY invoke_function.main(mock_args()) _get_sagemaker_session.assert_called_with(TEST_REGION) @@ -108,7 +106,6 @@ def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _l def test_main_success_with_run( _get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object ): - os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY invoke_function.main(mock_args_with_run_in_context()) _get_sagemaker_session.assert_called_with(TEST_REGION) @@ -137,7 +134,6 @@ def test_main_success_with_run( def test_main_success_with_pipeline_context( _get_sagemaker_session, mock_stored_function, _exit_process, _load_run_object, args ): - os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY args_input, expected_serialize_output_to_json = args invoke_function.main(args_input) @@ -147,7 +143,6 @@ def test_main_success_with_pipeline_context( sagemaker_session=ANY, s3_base_uri=TEST_S3_BASE_URI, s3_kms_key=TEST_S3_KMS_KEY, - hmac_key=TEST_HMAC_KEY, context=Context( execution_id=TEST_EXECUTION_ID, step_name=TEST_STEP_NAME, @@ -174,7 +169,6 @@ def test_main_success_with_pipeline_context( def test_main_failure( _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object ): - os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY ser_err = SerializationError("some failure reason") load_and_invoke.side_effect = ser_err handle_error.return_value = 1 @@ -189,7 +183,6 @@ def test_main_failure( sagemaker_session=_get_sagemaker_session(), s3_base_uri=TEST_S3_BASE_URI, s3_kms_key=TEST_S3_KMS_KEY, - hmac_key=TEST_HMAC_KEY, ) _exit_process.assert_called_with(1) @@ -205,7 +198,6 @@ def test_main_failure( def test_main_failure_with_step( _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object ): - os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY ser_err = SerializationError("some failure reason") load_and_invoke.side_effect = ser_err handle_error.return_value = 1 @@ -221,6 +213,5 @@ def test_main_failure_with_step( sagemaker_session=_get_sagemaker_session(), s3_base_uri=s3_uri, s3_kms_key=TEST_S3_KMS_KEY, - hmac_key=TEST_HMAC_KEY, ) _exit_process.assert_called_with(1) diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index f153b5b2ca..d7fbbd7b78 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -68,7 +68,7 @@ RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" -HMAC_KEY = "some-hmac-key" + EXPECTED_FUNCTION_URI = S3_URI + "/function.pkl" EXPECTED_OUTPUT_URI = S3_URI + "/output" @@ -252,10 +252,8 @@ "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.c4.xlarge", - "VolumeSizeInGB": 30, - }, - "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, -} + "VolumeSizeInGB": 30}, + "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}} TEST_JOB_NAME = "my-job-name" TEST_PIPELINE_NAME = "my-pipeline" @@ -297,20 +295,16 @@ def describe_training_job_response(job_status, disable_output_compression=False) "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.c4.xlarge", - "VolumeSizeInGB": 30, - }, - } + "VolumeSizeInGB": 30}} if disable_output_compression: output_config = { "S3OutputPath": "s3://sagemaker-123/image_uri/output", - "CompressionType": "NONE", - } + "CompressionType": "NONE"} else: output_config = { "S3OutputPath": "s3://sagemaker-123/image_uri/output", - "CompressionType": "NONE", - } + "CompressionType": "NONE"} job_response["OutputDataConfig"] = output_config @@ -359,24 +353,20 @@ def serialized_data(): return _SerializedData(func=b"serialized_func", args=b"serialized_args") -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) -def test_sagemaker_config_job_settings(get_execution_role, session, secret_token): +def test_sagemaker_config_job_settings(get_execution_role, session): job_settings = _JobSettings(image_uri="image_uri", instance_type="ml.m5.xlarge") assert job_settings.image_uri == "image_uri" assert job_settings.s3_root_uri == f"s3://{BUCKET}" assert job_settings.role == DEFAULT_ROLE_ARN assert job_settings.environment_variables == { - "AWS_DEFAULT_REGION": "us-west-2", - "REMOTE_FUNCTION_SECRET_KEY": "some-hmac-key", - } + "AWS_DEFAULT_REGION": "us-west-2"} assert job_settings.include_local_workdir is False assert job_settings.instance_type == "ml.m5.xlarge" -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch( "sagemaker.remote_function.job._JobSettings._get_default_spark_image", return_value="some_image_uri", @@ -384,7 +374,7 @@ def test_sagemaker_config_job_settings(get_execution_role, session, secret_token @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) def test_sagemaker_config_job_settings_with_spark_config( - get_execution_role, session, mock_get_default_spark_image, secret_token + get_execution_role, session, mock_get_default_spark_image ): spark_config = SparkConfig() @@ -394,9 +384,7 @@ def test_sagemaker_config_job_settings_with_spark_config( assert job_settings.s3_root_uri == f"s3://{BUCKET}" assert job_settings.role == DEFAULT_ROLE_ARN assert job_settings.environment_variables == { - "AWS_DEFAULT_REGION": "us-west-2", - "REMOTE_FUNCTION_SECRET_KEY": "some-hmac-key", - } + "AWS_DEFAULT_REGION": "us-west-2"} assert job_settings.include_local_workdir is False assert job_settings.instance_type == "ml.m5.xlarge" assert job_settings.spark_config == spark_config @@ -434,11 +422,10 @@ def test_sagemaker_config_job_settings_with_not_supported_param_by_spark(): ) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) def test_sagemaker_config_job_settings_with_configuration_file( - get_execution_role, session, secret_token + get_execution_role, session ): config_tags = [ {"Key": "someTagKey", "Value": "someTagValue"}, @@ -458,9 +445,7 @@ def test_sagemaker_config_job_settings_with_configuration_file( assert job_settings.pre_execution_commands == ["command_1", "command_2"] assert job_settings.environment_variables == { "AWS_DEFAULT_REGION": "us-west-2", - "REMOTE_FUNCTION_SECRET_KEY": "some-hmac-key", - "EnvVarKey": "EnvVarValue", - } + "EnvVarKey": "EnvVarValue"} assert job_settings.job_conda_env == "my_conda_env" assert job_settings.include_local_workdir is True assert job_settings.custom_file_filter.ignore_name_patterns == ["data", "test"] @@ -542,7 +527,6 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -556,7 +540,6 @@ def test_start( mock_runtime_manager, mock_script_upload, mock_dependency_upload, - secret_token, ): job_settings = _JobSettings( @@ -575,7 +558,6 @@ def test_start( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=None, ) @@ -616,8 +598,7 @@ def test_start( DataSource={ "S3DataSource": { "S3Uri": mock_script_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -625,8 +606,7 @@ def test_start( DataSource={ "S3DataSource": { "S3Uri": mock_dependency_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -662,12 +642,11 @@ def test_start( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -681,8 +660,7 @@ def test_start_with_checkpoint_location( mock_runtime_manager, mock_script_upload, mock_user_workspace_upload, - secret_token, -): + ): job_settings = _JobSettings( image_uri=IMAGE, @@ -707,7 +685,6 @@ def test_start_with_checkpoint_location( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=None, ) @@ -725,16 +702,14 @@ def test_start_with_checkpoint_location( RetryStrategy={"MaximumRetryAttempts": 1}, CheckpointConfig={ "LocalPath": "/opt/ml/checkpoints/", - "S3Uri": "s3://my-bucket/my-checkpoints/", - }, + "S3Uri": "s3://my-bucket/my-checkpoints/"}, InputDataConfig=[ dict( ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, DataSource={ "S3DataSource": { "S3Uri": mock_script_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -742,8 +717,7 @@ def test_start_with_checkpoint_location( DataSource={ "S3DataSource": { "S3Uri": mock_user_workspace_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -779,7 +753,7 @@ def test_start_with_checkpoint_location( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @@ -819,7 +793,6 @@ def test_start_with_checkpoint_location_failed_with_multiple_checkpoint_location ) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -833,8 +806,7 @@ def test_start_with_complete_job_settings( mock_runtime_manager, mock_bootstrap_script_upload, mock_user_workspace_upload, - secret_token, -): + ): job_settings = _JobSettings( dependencies="path/to/dependencies/req.txt", @@ -860,7 +832,6 @@ def test_start_with_complete_job_settings( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=KMS_KEY_ARN, ) @@ -899,8 +870,7 @@ def test_start_with_complete_job_settings( DataSource={ "S3DataSource": { "S3Uri": mock_bootstrap_script_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -908,8 +878,7 @@ def test_start_with_complete_job_settings( DataSource={ "S3DataSource": { "S3Uri": mock_user_workspace_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -949,12 +918,12 @@ def test_start_with_complete_job_settings( EnableInterContainerTrafficEncryption=False, VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]), EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -@patch("secrets.token_hex", MagicMock(return_value=HMAC_KEY)) + @patch( "sagemaker.remote_function.job._prepare_dependencies_and_pre_execution_scripts", return_value="some_s3_uri", @@ -1019,15 +988,14 @@ def test_get_train_args_under_pipeline_context( func_kwargs={ "c": 3, "d": ParameterInteger(name="d", default_value=4), - "e": DelayedReturn(function_step, reference_path=("__getitem__", 1)), - }, + "e": DelayedReturn(function_step, reference_path=("__getitem__", 1))}, serialized_data=mocked_serialized_data, ) mock_stored_function_ctr.assert_called_once_with( sagemaker_session=session(), s3_base_uri=s3_base_uri, - hmac_key="token-from-pipeline", + s3_kms_key=KMS_KEY_ARN, context=Context( step_name=MOCKED_PIPELINE_CONFIG.step_name, @@ -1073,8 +1041,7 @@ def test_get_train_args_under_pipeline_context( DataSource={ "S3DataSource": { "S3Uri": mock_bootstrap_scripts_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -1082,8 +1049,7 @@ def test_get_train_args_under_pipeline_context( DataSource={ "S3DataSource": { "S3Uri": mock_user_dependencies_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -1091,8 +1057,7 @@ def test_get_train_args_under_pipeline_context( DataSource={ "S3DataSource": { "S3Uri": mock_user_workspace_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -1106,8 +1071,7 @@ def test_get_train_args_under_pipeline_context( "results", ], ), - "KmsKeyId": KMS_KEY_ARN, - }, + "KmsKeyId": KMS_KEY_ARN}, AlgorithmSpecification=dict( TrainingImage=IMAGE, TrainingInputMode="File", @@ -1161,13 +1125,10 @@ def test_get_train_args_under_pipeline_context( VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]), EnableManagedSpotTraining=False, Environment={ - "AWS_DEFAULT_REGION": "us-west-2", - "REMOTE_FUNCTION_SECRET_KEY": "token-from-pipeline", - }, + "AWS_DEFAULT_REGION": "us-west-2"}, ) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch( "sagemaker.remote_function.job._JobSettings._get_default_spark_image", return_value="some_image_uri", @@ -1192,8 +1153,7 @@ def test_start_with_spark( mock_dependency_upload, mock_spark_dependency_upload, mock_get_default_spark_image, - secrete_token, -): + ): spark_config = SparkConfig() job_settings = _JobSettings( spark_config=spark_config, @@ -1237,8 +1197,7 @@ def test_start_with_spark( "S3DataSource": { "S3Uri": mock_script_upload.return_value, "S3DataType": "S3Prefix", - "S3DataDistributionType": "FullyReplicated", - } + "S3DataDistributionType": "FullyReplicated"} }, ), dict( @@ -1247,8 +1206,7 @@ def test_start_with_spark( "S3DataSource": { "S3Uri": mock_dependency_upload.return_value, "S3DataType": "S3Prefix", - "S3DataDistributionType": "FullyReplicated", - } + "S3DataDistributionType": "FullyReplicated"} }, ), dict( @@ -1257,8 +1215,7 @@ def test_start_with_spark( "S3DataSource": { "S3Uri": "config_file_s3_uri", "S3DataType": "S3Prefix", - "S3DataDistributionType": "FullyReplicated", - } + "S3DataDistributionType": "FullyReplicated"} }, ), ], @@ -1301,7 +1258,7 @@ def test_start_with_spark( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @@ -1816,8 +1773,7 @@ def test_extend_spark_config_to_request( "/opt/ml/input/data/sagemaker_remote_function_bootstrap/spark_app.py", ], "TrainingImage": "image_uri", - "TrainingInputMode": "File", - }, + "TrainingInputMode": "File"}, InputDataConfig=[ { "ChannelName": "conf", @@ -1825,16 +1781,13 @@ def test_extend_spark_config_to_request( "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "config_file_s3_uri", - "S3DataDistributionType": "FullyReplicated", - } - }, - } + "S3DataDistributionType": "FullyReplicated"} + }} ], ) @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -1848,8 +1801,7 @@ def test_start_with_torchrun_single_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, - secret_token, -): + ): job_settings = _JobSettings( image_uri=IMAGE, @@ -1869,7 +1821,6 @@ def test_start_with_torchrun_single_node( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=None, ) @@ -1910,8 +1861,7 @@ def test_start_with_torchrun_single_node( DataSource={ "S3DataSource": { "S3Uri": mock_script_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -1919,8 +1869,7 @@ def test_start_with_torchrun_single_node( DataSource={ "S3DataSource": { "S3Uri": mock_dependency_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -1958,12 +1907,11 @@ def test_start_with_torchrun_single_node( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -1977,8 +1925,7 @@ def test_start_with_torchrun_multi_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, - secret_token, -): + ): job_settings = _JobSettings( image_uri=IMAGE, @@ -1999,7 +1946,6 @@ def test_start_with_torchrun_multi_node( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=None, ) @@ -2041,8 +1987,7 @@ def test_start_with_torchrun_multi_node( "S3DataSource": { "S3Uri": mock_script_upload.return_value, "S3DataType": "S3Prefix", - "S3DataDistributionType": "FullyReplicated", - } + "S3DataDistributionType": "FullyReplicated"} }, ), dict( @@ -2051,8 +1996,7 @@ def test_start_with_torchrun_multi_node( "S3DataSource": { "S3Uri": mock_dependency_upload.return_value, "S3DataType": "S3Prefix", - "S3DataDistributionType": "FullyReplicated", - } + "S3DataDistributionType": "FullyReplicated"} }, ), ], @@ -2090,7 +2034,7 @@ def test_start_with_torchrun_multi_node( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @@ -2355,7 +2299,6 @@ def test_set_env_multi_node_multi_gpu_mpirun( @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -2369,8 +2312,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, - secret_token, -): + ): job_settings = _JobSettings( image_uri=IMAGE, @@ -2391,7 +2333,6 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=None, ) @@ -2432,8 +2373,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( DataSource={ "S3DataSource": { "S3Uri": mock_script_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -2441,8 +2381,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( DataSource={ "S3DataSource": { "S3Uri": mock_dependency_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -2482,12 +2421,11 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) -@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -2501,8 +2439,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, - secret_token, -): + ): job_settings = _JobSettings( image_uri=IMAGE, @@ -2523,7 +2460,6 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", - hmac_key=HMAC_KEY, s3_kms_key=None, ) @@ -2564,8 +2500,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( DataSource={ "S3DataSource": { "S3Uri": mock_script_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), dict( @@ -2573,8 +2508,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( DataSource={ "S3DataSource": { "S3Uri": mock_dependency_upload.return_value, - "S3DataType": "S3Prefix", - } + "S3DataType": "S3Prefix"} }, ), ], @@ -2614,7 +2548,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, EnableManagedSpotTraining=False, - Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index d83bebd167..21ee5e75c7 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -324,7 +324,6 @@ def test_pipeline_execution_result( }, "TrainingJobStatus": "Completed", "OutputDataConfig": {"S3OutputPath": s3_output_path}, - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "abcdefg"}, } execution.result("stepA")