From 1d6c559e70506891e68e3c1cdcca08f4569dc4c4 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 11 Dec 2025 11:48:03 -0800 Subject: [PATCH 1/6] Add aws batch implementation (works with example notebook) --- .../sagemaker/core/helper/session_helper.py | 26 +- .../src/sagemaker/train/aws_batch/__init__.py | 0 .../train/aws_batch/batch_api_helper.py | 186 +++++++++ .../sagemaker/train/aws_batch/boto_client.py | 33 ++ .../sagemaker/train/aws_batch/constants.py | 34 ++ .../sagemaker/train/aws_batch/exception.py | 52 +++ .../train/aws_batch/training_queue.py | 203 ++++++++++ .../train/aws_batch/training_queued_job.py | 354 ++++++++++++++++++ .../src/sagemaker/train/model_trainer.py | 227 ++++++----- sagemaker-train/src/sagemaker/train/utils.py | 16 + .../tests/unit/train/aws_batch/constants.py | 166 ++++++++ 11 files changed, 1200 insertions(+), 97 deletions(-) create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/__init__.py create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/constants.py create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/exception.py create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py create mode 100644 sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py create mode 100644 sagemaker-train/tests/unit/train/aws_batch/constants.py diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index bc35c7eb9d..2d6ecf80d5 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1865,6 +1865,30 @@ def expand_role(self, role): if "/" in role: return role return self.boto_session.resource("iam").Role(role).arn + + + def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=None): + """Display logs for a given training job, optionally tailing them until job is complete. + + If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + job_name (str): Name of the training job to display the logs for. + wait (bool): Whether to keep looking for new log entries until the job completes + (default: False). + poll (int): The interval in seconds between polling for new log entries and job + completion (default: 5). + log_type ([str]): A list of strings specifying which logs to print. Acceptable + strings are "All", "None", "Training", or "Rules". To maintain backwards + compatibility, boolean values are also accepted and converted to strings. + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If waiting and the training job fails. + """ + _logs_for_job(self, job_name, wait, poll, log_type, timeout) def _expand_container_def(c_def): @@ -2962,4 +2986,4 @@ def container_def( c_def["Mode"] = container_mode if image_config: c_def["ImageConfig"] = image_config - return c_def + return c_def \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/__init__.py b/sagemaker-train/src/sagemaker/train/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py new file mode 100644 index 0000000000..4cb88f65ea --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The module provides helper function for Batch Submit/Describe/Terminal job APIs.""" +from __future__ import absolute_import + +import json +from typing import List, Dict, Optional +from sagemaker.train.aws_batch.constants import ( + SAGEMAKER_TRAINING, + DEFAULT_TIMEOUT, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, +) +from sagemaker.train.aws_batch.boto_client import get_batch_boto_client + + +def submit_service_job( + training_payload: Dict, + job_name: str, + job_queue: str, + retry_config: Optional[Dict] = None, + scheduling_priority: Optional[int] = None, + timeout: Optional[Dict] = None, + share_identifier: Optional[str] = None, + tags: Optional[Dict] = None, +) -> Dict: + """Batch submit_service_job API helper function. + + Args: + training_payload: a dict containing a dict of arguments for Training job. + job_name: Batch job name. + job_queue: Batch job queue ARN. + retry_config: Batch job retry configuration. + scheduling_priority: An integer representing scheduling priority. + timeout: Set with value of timeout if specified, else default to 1 day. + share_identifier: value of shareIdentifier if specified. + tags: A dict of string to string representing Batch tags. + + Returns: + A dict containing jobArn, jobName and jobId. + """ + if timeout is None: + timeout = DEFAULT_TIMEOUT + client = get_batch_boto_client() + training_payload_tags = training_payload.pop("Tags", None) + payload = { + "jobName": job_name, + "jobQueue": job_queue, + "retryStrategy": DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + "serviceJobType": SAGEMAKER_TRAINING, + "serviceRequestPayload": json.dumps(training_payload), + "timeoutConfig": timeout, + } + if retry_config: + payload["retryStrategy"] = retry_config + if scheduling_priority: + payload["schedulingPriority"] = scheduling_priority + if share_identifier: + payload["shareIdentifier"] = share_identifier + if tags or training_payload_tags: + payload["tags"] = __merge_tags(tags, training_payload_tags) + return client.submit_service_job(**payload) + + +def describe_service_job(job_id: str) -> Dict: + """Batch describe_service_job API helper function. + + Args: + job_id: Job ID used. + + Returns: a dict. See the sample below + { + 'attempts': [ + { + 'serviceResourceId': { + 'name': 'string', + 'value': 'string' + }, + 'startedAt': 123, + 'stoppedAt': 123, + 'statusReason': 'string' + }, + ], + 'createdAt': 123, + 'isTerminated': True|False, + 'jobArn': 'string', + 'jobId': 'string', + 'jobName': 'string', + 'jobQueue': 'string', + 'retryStrategy': { + 'attempts': 123 + }, + 'schedulingPriority': 123, + 'serviceRequestPayload': 'string', + 'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING', + 'shareIdentifier': 'string', + 'startedAt': 123, + 'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED', + 'statusReason': 'string', + 'stoppedAt': 123, + 'tags': { + 'string': 'string' + }, + 'timeout': { + 'attemptDurationSeconds': 123 + } + } + """ + client = get_batch_boto_client() + return client.describe_service_job(jobId=job_id) + + +def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict: + """Batch terminate_service_job API helper function. + + Args: + job_id: Job ID + reason: A string representing terminate reason. + + Returns: an empty dict + """ + client = get_batch_boto_client() + return client.terminate_service_job(jobId=job_id, reason=reason) + + +def list_service_job( + job_queue: str, + job_status: Optional[str] = None, + filters: Optional[List] = None, + next_token: Optional[str] = None, +) -> Dict: + """Batch list_service_job API helper function. + + Args: + job_queue: Batch job queue ARN. + job_status: Batch job status. + filters: A list of Dict. Each contains a filter. + next_token: Used to retrieve data in next page. + + Returns: A generator containing list results. + + """ + client = get_batch_boto_client() + payload = {"jobQueue": job_queue} + if filters: + payload["filters"] = filters + if next_token: + payload["nextToken"] = next_token + if job_status: + payload["jobStatus"] = job_status + part_of_jobs = client.list_service_jobs(**payload) + next_token = part_of_jobs.get("nextToken") + yield part_of_jobs + if next_token: + yield from list_service_job(job_queue, job_status, filters, next_token) + + +def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]: + """Merges Batch and training payload tags. + + Returns a copy of Batch tags merged with training payload tags. Training payload tags take + precedence in the case of key conflicts. + + :param batch_tags: A dict of string to string representing Batch tags. + :param training_tags: A list of `{"Key": "string", "Value": "string"}` objects representing + training payload tags. + :return: A dict of string to string representing batch tags merged with training tags. + batch_tags is returned unaltered if training_tags is None or empty. + """ + if not training_tags: + return batch_tags + + training_tags_to_merge = {tag["Key"]: tag["Value"] for tag in training_tags} + batch_tags_copy = batch_tags.copy() if batch_tags else {} + batch_tags_copy.update(training_tags_to_merge) + + return batch_tags_copy diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py b/sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py new file mode 100644 index 0000000000..87f3486887 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file provides helper function for getting Batch boto client.""" +from __future__ import absolute_import + +from typing import Optional +import boto3 + + +def get_batch_boto_client( + region: Optional[str] = None, + endpoint: Optional[str] = None, +) -> boto3.session.Session.client: + """Helper function for getting Batch boto3 client. + + Args: + region: Region specified + endpoint: Batch API endpoint. + + Returns: Batch boto3 client. + + """ + return boto3.client("batch", region_name=region, endpoint_url=endpoint) diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/constants.py b/sagemaker-train/src/sagemaker/train/aws_batch/constants.py new file mode 100644 index 0000000000..ee41d3a413 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/constants.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file defines constants used for Batch API helper functions.""" + +from __future__ import absolute_import + +SAGEMAKER_TRAINING = "SAGEMAKER_TRAINING" +DEFAULT_ATTEMPT_DURATION_IN_SECONDS = 86400 # 1 day in seconds. +DEFAULT_TIMEOUT = {"attemptDurationSeconds": DEFAULT_ATTEMPT_DURATION_IN_SECONDS} +POLL_IN_SECONDS = 5 +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_COMPLETED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError: " + "We encountered an internal error. Please try again.", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/exception.py b/sagemaker-train/src/sagemaker/train/aws_batch/exception.py new file mode 100644 index 0000000000..94318bbce4 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/exception.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file Defines customized exception for Batch queueing""" +from __future__ import absolute_import + + +class NoTrainingJob(Exception): + """Define NoTrainingJob Exception. + + It means no Training job has been created by AWS Batch service. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) + + +class MissingRequiredArgument(Exception): + """Define MissingRequiredArgument exception. + + It means some required arguments are missing. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py new file mode 100644 index 0000000000..c464f0b382 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py @@ -0,0 +1,203 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define Queue class for AWS Batch service""" +from __future__ import absolute_import + +from typing import Dict, Optional, List +import logging +from sagemaker.train.model_trainer import ModelTrainer, Mode +from .training_queued_job import TrainingQueuedJob +from .batch_api_helper import submit_service_job, list_service_job +from .exception import MissingRequiredArgument +from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING + + +class TrainingQueue: + """TrainingQueue class for AWS Batch service + + With this class, customers are able to create a new queue and submit jobs to AWS Batch Service. + """ + + def __init__(self, queue_name: str): + self.queue_name = queue_name + + def submit( + self, + training_job: ModelTrainer, + inputs, + job_name: Optional[str] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> TrainingQueuedJob: + """Submit a queued job and return a QueuedJob object. + + Args: + training_job: Training job ModelTrainer object. + inputs: Training job inputs. + job_name: Batch job name. + retry_config: Retry configuration for Batch job. + priority: Scheduling priority for Batch job. + share_identifier: Share identifier for Batch job. + timeout: Timeout configuration for Batch job. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a TrainingQueuedJob object with Batch job ARN and job name. + + """ + if not isinstance(training_job, ModelTrainer): + raise TypeError( + "training_job must be an instance of ModelTrainer, " + f"but got {type(training_job)}" + ) + + if training_job.training_mode != Mode.SAGEMAKER_TRAINING_JOB: + raise ValueError( + "TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB" + ) + if experiment_config is not None: + logging.warning( + "ExperimentConfig is not supported for ModelTrainer. " + "It will be ignored when submitting the job." + ) + training_payload = training_job._create_training_job_args( + input_data_config=inputs, boto3=True + ) + + if timeout is None: + timeout = DEFAULT_TIMEOUT + if job_name is None: + job_name = training_payload["TrainingJobName"] + + resp = submit_service_job( + training_payload, + job_name, + self.queue_name, + retry_config, + priority, + timeout, + share_identifier, + tags, + ) + if "jobArn" not in resp or "jobName" not in resp: + raise MissingRequiredArgument( + "jobArn or jobName is missing in response from Batch submit_service_job API" + ) + return TrainingQueuedJob(resp["jobArn"], resp["jobName"]) + + def map( + self, + training_job: ModelTrainer, + inputs, + job_names: Optional[List[str]] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> List[TrainingQueuedJob]: + """Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects. + + Args: + training_job: Training job ModelTrainer object. + inputs: List of Training job inputs. + job_names: List of Batch job names. + retry_config: Retry config for the Batch jobs. + priority: Scheduling priority for the Batch jobs. + share_identifier: Share identifier for the Batch jobs. + timeout: Timeout configuration for the Batch jobs. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name. + + """ + if experiment_config is None: + experiment_config = {} + + if job_names is not None: + if len(job_names) != len(inputs): + raise ValueError( + "When specified, the number of job names must match the number of inputs" + ) + else: + job_names = [None] * len(inputs) + + queued_batch_job_list = [] + for index, value in enumerate(inputs): + queued_batch_job = self.submit( + training_job, + value, + job_names[index], + retry_config, + priority, + share_identifier, + timeout, + tags, + experiment_config, + ) + queued_batch_job_list.append(queued_batch_job) + + return queued_batch_job_list + + def list_jobs( + self, job_name: Optional[str] = None, status: Optional[str] = JOB_STATUS_RUNNING + ) -> List[TrainingQueuedJob]: + """List Batch jobs according to job_name or status. + + Args: + job_name: Batch job name. + status: Batch job status. + + Returns: A list of QueuedJob. + + """ + filters = None + if job_name: + filters = [{"name": "JOB_NAME", "values": [job_name]}] + status = None # job_status is ignored when job_name is specified. + jobs_to_return = [] + next_token = None + for job_result_dict in list_service_job(self.queue_name, status, filters, next_token): + for job_result in job_result_dict.get("jobSummaryList", []): + if "jobArn" in job_result and "jobName" in job_result: + jobs_to_return.append( + TrainingQueuedJob(job_result["jobArn"], job_result["jobName"]) + ) + else: + logging.warning("Missing JobArn or JobName in Batch ListJobs API") + continue + return jobs_to_return + + def get_job(self, job_name): + """Get a Batch job according to job_name. + + Args: + job_name: Batch job name. + + Returns: The QueuedJob with name matching job_name. + + """ + jobs_to_return = self.list_jobs(job_name) + if len(jobs_to_return) == 0: + raise ValueError(f"Cannot find job: {job_name}") + return jobs_to_return[0] diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py new file mode 100644 index 0000000000..d7c42ea7ad --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py @@ -0,0 +1,354 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define QueuedJob class for AWS Batch service""" +from __future__ import absolute_import + +import logging +import time +import asyncio +import re +from typing import Optional, Dict +import nest_asyncio +from sagemaker.core.resources import TrainingJob +from sagemaker.core.shapes import Unassigned +from sagemaker.train.model_trainer import ModelTrainer +from sagemaker.train.configs import ( + Compute, + Networking, + StoppingCondition, + SourceCode, + TrainingImageConfig, +) +from .batch_api_helper import terminate_service_job, describe_service_job +from .exception import NoTrainingJob, MissingRequiredArgument +from ..utils import get_training_job_name_from_training_job_arn +from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS + +logging.basicConfig( + format="%(asctime)s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" +) + + +class TrainingQueuedJob: + """TrainingQueuedJob class for AWS Batch service. + + With this class, customers are able to attach the latest training job to a ModelTrainer. + """ + + def __init__(self, job_arn: str, job_name: str): + self.job_arn = job_arn + self.job_name = job_name + self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"} + + def get_model_trainer(self) -> ModelTrainer: + """Attach the latest training job to a ModelTrainer and return. + + Returns: a ModelTrainer instance. + + """ + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + if self._training_job_created(job_status): + if "latestAttempt" not in describe_resp: + raise MissingRequiredArgument("No LatestAttempt in describe call") + new_training_job_name = _get_new_training_job_name_from_latest_attempt( + describe_resp["latestAttempt"] + ) + output_model_trainer = _construct_model_trainer_from_training_job_name( + new_training_job_name + ) + _remove_system_tags_in_place_in_model_trainer_object(output_model_trainer) + return output_model_trainer + + _output_attempt_history(describe_resp) + raise NoTrainingJob("No Training job created. Job is still waiting in queue") + + def terminate(self, reason: Optional[str] = "Default terminate reason") -> None: + """Terminate Batch job. + + Args: + reason: Reason for terminating a job. + + Returns: None + + """ + terminate_service_job(self.job_arn, reason) + + def describe(self) -> Dict: + """Describe Batch job. + + Returns: A dict which includes job parameters, job status, attempts and so on. + + """ + return describe_service_job(self.job_arn) + + def _training_job_created(self, status: str) -> bool: + """Return True if a Training job has been created + + Args: + status: Job status returned from Batch API. + + Returns: a boolean indicating whether a Training job has been created. + + """ + return status not in self._no_training_job_status + + def result(self, timeout: int = None) -> Dict: + """Fetch the terminal result of the Batch job. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict. + + """ + nest_asyncio.apply() + loop = asyncio.get_event_loop() + task = loop.create_task(self.fetch_job_results(timeout)) + resp = loop.run_until_complete(task) + return resp + + async def fetch_job_results(self, timeout: int = None) -> Dict: + """Async method that waits for the Batch job to complete or until timeout. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict, or an Error. + + """ + self.wait(timeout) + + describe_resp = self.describe() + if describe_resp.get("status", "") == JOB_STATUS_COMPLETED: + return describe_resp + if describe_resp.get("status", "") == JOB_STATUS_FAILED: + raise RuntimeError(describe_resp["statusReason"]) + raise TimeoutError("Reached timeout before the Batch job reached a terminal status") + + def wait(self, timeout: int = None) -> Dict: + """Wait for the Batch job to finish. + + This method blocks on the job completing for up to the timeout value (if specified). + If timeout is ``None``, this method will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: The last describe_service_job response for the Batch job. + """ + request_end_time = time.time() + timeout if timeout else None + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + while not job_completed: + if timeout and time.time() > request_end_time: + logging.info( + "Timeout exceeded: %d seconds elapsed. Returning current results", timeout + ) + break + if job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED): + break + + time.sleep(POLL_IN_SECONDS) + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + return describe_resp + + +def _construct_model_trainer_from_training_job_name(training_job_name: str) -> ModelTrainer: + """Build ModelTrainer instance from training job name. + + Args: + training_job_name: Training job name. + + Returns: a ModelTrainer instance with _latest_training_job set. + + """ + # Step 1: Get the TrainingJob resource + training_job = TrainingJob.get(training_job_name=training_job_name) + + # Step 2: Extract parameters from training_job to reconstruct ModelTrainer + init_params = {} + + # Required/common parameters + init_params["role"] = training_job.role_arn + init_params["base_job_name"] = _extract_base_job_name(training_job_name) + + # Training image or algorithm + if training_job.algorithm_specification and not isinstance(training_job.algorithm_specification, Unassigned): + if (training_job.algorithm_specification.training_image and + not isinstance(training_job.algorithm_specification.training_image, Unassigned)): + init_params["training_image"] = training_job.algorithm_specification.training_image + if (training_job.algorithm_specification.algorithm_name and + not isinstance(training_job.algorithm_specification.algorithm_name, Unassigned)): + init_params["algorithm_name"] = training_job.algorithm_specification.algorithm_name + if (training_job.algorithm_specification.training_input_mode and + not isinstance(training_job.algorithm_specification.training_input_mode, Unassigned)): + init_params["training_input_mode"] = training_job.algorithm_specification.training_input_mode + + # Compute config + if training_job.resource_config and not isinstance(training_job.resource_config, Unassigned): + compute_params = {} + + if (training_job.resource_config.instance_type and + not isinstance(training_job.resource_config.instance_type, Unassigned)): + compute_params["instance_type"] = training_job.resource_config.instance_type + if (training_job.resource_config.instance_count and + not isinstance(training_job.resource_config.instance_count, Unassigned)): + compute_params["instance_count"] = training_job.resource_config.instance_count + if (training_job.resource_config.volume_size_in_gb and + not isinstance(training_job.resource_config.volume_size_in_gb, Unassigned)): + compute_params["volume_size_in_gb"] = training_job.resource_config.volume_size_in_gb + + # Add managed spot training if enabled (available directly on TrainingJob) + if training_job.enable_managed_spot_training and not isinstance(training_job.enable_managed_spot_training, Unassigned): + compute_params["enable_managed_spot_training"] = training_job.enable_managed_spot_training + + if compute_params: # Only create Compute if we have valid params + init_params["compute"] = Compute(**compute_params) + + # Output config - pass the raw training job output config directly + if training_job.output_data_config and not isinstance(training_job.output_data_config, Unassigned): + init_params["output_data_config"] = training_job.output_data_config + + # Stopping condition + if training_job.stopping_condition and not isinstance(training_job.stopping_condition, Unassigned): + if (training_job.stopping_condition.max_runtime_in_seconds and + not isinstance(training_job.stopping_condition.max_runtime_in_seconds, Unassigned)): + init_params["stopping_condition"] = StoppingCondition( + max_runtime_in_seconds=training_job.stopping_condition.max_runtime_in_seconds, + ) + + # Networking + if training_job.vpc_config and not isinstance(training_job.vpc_config, Unassigned): + networking_params = {} + + if (training_job.vpc_config.subnets and + not isinstance(training_job.vpc_config.subnets, Unassigned)): + networking_params["subnets"] = training_job.vpc_config.subnets + if (training_job.vpc_config.security_group_ids and + not isinstance(training_job.vpc_config.security_group_ids, Unassigned)): + networking_params["security_group_ids"] = training_job.vpc_config.security_group_ids + + # Add network isolation if present (available directly on TrainingJob) + if training_job.enable_network_isolation and not isinstance(training_job.enable_network_isolation, Unassigned): + networking_params["enable_network_isolation"] = training_job.enable_network_isolation + + # Add inter-container traffic encryption if present (available directly on TrainingJob) + if training_job.enable_inter_container_traffic_encryption and not isinstance(training_job.enable_inter_container_traffic_encryption, Unassigned): + networking_params["enable_inter_container_traffic_encryption"] = training_job.enable_inter_container_traffic_encryption + + if networking_params: # Only create Networking if we have valid params + init_params["networking"] = Networking(**networking_params) + + # Hyperparameters + if training_job.hyper_parameters and not isinstance(training_job.hyper_parameters, Unassigned): + init_params["hyperparameters"] = training_job.hyper_parameters + + # Environment + if training_job.environment and not isinstance(training_job.environment, Unassigned): + init_params["environment"] = training_job.environment + + # Checkpoint config + if training_job.checkpoint_config and not isinstance(training_job.checkpoint_config, Unassigned): + init_params["checkpoint_config"] = training_job.checkpoint_config + + # Step 3: Create ModelTrainer + model_trainer = ModelTrainer(**init_params) + + # Step 4: Set _latest_training_job (key insight!) + model_trainer._latest_training_job = training_job + + return model_trainer + + +def _extract_base_job_name(training_job_name: str) -> str: + """Extract base job name from full training job name. + + Args: + training_job_name: Full training job name. + + Returns: Base job name. + + """ + # Use the same regex pattern as PySDK V2's base_from_name() function + # Matches timestamps like: YYYY-MM-DD-HH-MM-SS-SSS or YYMMDD-HHMM + match = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", training_job_name) + return match.group(1) if match else training_job_name + + +def _output_attempt_history(describe_resp: Dict) -> None: + """Print attempt history if no Training job created. + + Args: + describe_resp: Describe response from Batch API. + + Returns: None + + """ + has_seen_status_reason = False + for i, attempt_dict in enumerate(describe_resp.get("attempts", [])): + if "statusReason" in attempt_dict: + logging.info("Attempt %d - %s", i + 1, attempt_dict["statusReason"]) + has_seen_status_reason = True + if not has_seen_status_reason: + logging.info("No attempts found or no statusReason found.") + + +def _get_new_training_job_name_from_latest_attempt(latest_attempt: Dict) -> str: + """Extract new Training job name from latest attempt in Batch Describe response. + + Args: + latest_attempt: a Dict containing Training job arn. + + Returns: new Training job name or None if not found. + + """ + training_job_arn = latest_attempt.get("serviceResourceId", {}).get("value", None) + return get_training_job_name_from_training_job_arn(training_job_arn) + + +def _remove_system_tags_in_place_in_model_trainer_object(model_trainer: ModelTrainer) -> None: + """Remove system tags in place. + + Args: + model_trainer: input ModelTrainer object. + + Returns: None. Remove system tags in place. + + """ + if model_trainer.tags: + filtered_tags = [] + for tag in model_trainer.tags: + # Handle both V2 dict format {"Key": "...", "Value": "..."} and V3 object format with .key attribute + if isinstance(tag, dict): + # V2 format + if not tag.get("Key", "").startswith("aws:"): + filtered_tags.append(tag) + else: + # V3 format - assume it has .key attribute + if hasattr(tag, 'key') and not tag.key.startswith("aws:"): + filtered_tags.append(tag) + elif hasattr(tag, 'Key') and not tag.Key.startswith("aws:"): + # Fallback for other formats + filtered_tags.append(tag) + else: + # If we can't determine the key, keep the tag to be safe + filtered_tags.append(tag) + + model_trainer.tags = filtered_tags diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index c2fbd5da46..b0a55f3d71 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -27,6 +27,8 @@ from sagemaker.core.resources import TrainingJob from sagemaker.core import shapes from sagemaker.core.shapes import AlgorithmSpecification +from sagemaker.core.utils.utils import serialize +from sagemaker.core.apiutils._boto_functions import to_pascal_case from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call from sagemaker.core.config.config_schema import ( @@ -250,6 +252,9 @@ class ModelTrainer(BaseModel): # Private Attributes for JumpStart _jumpstart_config: Optional[JumpStartConfig] = PrivateAttr(default=None) + # Private Attributes for AWS_Batch + _temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ "role", "base_job_name", @@ -386,6 +391,8 @@ def __del__(self): if hasattr(self, "__pydantic_fields_set__"): if self._temp_recipe_train_dir is not None: self._temp_recipe_train_dir.cleanup() + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] @@ -525,30 +532,25 @@ def model_post_init(self, __context: Any): if self.training_image: logger.info(f"Training image URI: {self.training_image}") + - @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") - @runnable_by_pipeline - @validate_call - def train( + def _create_training_job_args( self, input_data_config: Optional[List[Union[Channel, InputData]]] = None, - wait: Optional[bool] = True, - logs: Optional[bool] = True, - ): - """Train a model using AWS SageMaker. - + boto3: bool = False, + ) -> Dict[str, Any]: + """Create the training job arguments. Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): input_data_config (Optional[List[Union[Channel, InputData]]]): The input data config for the training job. Takes a list of Channel objects or a dictionary of channel names to DataSourceType. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object. - wait (Optional[bool]): - Whether to wait for the training job to complete before returning. - Defaults to True. - logs (Optional[bool]): - Whether to display the training container logs while training. - Defaults to True. + boto3 (bool): Whether to return the arguments in boto3 format. Defaults to False. + By default, the arguments are returned in the format used by the SageMaker Core. + Returns: + Dict[str, Any]: The training job arguments. """ self._populate_intelligent_defaults() current_training_job_name = _get_unique_name(self.base_job_name) @@ -593,16 +595,16 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) + self._temp_code_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) else: - tmp_dir = TemporaryDirectory() + self._temp_code_dir = TemporaryDirectory() # Copy everything under container_drivers/ to a temporary directory - shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + shutil.copytree(SM_DRIVERS_LOCAL_PATH, self._temp_code_dir.name, dirs_exist_ok=True) # If distributed is provided, overwrite code under /drivers if self.distributed: distributed_driver_dir = self.distributed.driver_dir - driver_dir = os.path.join(tmp_dir.name, "distributed_drivers") + driver_dir = os.path.join(self._temp_code_dir.name, "distributed_drivers") shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code @@ -616,7 +618,7 @@ def train( final_input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=tmp_dir, + tmp_dir=self._temp_code_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -625,13 +627,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=self._temp_code_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=self._temp_code_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=tmp_dir.name, + data_source=self._temp_code_dir.name, key_prefix=input_data_key_prefix, ) final_input_data_config.append(sm_drivers_channel) @@ -656,90 +658,123 @@ def train( resource_config = self.compute._to_resource_config() vpc_config = self.networking._to_vpc_config() if self.networking else None - if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: - # Convert tags to dictionaries if they are Tag objects - tags_as_dicts = None - if self.tags: - tags_as_dicts = [] - for tag in self.tags: - if hasattr(tag, 'model_dump'): - tags_as_dicts.append(tag.model_dump()) - elif isinstance(tag, dict): - tags_as_dicts.append(tag) - else: - # Fallback for any other tag-like object - tags_as_dicts.append({"key": getattr(tag, 'key', ''), "value": getattr(tag, 'value', '')}) - - # Build training request with snake_case keys (Python SDK convention) - training_request = { - "training_job_name": current_training_job_name, - "algorithm_specification": algorithm_specification, - "hyper_parameters": string_hyper_parameters, - "input_data_config": final_input_data_config, - "resource_config": resource_config, - "vpc_config": vpc_config, - "role_arn": self.role, - "tags": tags_as_dicts, - "stopping_condition": self.stopping_condition, - "output_data_config": self.output_data_config, - "checkpoint_config": self.checkpoint_config, - "environment": self.environment, - "enable_managed_spot_training": self.compute.enable_managed_spot_training, - "enable_inter_container_traffic_encryption": ( - self.networking.enable_inter_container_traffic_encryption - if self.networking - else None - ), - "enable_network_isolation": ( - self.networking.enable_network_isolation if self.networking else None - ), - "remote_debug_config": self._remote_debug_config, - "tensor_board_output_config": self._tensorboard_output_config, - "retry_strategy": self._retry_strategy, - "infra_check_config": self._infra_check_config, - "session_chaining_config": self._session_chaining_config, - } - - # Handle PipelineSession + # Convert tags to dictionaries if they are Tag objects + tags_as_dicts = None + if self.tags: + tags_as_dicts = [] + for tag in self.tags: + if hasattr(tag, 'model_dump'): + tags_as_dicts.append(tag.model_dump()) + elif isinstance(tag, dict): + tags_as_dicts.append(tag) + else: + # Fallback for any other tag-like object + tags_as_dicts.append({"key": getattr(tag, 'key', ''), "value": getattr(tag, 'value', '')}) + + # Build training request with snake_case keys (Python SDK convention) + training_request = { + "training_job_name": current_training_job_name, + "algorithm_specification": algorithm_specification, + "hyper_parameters": string_hyper_parameters, + "input_data_config": final_input_data_config, + "resource_config": resource_config, + "vpc_config": vpc_config, + "role_arn": self.role, + "tags": tags_as_dicts, + "stopping_condition": self.stopping_condition, + "output_data_config": self.output_data_config, + "checkpoint_config": self.checkpoint_config, + "environment": self.environment, + "enable_managed_spot_training": self.compute.enable_managed_spot_training, + "enable_inter_container_traffic_encryption": ( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None + ), + "enable_network_isolation": ( + self.networking.enable_network_isolation if self.networking else None + ), + "remote_debug_config": self._remote_debug_config, + "tensor_board_output_config": self._tensorboard_output_config, + "retry_strategy": self._retry_strategy, + "infra_check_config": self._infra_check_config, + "session_chaining_config": self._session_chaining_config, + } + + if boto3 or isinstance(self.sagemaker_session, PipelineSession): if isinstance(self.sagemaker_session, PipelineSession): - from sagemaker.core.utils.utils import serialize - from sagemaker.core.apiutils._boto_functions import to_pascal_case - - # Remove training_job_name for pipeline as it's auto-generated at execution time training_request.pop("training_job_name", None) - # Convert snake_case to PascalCase for AWS API - pipeline_request = {to_pascal_case(k): v for k, v in training_request.items()} - serialized_request = serialize(pipeline_request) - self.sagemaker_session._intercept_create_request(serialized_request, None, "train") - return + # Convert snake_case to PascalCase for AWS API + pipeline_request = {to_pascal_case(k): v for k, v in training_request.items()} + serialized_request = serialize(pipeline_request) + return serialized_request + + return training_request + + + @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") + @runnable_by_pipeline + @validate_call + def train( + self, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + wait: Optional[bool] = True, + logs: Optional[bool] = True, + ): + """Train a model using AWS SageMaker. + + Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + wait (Optional[bool]): + Whether to wait for the training job to complete before returning. + Defaults to True. + logs (Optional[bool]): + Whether to display the training container logs while training. + Defaults to True. + """ + training_request = self._create_training_job_args(input_data_config=input_data_config) - training_job = TrainingJob.create( - session=self.sagemaker_session.boto_session, - **training_request + # Handle PipelineSession + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + if isinstance(self.sagemaker_session, PipelineSession): + self.sagemaker_session._intercept_create_request(training_request, None, "train") + return + + training_job = TrainingJob.create( + session=self.sagemaker_session.boto_session, + **training_request + ) + self._latest_training_job = training_job + + if wait: + training_job.wait(logs=logs) + if logs and not wait: + logger.warning( + "Not displaing the training container logs as 'wait' is set to False." ) - self._latest_training_job = training_job - if wait: - training_job.wait(logs=logs) - if logs and not wait: - logger.warning( - "Not displaing the training container logs as 'wait' is set to False." - ) else: local_container = _LocalContainer( - training_job_name=_get_unique_name(self.base_job_name), - instance_type=resource_config.instance_type, - instance_count=resource_config.instance_count, - image=algorithm_specification.training_image, + training_job_name=training_request["training_job_name"], + instance_type=training_request["resource_config"].instance_type, + instance_count=training_request["resource_config"].instance_count, + image=training_request["algorithm_specification"].training_image, container_root=self.local_container_root, sagemaker_session=self.sagemaker_session, - container_entrypoint=algorithm_specification.container_entrypoint, - container_arguments=algorithm_specification.container_arguments, - input_data_config=final_input_data_config, - hyper_parameters=string_hyper_parameters, - environment=self.environment, + container_entrypoint=training_request["algorithm_specification"].container_entrypoint, + container_arguments=training_request["algorithm_specification"].container_arguments, + input_data_config=training_request["input_data_config"], + hyper_parameters=training_request["hyper_parameters"], + environment=training_request["environment"], ) local_container.train(wait) + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() + def create_input_data_channel( self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 6e0f8c4ff7..cebd4338ee 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -13,6 +13,7 @@ """Utils module.""" from __future__ import absolute_import +import re import os import json import subprocess @@ -246,3 +247,18 @@ def _get_studio_tags(model_id: str, hub_name: str): "value": hub_name } ] + + +def get_training_job_name_from_training_job_arn(training_job_arn: str) -> str: + """Extract Training job name from Training job arn. + Args: + training_job_arn: Training job arn. + Returns: Training job name. + """ + if training_job_arn is None: + return None + pattern = "arn:aws[a-z-]*:sagemaker:[a-z0-9-]*:[0-9]{12}:training-job/(.+)" + match = re.match(pattern, training_job_arn) + if match: + return match.group(1) + return None \ No newline at end of file diff --git a/sagemaker-train/tests/unit/train/aws_batch/constants.py b/sagemaker-train/tests/unit/train/aws_batch/constants.py new file mode 100644 index 0000000000..c33baa0752 --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/constants.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Test constants for AWS Batch unit tests""" + +# Job identifiers +JOB_NAME = "test-training-job" +JOB_ARN = "arn:aws:batch:us-west-2:123456789012:job/test-job-id" +JOB_ID = "test-job-id" +JOB_QUEUE = "test-queue" + +# Training job identifiers +TRAINING_JOB_NAME = "training-job-20251211" +TRAINING_JOB_ARN = "arn:aws:sagemaker:us-west-2:123456789012:training-job/training-job-20251211" + +# Job statuses +JOB_STATUS_SUBMITTED = "SUBMITTED" +JOB_STATUS_PENDING = "PENDING" +JOB_STATUS_RUNNABLE = "RUNNABLE" +JOB_STATUS_STARTING = "STARTING" +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_SUCCEEDED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" + +# Configuration values +INSTANCE_TYPE = "ml.m5.xlarge" +INSTANCE_COUNT = 1 +VOLUME_SIZE_IN_GB = 30 +MAX_RUNTIME_IN_SECONDS = 3600 +TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.5-gpu-py311" +EXECUTION_ROLE = "arn:aws:iam::123456789012:role/SageMakerRole" +S3_OUTPUT_PATH = "s3://my-bucket/output" + +# Batch configuration +SCHEDULING_PRIORITY = 1 +SHARE_IDENTIFIER = "test-share-id" +ATTEMPT_DURATION_IN_SECONDS = 86400 +REASON = "Test termination reason" +NEXT_TOKEN = "test-next-token" + +# Tags +BATCH_TAGS = {"batch-key": "batch-value", "environment": "test"} +TRAINING_TAGS = [ + {"Key": "training-key", "Value": "training-value"}, + {"Key": "project", "Value": "test-project"}, +] +TRAINING_TAGS_CONVERTED = {"training-key": "training-value", "project": "test-project"} +MERGED_TAGS = {**BATCH_TAGS, **TRAINING_TAGS_CONVERTED} + +# Retry configuration +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} + +# Timeout configuration +TIMEOUT_CONFIG = {"attemptDurationSeconds": ATTEMPT_DURATION_IN_SECONDS} + +# API responses +SUBMIT_SERVICE_JOB_RESP = { + "jobArn": JOB_ARN, + "jobName": JOB_NAME, + "jobId": JOB_ID, +} + +DESCRIBE_SERVICE_JOB_RESP_RUNNING = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_RUNNING, + "createdAt": 1702300000, + "startedAt": 1702300100, +} + +DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_SUCCEEDED, + "createdAt": 1702300000, + "startedAt": 1702300100, + "stoppedAt": 1702300400, + "latestAttempt": { + "serviceResourceId": { + "name": "trainingJobArn", + "value": TRAINING_JOB_ARN, + }, + "startedAt": 1702300100, + "stoppedAt": 1702300400, + }, +} + +DESCRIBE_SERVICE_JOB_RESP_FAILED = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_FAILED, + "statusReason": "Task failed", + "createdAt": 1702300000, + "startedAt": 1702300100, + "stoppedAt": 1702300200, +} + +DESCRIBE_SERVICE_JOB_RESP_PENDING = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_PENDING, + "createdAt": 1702300000, +} + +LIST_SERVICE_JOB_RESP_EMPTY = { + "jobSummaryList": [], + "nextToken": None, +} + +LIST_SERVICE_JOB_RESP_WITH_JOBS = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, + {"jobName": "another-job", "jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id", "jobId": "another-id"}, + ], + "nextToken": None, +} + +LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, + ], + "nextToken": NEXT_TOKEN, +} + +# Training payload +TRAINING_JOB_PAYLOAD = { + "TrainingJobName": TRAINING_JOB_NAME, + "RoleArn": EXECUTION_ROLE, + "OutputDataConfig": {"S3OutputPath": S3_OUTPUT_PATH}, + "ResourceConfig": { + "InstanceType": INSTANCE_TYPE, + "InstanceCount": INSTANCE_COUNT, + "VolumeSizeInGB": VOLUME_SIZE_IN_GB, + }, + "StoppingCondition": {"MaxRuntimeInSeconds": MAX_RUNTIME_IN_SECONDS}, + "AlgorithmSpecification": { + "TrainingImage": TRAINING_IMAGE, + "TrainingInputMode": "File", + }, +} From 2cdb2d47be687ea592af2062946818b2a1a3d97d Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:10:56 -0800 Subject: [PATCH 2/6] fixing unit tests and adding integration test --- .../data/train/script_mode/custom_script.py | 191 ++++++++++ .../train/script_mode/data/test/x_test.npy | Bin 0 -> 19456 bytes .../train/script_mode/data/test/y_test.npy | Bin 0 -> 2544 bytes .../train/script_mode/data/train/x_train.npy | Bin 0 -> 77184 bytes .../train/script_mode/data/train/y_train.npy | Bin 0 -> 9760 bytes .../train/script_mode/pytorch_model_def.py | 23 ++ .../data/train/script_mode/requirements.txt | 3 + sagemaker-train/tests/integ/__init__.py | 4 + .../tests/integ/train/aws_batch/__init__.py | 13 + .../tests/integ/train/aws_batch/manager.py | 133 +++++++ .../tests/integ/train/aws_batch/test_queue.py | 93 +++++ .../tests/unit/train/aws_batch/__init__.py | 13 + .../tests/unit/train/aws_batch/conftest.py | 166 +++++++++ .../train/aws_batch/test_batch_api_helper.py | 260 ++++++++++++++ .../train/aws_batch/test_training_queue.py | 328 ++++++++++++++++++ .../aws_batch/test_training_queued_job.py | 265 ++++++++++++++ 16 files changed, 1492 insertions(+) create mode 100644 sagemaker-train/tests/data/train/script_mode/custom_script.py create mode 100644 sagemaker-train/tests/data/train/script_mode/data/test/x_test.npy create mode 100644 sagemaker-train/tests/data/train/script_mode/data/test/y_test.npy create mode 100644 sagemaker-train/tests/data/train/script_mode/data/train/x_train.npy create mode 100644 sagemaker-train/tests/data/train/script_mode/data/train/y_train.npy create mode 100644 sagemaker-train/tests/data/train/script_mode/pytorch_model_def.py create mode 100644 sagemaker-train/tests/data/train/script_mode/requirements.txt create mode 100644 sagemaker-train/tests/integ/train/aws_batch/__init__.py create mode 100644 sagemaker-train/tests/integ/train/aws_batch/manager.py create mode 100644 sagemaker-train/tests/integ/train/aws_batch/test_queue.py create mode 100644 sagemaker-train/tests/unit/train/aws_batch/__init__.py create mode 100644 sagemaker-train/tests/unit/train/aws_batch/conftest.py create mode 100644 sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py create mode 100644 sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py create mode 100644 sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py diff --git a/sagemaker-train/tests/data/train/script_mode/custom_script.py b/sagemaker-train/tests/data/train/script_mode/custom_script.py new file mode 100644 index 0000000000..a57ddee743 --- /dev/null +++ b/sagemaker-train/tests/data/train/script_mode/custom_script.py @@ -0,0 +1,191 @@ +# flake8: noqa +import argparse +import numpy as np +import os +import sys +import logging +import json +import shutil +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_model_def import get_model + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_train_data(train_dir): + """ + Get the training data and convert to tensors + """ + + x_train = np.load(os.path.join(train_dir, "x_train.npy")) + y_train = np.load(os.path.join(train_dir, "y_train.npy")) + logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}") + + return torch.from_numpy(x_train), torch.from_numpy(y_train) + + +def get_test_data(test_dir): + """ + Get the testing data and convert to tensors + """ + + x_test = np.load(os.path.join(test_dir, "x_test.npy")) + y_test = np.load(os.path.join(test_dir, "y_test.npy")) + logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}") + + return torch.from_numpy(x_test), torch.from_numpy(y_test) + + +def model_fn(model_dir): + """ + Load the model for inference + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model() + model.load_state_dict(torch.load(model_dir + "/model.pth")) + model.eval() + return model.to(device) + + +def input_fn(request_body, request_content_type): + """ + Deserialize and prepare the prediction input + """ + + if request_content_type == "application/json": + request = json.loads(request_body) + train_inputs = torch.tensor(request) + return train_inputs + + +def predict_fn(input_data, model): + """ + Apply model to the incoming request + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + with torch.no_grad(): + return model(input_data.float()).numpy()[0] + + +def parse_args(): + """ + Parse the command line arguments + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + type=str, + default=os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")), + help="Directory to save the model", + ) + parser.add_argument( + "--train-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TRAIN", os.path.join(current_dir, "data/train")), + help="Directory containing training data", + ) + parser.add_argument( + "--test-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TEST", os.path.join(current_dir, "data/test")), + help="Directory containing testing data", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Batch size for training", + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="Number of epochs for training", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=0.1, + help="Learning rate for training", + ) + return parser.parse_args() + + +def train(): + """ + Train the PyTorch model + """ + args = parse_args() + # Directories: train, test and model + train_dir = args.train_dir + test_dir = args.test_dir + model_dir = args.model_dir + + # Load the training and testing data + x_train, y_train = get_train_data(train_dir) + x_test, y_test = get_test_data(test_dir) + train_ds = TensorDataset(x_train, y_train) + + # Training parameters - used to configure the training loop + batch_size = args.batch_size + epochs = args.epochs + learning_rate = args.learning_rate + logger.info( + "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) + ) + + train_dl = DataLoader(train_ds, batch_size, shuffle=True) + + # Define the model, loss function and optimizer + model = get_model() + model = model.to(device) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + # Train the model + for epoch in range(epochs): + for x_train_batch, y_train_batch in train_dl: + y = model(x_train_batch.float()) + loss = criterion(y.flatten(), y_train_batch.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch += 1 + logger.info(f"epoch: {epoch} -> loss: {loss}") + + # Test the model + with torch.no_grad(): + y = model(x_test.float()).flatten() + mse = ((y - y_test) ** 2).sum() / y_test.shape[0] + print("\nTest MSE:", mse.numpy()) + + # Save the model + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_dir + "/model.pth") + inference_code_path = model_dir + "/code/" + + if not os.path.exists(inference_code_path): + os.mkdir(inference_code_path) + logger.info("Created a folder at {}!".format(inference_code_path)) + + shutil.copy("custom_script.py", inference_code_path) + shutil.copy("pytorch_model_def.py", inference_code_path) + logger.info("Saving models files to {}".format(inference_code_path)) + + +if __name__ == "__main__": + print("Running the training job ...\n") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train() diff --git a/sagemaker-train/tests/data/train/script_mode/data/test/x_test.npy b/sagemaker-train/tests/data/train/script_mode/data/test/x_test.npy new file mode 100644 index 0000000000000000000000000000000000000000..a9977e39c042722d4103758e390eaffa6653b937 GIT binary patch literal 19456 zcmbW7c~sBo`^Hll`<5&jhJz zk7E7#4LaAjXm2B9-L%`T3%5s>>DPfT>%o`#;eQ29c;UJ5k>DPnXZ`Fm`^0&W=RCPT z>?8i{hw-pLn{mGvO8T?D^5E5fdR={^zki$?^RN!~i$1KId6|oS#+Un27`pvYJ=bih zBJ-OySY4X7xKiRrKYXcU-K-NG?}_d!+IcN5fcfdZey#=aJT654}SiTFX_WRpKlYz_xi{=T`%OQ-L+Cgc_K|+vEB1>%teZKwujKh&@DBNo@J^VM_emY|;s<9A z?iKZL=4CysgWoIql4svIZ}x>b>CgGGe$Ej;^61$Q)4qT{Wu4ZcLt(o4&FbU{bRkXi*>NCtcN*R58n&w*zA1B`W#r`_n~;8P~w+usV&y~~t7 z5dkymbTiaV^lq@D&OJljZ}-nGns1RJ0+MI#?pXCw%-D9vEI+$s&sEhrAXGle|XUd2?Ml1k`(M)a9nU(JVUD zdPS-zT-ww)*`ri`GkWown{(fZbcfxGjHg@|-}>LPZ1zX8wA^-J&WWidVr=DCcYi9+ z5L=Bp*86$oS=l!6_@ASH$`j%F6Vo0Kxg#EmsSat`VY0tTiA~ zgD$?>kS}sOjocO)|3M1tS_PrA3Pt&^ONtu@N6H<^>GxdcCW{dRC&oG4%n)YDlRJ!= zlPo5U?=@rB(iCaC?ODr5H>Dh&X!+F_tse@vML8A0JBoyvOWb$PV?*U94H}fVuB?!5 zZX32#S-llE>o>nPL6;=MOux9}w>(WO`*Oc5PJbi6zdQZEhZ@w-*(^9>+vxmDQR!WG zy8Forv3lGX|JRKZOc_eDwXz8~|sQ zuc%NNbh~k8ZRcXSCHscsmAD8o*YeAOqfOt578^$%5`7=YZ9_L68#VZe*gfXv6Px`h z;ze9g4QUxG_jikIvc5xsv|E{Tw9e=Z+17bR^I6Zr<$-Q}ii|%`kr(TDFX`+TD_b8w z(mZ6}TVZwB@X7I7S;G6o(Ti@*&&0#cCEvW7@KBfryI)z=C{8A=%J^#3+9Wadii_bM znIP`=4%)uxb*k8P_2w^YXQ#_)p?#j(u6QVB)wgVG*D_4>-touO)~$^~cA8k3^<5ez z=6>B}Mbyy(!-Kk?ML%xje-qbB}wXZ4kb zVs^Bt(TSrjVorSfoiz_c$*RJN%JBQ~V%5(r>sr?@k{gq6c3D63f$a6H``0#OLgdxs z2@?jME|!5$Tsys;@lsy9>|~woogn*$^$hhmc~(ZBOt~`0%S$*1CJ!F|_qlhgTXV(n zq8PFH^-Yo6Bu6-3^g6!2YlSEudU~;oX{K=4bip@j*LJxiV(aj@fjPop-r$Mml{w+S;5-G*+{d`-eir~)*KerefCl}QD)pkYG0_i(;%BFW~-pORED&dE ztuUCnCSQhpI1qK{(=DR@LlSX*K%M+Gk=50bwf7kZhX^r>ML=?(J{2`FJ-dIYF1uFqgYX8d-nI(>$k*@ z2ZtrBi+Us0-dgsiyK%YjKQq+Fqtk2Yk#a>;7Q7c3>Gs7Yc1fZ`_LR!cHm8Y$rxQAP z>2AsQ7t`v#8D|)B*fV11g87$arf=KegZpzuL1X)gQ%9D{^tS`OU%e}otqgLmnQwn7 z2krTCv}Ny7xoplQmvw_*NYiwme*Herk*}^+Zdwo+C;V*fy3O)14zV(u^T>QhpgeSI zTxtB0GPyD|*6GXe1mW1fhY&rU%XS9`E!h2In6&YIP`1D(Q^XzXQI z5Yj5Xta+8ao7insY~mbqM2?9F8dvYT7%}<5wV%$F6$pQiKT{h=XNq(C8t)ltRVe0X ztpBg=n|$eHX>E}kSuVRzzBZ%r<(D#JS%zaT%t>_6(lon zwLI;6H&u?^G;~-&6P9+e|YhyC&N_qGY5*5otO2Kt4_ zZG+?7YhFtb zz9w!jIliqof8<#0RKk(<0SIw7&=oGsk<*&DSde<$P0d~wg`efkc|5X8-H9c&r;?IC@>-zUF3wD2sI=s@)3 zLA4%Eo_hM@qsGybhpT>Y>Otbvsr}&SLE`jP5kKOpo;-f=kJU4`iu3{TMMs=IYMyz} z!KsH+hYpU8IDO#cK{Z~TKYh@t{;D4wJzUF2)xpt&YFu59nx`*u=7!^o4{>zV!_mRj zdh%)pCypM}nhU*(_)({oCr%&chlBXQLG<|07p~^1!;d_5AbAi!@^Jj{frIGC z!$JJ$gN}MNPaMRDdX?j+6{k+^L!5e$zSSXqAbucz%mt$V*u0h>^WldMRO9pq>5m?+ z*1^eFr^lc9;Gk9d~VERULll)jS+sw3kbyX~|o0+1zfz!Q-hCF8xo%-1)vl5?_CI_^Sm&t_u*| zqXth$4ZNDAkCUekM29bko;)0XIC<*uB@agj2hoA}z*QYwMf7St^AcwckhI0`9 zjt`u^aF96bAP(ZI=Hc{#Q%9V7wGUj?sr7JGPaJRi;5SM_lG;c6XueCUIoI7nXQ_>u?l0qF;-KE%=E3rCL* zPF@SudEoe~KIHLJ`=Ljt=BYzRA9U1xEK&#JOCNmf(&u`1KTsoNMcrASn0C&PeU=<< z>)z~zNNwvOZI8Z{!tF_Cug zl)u!SzAs|bJ9(=rv>;(^rtE&QZnq8LNpjHd2j`FInIb~ldiL14DP4}8Wjkuw&%fz; z#>xKc{*#m}hUi9|wlI4m27lWm`)Ru@(JQdGUrC2YV!PWzyBaGJ#piN9rCHv0XSU0c_eZ9EapF+9_-@isxp{rFRW5e|};exr@_vKa{Z6bQ4$;h?48h#a2BvKq(4ZLtMUm6-V zdDCmnIg!2iL{*EZ4AFj9NWV+93&hTmSxpAaD3Uv~c2706sgkVI)2e;krip+5`+-}c z&r{@j*d?gkyfpX3w+;; zE5~!!TfMlfpNIR9CMLYE{QKX1*l&-2BngF!M-pz_k??i54uNji#Is0e=JwK*lY=yQmnt%^xbd2CLzE%TYqnO7o0EW zaesipp2!351ow&k!XF)Z=HonqPmLQfveBRVcgy|4mvgo5m@p!(vP^&9@ZlZf&wT6) z>-$g0oY3VP;zWy*u1UW}X6e5NQHxJ4|Mp~w{@$fm-Q2kIK&5_PI7i+A??0-QwSDvb zRf2i>exPUH*#~sQSr7B;=j|A?Xrj#r{e5QM=;?Q3wj3=Jtk0)Uu+OU{d3ugN_loc3 z^}4R<1KSzsxW8vVw~IX+dt31RK*x86`^7qVAAHy7%l`A7;d{pSl>5d#YFasCcqgxu zGNs2d2lJDq`tLGxpl5ybbG7w3R(EuTD9o!oxMY+;2=}GG?E!Z`<6_C2ya)E1IPc9N z!M)qvg;DzNqI#eC{%|fP2G3(YIMvX3R=AYq4Sz4$C+-pZ%>7~ix&Ltb-hNSezO8ev z9$cWxmA+$ zB+mWhd(HXi&*gvL3w{q$$M*Gq99*dX&hlRQU1mP!BG3BBvo89uo&`0nW(0&A=s5Qz z({~jwJ!ACy$ayavnY(dQzhtT2OV0tHI6F)!k*te*#BBqV zzL94?_@41@{+toy9lYgold9m%7278l=_1u4Li|jx4mr50>L$4x!=+t_+ zRvtfeaC|}P;aW(&iqwPD!&SY?RlQc78mEr+ktYt<%A;c*Ej@8nM?L;fx#$ zjxY7Z;nbspt9dwx4x}C=uIg2ej=A7ksOspa`l@mAYQ0vRdXPF5sYeH*Q}e{pkq3#Z zsOIq@4rkxsst<8hM;s(iz1jyIarDIDYCU=twRFTm{Ha&t^aa(tT1TF^mOnZ*Pao<* zRS!q6ay3pL5MK~Ic{r&0t2#J3P>tgUsygBz^Qk_>sRyZpgVd{eHI6TF6}9T=qm@^E z&}$)j{NU<*)T0B@fm-#%sROA~{m~Owb>xY|@r6^ba_Z2-LE>s1exRDCo;(~x57*ME z`K!M!Szfd~UxLrOwp`Y{L{ zwfxk$%2i*LqgUfv9DQ|IUA>xD{nWh5RXrRX93=j+s8y#G$Ctc{TKSKSt2#JHKe!gE zIym@m>#D2A7oCb~9uA_1gTz%N50ZzgNM1!Pof@x>Yxxq_swWSk)8fQIExj5?50Y1r ze0BJ5>+t=j_2|`lcy*}Oqj!=b_)G6Rsbae+P6q?}<_nNKTnkm5%Bia^PF_XoRaA9w zbZ`~PtEiJYt()Deek zp{j#xq3VNkfqIyk8Msc|*0a#askQL9c% uhaRj>M;^TvqEivQ8i#A;iG%1tHLk^pgK9l-P%Ez$r%nsiIyk7-5&u7>lRpjs literal 0 HcmV?d00001 diff --git a/sagemaker-train/tests/data/train/script_mode/data/test/y_test.npy b/sagemaker-train/tests/data/train/script_mode/data/test/y_test.npy new file mode 100644 index 0000000000000000000000000000000000000000..a7191945eea4a7781c091a3177b77db6bcdfbd91 GIT binary patch literal 2544 zcmbVN`9D?pALd%`e$HJHrBpI@#f&EIlN)18(v(OV$e=G?FO3%;M<&gXp2c|Gsvc|OmR>#=sFr;oBytWuKUmY}FW znc-4>!)3vr8Jg%D28YX{WdS=kh0C@CDfXQLc0>hH`%$3*5kb`RqQw?gCPpUuyY>I? z#Sj~BvtUa}v>WM_PYi5_i`MyHz-08k*JxzfAynB?p zRH4qE;`^Bp$S^vz!n|Gwou&s*jUUp5#M-~c#a)2N@~kJJrv!+3S-Ffmzy-A82khEO z6h+uE9vaXw+FE9u{|$jJqvz*Mw`fA+oyggIoWCe}f3*q}@aqj{VGBkR$C@;;-FP&wqa+1KF7A@KbYr@k|pJQSV?^`7e}!s7OnKe?xAkSi70L|TavpIB@n zl@svdw!YzX3vl4slMf+#2xut;Xd`EEsq*0Y!lli+k`MjB0bhq?0rtLHnq5_;1)}Fl zg^D48z}Dzkt9Ui&eHOA}>!225>n}SVyTit++liLn<&tn6oV6(_n8eg*Oy`LqU09gc z*)JU-QSv@E{a%_L=8$XVWjy7h`*RAog$QGx3)Ci6g^k-6GdM_SQprL#dM1pb*il z-~9(a3a~J{b?I4OCe#`ZSjL?v(SLf4w>`!4_WRN$HT7KFBARNhhVih2K5;5!m${bM)_Cxd`SR#GDTVWHb`RZc z7Xx`&;~uvY87O4-Zg{nXfR31xzD1dZ+{RG7wJU`VYQ4|9kJDiwY!s!PV9)sK zG=Z^-@+C%Ua!KM@>CU~^c!=Gh+8uLU8wOV8-P!wjSZnb~8NY?{&jj_OI$cD*@{UUB zR6)KE?^MZA9)67)>^l5hjFAU*>*yjTdM|$eXjeNIS)W=RkOzp69NpU-GQq^-W0Zky zxOkNi5%=JN4*XO$ELblW!(=GsqR~SUIOB2)hpRlSn0(t;XUm7}W>b5$6gCvsPX368 zB$vRdoKFa3`k3#PQM}n!KPy-iA%wEUkISVJ9=@Vo8F;mokNu{qwu=IUkc5BY6;Lj~ zg?r9!FHM#4lPvVm+gfeNM59&p&*T#K&IcXKjws>Sy&%h%~`E!Gm`^r4)w=S=bq^*A-lBEvwb*s!~De8)viWE;f`D ztB#LZTMi$pwBSMg$0Tl^9~*Oh>Qr~y@nEK_`3J3$K;fRWEMbf`isVIqmr8~I%_G9> ziLX;%`|~j}TryACs14J(b3W{_q3)5aaN{@MX~N!5)^60nK?$qdCBDh{!)fA{sfO2PLZzTuQ&91p#Ee!awGm=d7gjjvu`lq!H6rbms zZy4X?;%RI_wMQTw$N%{BNxPO1O)32&1Li{X+4k&yzFiezwt3xWt=KqzV7s(oToo*u z@jg+t7^Qw)8_l{IXm?VHXBKf0u8rR=)t&-|y`* z=%&H#+}y|m#l4c}H&YLqbPAxhv*XU1jT#6jdArG{M}Qg+1Jff{m`DnCN%=uXj2_PB z5Qd&E>~9Ts4)k*|9`Q$H>3uQW>)(-<#Z+EBvnsts^+;PIxqZlm#I!2aI1x%1{&HVC zD~?3QrE!5$y#Q$qUAHS1YNCoJU3aC1fmZK>!6r2XC^huTB{Wn~Epk_f4*jI-|2*H# zhQq}adV8V>E;YgL@7TPRFbW4$4(brtRvVqGm9C8cT~~$}MhP)F%Gv~q$FeKz785hy zPYFS(VPHBUghPB)<5fJMbfw2#_<{Ho%kGpW5UlE9v#Z&~NGbZ{)Sd3?-=2K8$JPYPMwnSLteL!u~( zT21Vmcu(uEC(!9iCBkMF3_>Vnb*Y?J9GEY}%}L3=2`>TS6hOmBIJXD(Nb`7b4(ooA z`~?q3I1i?6-C3Xw)g=FFCPGH<`PB1X+Ss;oIA=|j`po>9I(U;ddDOOp1l7KNGCv+# z7%ZQTOaZd@Fg*bmhZlG!t556OIjN=Wu9u--BRf@;zQtR zPK58fc>?4pKkn;z#=}*E6QbHP9MBhsO*c0);dqTsg;XX)CGuffh~M4zoUjW>)i0m2|OrjDTeZK zHdDjlVy6%}vwykL^Cu05D}O50n=Qnz%>|B^*AOUSI4y{o$4A+=v@`!Gsn6tLH3v*Z zK$o#$r!dG{Ar{)qdAPTk0UHHj$~?rmYV=K?AmB6EqBg=NaVvdO?`!JZ4}DA8ylhqQ zg=~V`BNjjytw=^qOy$?_N=Q{jKlM7{xE`e4v)$%UKE2lBS9(aq#$+<(yr@|SH{ZJd cL8KUdiXhm*K;z=jH~IQvj2)(mMyQ5=0d747vj6}9 literal 0 HcmV?d00001 diff --git a/sagemaker-train/tests/data/train/script_mode/data/train/x_train.npy b/sagemaker-train/tests/data/train/script_mode/data/train/x_train.npy new file mode 100644 index 0000000000000000000000000000000000000000..d267502e654e7274fd83cf0e9d2f9a6d73983b3b GIT binary patch literal 77184 zcmbrnbyO8v7xoK?-C|>Ri-{h^g@qm1-FR%p?(RhF#x86zk+sFb1_VTq?(P=3+;REs zZ;a*G!}GoO{pT5T&SySzuDN1w@Vt6x@W??!$Li}X*IUzghSiekiyQw_r*YSrof@~O z(|G2>#Y-1YT`+m!;u%)|uYbhiWmfs_C9|e3vdTA`v}tS7zD1o*O z&X_Oo`7LsWc6qz;xT~DJJrbd-EN989bbRAxh6@LzhAiW?1913I((9M zi-xa7A-6>(@6IyR^dH$;lYFd@`K9M^E(Zt3%6<(yY|fh$E5Mr5?SflnL`v6nn`cdX zm>`%(_vg1do*%{_ErUFxhqKNr%o4!vy|f241!nJ-=&pSr2~t$IrTpt!O>#dX*b; zS5v2?3Qk2e^vSw^itZ`Pd|zqjgH6A$K1G~5T|7##PU5WZ^{=2G2YTq4KQbQl)oW6) zcHYdH^)hefhfUvVzx2KMc-C}Nx8D*y`k{{n$M#o0+&V?D9{Ob8SV!*$A39BV;i}y~ zZ2CrLP2Z*wV|}s&>!Yqy>X30Ix8-Q}nRTE;|MZ2Qd9zRKFMZMHy`B#Hdi$hm^|7f# zedahy3u1l=#PBzI9KqQ2kW9w?jL=yFXW-iJ!8L#vk%n8rZ47-KFB_^AL!B_ z=ao9-v5saHyWB}VpCIAWC;J4S^{`K@n|sBausQG4=bU5HPrWXk8fey}OVwsQY9G|+ z-C`Zg4ITO>&iw$HD|xa(|Kq=37-+EBPv*k<=?g#gSTFrC5Bh-5`Z=G((P19+$9{4z zuN~S}qP|;}_VWN9b5iS)N8hYJXtkf^=tG4x1o(k&LukJvmW|I zhq*CV`edE(LFRxy`@lKpd=|Z4Y*6<&eRB{#^{~0$=yPtuhb=o4Sw&y_d4eA#pLtS; zxv`(DpZ({&F(=~avJc!l)=S^?MP1GZdf4Ql&;FpR_Q%}tGY{TV>d_zX6n%3)=!1QM z&;FywyUgdt*)Na$FGuQYSSQGRK$ku_C(IKb=Z1N+Pn>h!dHC>{4|OIs1-2>tw$<5BRCae9>qAyo;<0UDidMby0_XV2;FD2lL_k8~en1LG~XV z`emP32YoXS@~H>1Ps|TJ{M4gA<^ZD4`ZymA8XLYT)~)a#_k=ieV_xd~*l*TJ9_NC3 z^aqc*!DBt_7wceN+-v5=x|k2=kv#N?(?2%I`N4($}mwQgX z)Wwe;eK0Tf10MUN>hpaQed@Aq&OP^lIDFm*)(xtC!RP(pe38F)W&MVOjPy0?TsYsX zhd$Ux&J{YWk9Uo^QlEQ(AEZ9|=r9N7$o*p9sH@IheQ&5w9&_Ye&I_RIccJ2kgm!vQJ@aT{8$$8@bp-Uco_K$VLqhEB0WAnX( zd~`u{=$ARtAN$Gr*;md9dCZIY*lIrde9qCAs>7VwC;A0Bzr=YbIfr}>(kJ`OJa`X? zbFY{OeX}0?)FaMZ(Bb0pXe|b_Ko{PK0H1r=$C$&J2vZL4)jMJdaR3goAXT_=0O~r z`EY(YFZj`=KI>y2m;?I&kNu(#`lKK34Rhk2VzXb&o4UL+)Mx(0nKOBO-{W3r^9$}F zb(lB(p^u;Y!#c?WnG=%COP?tILdB%Qoj+hts1YP=IzRVL{`1H^G)$ebdEB2Xp3qI@P`v&KkcYyWL z2X*1IUi8U_$M<6DqQkmbC;Q7jVpE6tlgGKCE{GocOnv6UJt5BTg_$Go1N!PYfX6vS zk2&x?gLSh1tcSXs2kO9M{rp~n_3}=SC_0t!1)O*GGBF=eKby*jEa~_xjeBL|m3-#cUPn`aU<7Xc9M?Kcb`Jiv+%Dl

ltzoH#ekh4s-7`^WbZ@_GM=vtHJ}!mR)JJ?=l{!SN%H z+q?P7Q=1+I^sF2!Vq4bQkWll12>$)(#+{Hj@w(2q8D*v>igE{wHct5HEbleQ^SQGk zRrc*?7~gD`yWHS8X6W1^@5Q4#Z|w$_zbASm|Ejra*L|_Yc*leA%vf3b^NZb$X8X&| zPwb66TPBJ-J2O3Zj0u&cpN_lxyGe|+TkIFpd}o-P+4h=c*WKCjm_t98gVo|>&2tae zT*&?@9mILJpP@hH#<{nvzK(DbHyZct-}{+|h_?Rm`;ASI7@WI#VV^?T@~ng9zxy79 z%ZZv)qY17_vZ8-P*Wq@bI7Eys)Q;k1n$S<9O%;FaK$d_{iw{4k|DlH#W z^v7yL$F>Lxn-RMi-tMWDE1gA8&T) zktcgLDW4oNTF<=eQ^S_ql2gRz<3k&7x5|*S2JCxrvVNwh)wsn(fB(0#qh(2-lcU1K zgs)Sk?ysFLyjG@GT>Ldiyp3P~+hAveu&tYPFzk4UC~M=gS<~jKaP(+hq<#K>x7H)~ z)zMKcQsuYcQ;mu>)iW>rWm9=+9wZwzt{&cNXtub!zFbJyRu^f~qT8IkyWUIpJj>%p zm&b~>-y**Gv`P{y4t4mJ^g2qe*1KQ5@t&WuQ>NX4Wd9gZv~?k`#Seqz)=n+MhkOW= zeb=^G_T_4@+_SV%y%_y)sdu%if%EftXSi)`bf z%MP3kl>U?Q4$di+AwSOUFszNuBVoO5X&2Ksf#TJO(EhuxdI*<8A$P9yN|cZLnUCvx z?718%h8bNy9xU_<9eTLKI#qNt9=PJvq9k#(qij%ZL6nGJHUD@}i8$fybY$%4OHQ(C zbBll%CA`Gvxb}xr0XdZ+LBa=4QH>9-X*wLPmzXT&idB z$EAJ6x1mWTu9$uno1Q#5J$JLdX78TaoBM7Ek%RL-zKL>;7y745uIunUMCL5KG%|V1 z2RY<^-0OM|pNToc!)v!5cuxdgf3+|@AV7o<+0x?F(zo*B$>$d5~;&l+zE9StSqdy7(ns%n9kb zds+WvS;}VRt(ejGq-m$2rt@wk%apqBZVzh|Chy!D=hdNUirCiceU%Cue#pR2ZF;^R zxl48(wSRWeiQY2sS{cJ8OKyt}&636y&M2&TZ1TX>qDY2((rVp^fZq4yw%1qYO_VwE z!W5yuWvJ}%X#D$-grdM{77iF+2}6u!tI|d z**ZE!*e$du^>voFSUBs_z9X3injRP=+=nkq2FM>>dfXXM zDOQZhE21B>{+BQvGHq1#TVJJz<;rqNMsec!h3PGJ2YrwwZDv2!^#3Aq5~pUkl*o~N z1|&Iqd8deRc_tY%?-n*+XJ2yWxKSz6-tzW{z%m+n$f|OKSO_BH(e7T4QBQ{ z@3SOBdLN!TeCX{+IijNGLcoL{vT*aB^=wYMir>dK&-z~8R}8N|+%U0arkoUCsQU7I zC&gNaTNPVH#ERpdeJ8(q7Al%vx$(thXt4bBZurHqi+se|Sx^3{u-#eg|5pCu`-O?Z zqT3M5F84e|&l|VbTs8@o0n>7qoVy+_M;^Jm^QLX6%uP6LKDv3L99Vho5r=v|wv(LB2qDMnd!uQH<3UD4m|&h_c1VuhPi$Ltv?DWa33=hodunev!kgU5RY zyUOwxx=*_isu8R9U!3$i|NXgkX`xYC(jDox^wE*cgL37)TQ(C5ug{hl&8Jz8eeqM= z{biZn_DZnym>1X6rB;YsGrj5eIs>+fiQ7yLtiP2b=Q+M<*ugkM#Pn*`cKp2{`8hWK z{ta{yzar`dPIL&7t^;emI&B>+PM0*PR&;Bq-1pYu>&jD?khPg3Fdpl-*ovyRQ9~ECwAqFrjDZL}7RNNM_Y5@v_T1 z?-mI~K1l24O)H&fnjxbc((fEM(=)Gs{L0`cqcC}9MbGT#WAa4l>R}h}zX}p&YcJ3I zd^AG#jc!r3qFJ`=lU(=qhh16X_6=|Ie~sM5m7S56dD*UFS;zAq1KXvD+?WLyKR=0( z1_t(3<{0MdTzCC8Nc~2XbM_(1r0|b7YMd_j7x6Opv#R=KeGFnX6di z(b4#o#$6uh__@Z*(V=p-cg0R__8xNS=52{rW_yccqr>CRZ`TOZ<COuW`ZT3YV(JkAu5K%gW8`+g5c>5(5m} z#{~Rzkv@w%fAKI2mo^71-x`jLl-n(mM)h3kDXa7kcAsw$Ex!He=sYRmi!ds?w#S9p z339!D*W_@=AhE#nv6DpwFM02+xqnjr=g#)^1Dn+v^GFu!xOcW*Xt?aZO;ahZW4zcr z{H^20UMJ!MB{H_!oyO{RyKPfzSP~< z`M6z}cpSOeIL-T;bh&@JyMIoyyk5A9ZCBqMdA#iz%_hAca^e)*fND3^$hbollm2yz zkzqv|ja}M3Rr*DSxQE4j7x7(0-xEa&Y5M$1+p%`c13ATi$OMCyUu1~$oeBf`CdvbMfx@;=aMqeCKc$C3wRMkgeUh~*{qX4-_+8AXSlp&zLXvzk*CQpq-Zyc2?6!-E zdQU~yCGvbV7gzao;Mr$4PW#D&*DAWYtV|G&eqVcU_HYncZ4KwotaD!8O}aR1?nY1X zX2c_x!|xx;O+jB{!k2i8ZUnIu}54Xk~8 zWxR+TdRl+_v-o`8VW&p@F33yncji`{>n`VfyE@6`c%qmyb%M{>wz<+H_DH=crYYjy zq*+z%t^`P{n8hETZiyEwTK!YDkX4NET-5wvV3lA|diDH&M*jXToEq#kNyvOG)@5uT zIZZD`CXLzW_$p7|eBB6rm(D}e#GL#W`B9%#QFuw(3k$a(QK;YZmwwGc#r*EBNdvzm z%QIKb4QZAVD!eS_x%AqWDGmknD;qmDRmLB4D!I2=u1q(1;W#ZbMYwMtzp~*`d$G#S z`ox*#8KQNK)kjB_P8D-^^xjpePa*TgTY{_)emN@J#Ev=Ca6pv2VSemMj$gR={9@JX zv=QENnawaW*Y=@u>G=6=o0+G}#f6I0vRDx%-(GQWc-O*TWIgTn_1^=BgyqJ$(XXz& z6&73iwr|tqvON29gr9F0FY%#d_uq9IX3McN7V72PPm+N(Zf=+tohFSsn{KOGCPYqr zRJL5?%yiN9lkeW$#ed3erjJ*>Fiw`q-CYB`Yvc%r=~*Q@x_OI`RmVStRCp}zJ`Fxq zH8E1WKi2#oyWV$X`9m=XVeN49@D(mEvua( z)62aqbkinD46@$8+RV~MjvBD%rfGV#_&s8{O{KthGH=_QMrQf zPZ`_MU*0kqV09%UXT_ipU@*=A?7j5kd!->`^u1o<%^gf+S&ecwLLFo z2P20V;}Y)T>Y_nY_dOXYdzQDTsk&z?=PcTzq+w~e6$!_WKyjWGvb7K`ARwc zs)kDUjqe&?KJ!6d9OxNYp}C(Jv21@E z=2yL%A{w{3YP`xhNk%>z;&wLDRXFUsQ_(;4fjoTnq<_dO7uodJoHCb}Cdo(cOR5;` zj}=Q+-TBukB3jU7=F{p5oat8W`^3y_u0e7Eo_lPb1W`)J*)eT3}dy3@k+Vv1~e z$gO2f`v95a6gAd#eyGqdUUgWHH@?zsz?j6o#SG1#SlVyx5%Ep@2>#LJ)}$19u!V(R z*OifCR>u|}Yya>NoBrK*NUhw1P+v^igXH=;U_P<)L)b%GYx~AA03n~ z?)~h4(eBzWS@&SprWSh1V%NM^Nt)Baa;1UQuoe#A<-(8mcRXAZA>@WSG3|W4WWa=d z$A?CEif<#2tnr`bCn7dY?fEu+ndoj8cdo;-BIaYOEc3zMHbuhQ?b17u<83hBv}206*s*8PQhP6meZ>bPt*z!M zLOi@=xk16AO0R}z9*z#rKhHshyL@mHPeKOPdvhaGX!jvpG`d&kKHS+jHh zXP2sl=l?8HqexZHC>iWL-J)&$FOlNch>d zv@a1Z>Mi~{t;XkU@$0Z*s!4}}uS-^~J(a@%ETuN5kD@_5Qv_w!vN z<^3IF980u!6*ft+YcEvJ5tCaFyE(dHmUy&h`+I5R{p$f*;(wS<92zijZM9(5eWtJ@z2yKB;=-Ld#x z%}OT9O`W#(zHC-plVWt*c6M5tm{udT@A_7eVt>-tQJ3@Vt#1I^jilUB!H2^RH!H?^qnC``;st9mQG@@MgTu-izd{y*fx!o?2h7fBWC z`*v&krGL2S;2-W;CjVU?TCw}0Z3`aABi5bvP{kiD93XEspu(E1r!wxM)jbDV;6Gm3VKODgN1N);Z&s zo@R^7zXwi7M~jrnCcf3?g-FxWO{=skre{82YGa#h`z$%DzrMYDm3T4elwroD1yRz; zx8+rrEzd;x>iw1?tnu~`-%?bG--u~lw zd+XH*(eL=--~+CCvSi)QQzAB=kfq$dyg8ctTErAiU*>*2Rod^mwY=+GPkH)v=b5IB zgXQpFw?c<*|00vNM-(n|C|oY^%qr<#%tidzH#)0J^(=AZU%Np;x8kJx#ln{>Hq4dJ z_s^emtXh=V>ta^t)RSnbcVX1dq>FJf`F*eZXIe#zM89Ph#~r#XJx4X$=2|67PR_j+ zxA%OmTykv6*<~FIX-2)OeyDd~usl_@yW67t{~x2>Ah+{wv3^1?K4NtP&tJ0qjEGIw zjU!}*$<=RLlzAt5ma!_D*eyZKdAgx(|B-=m=E3zF7+!pX|X_{!%@aeQHP=VMNR(&Kzu z@Lj_+k>k-LWq;iw=5=O`oj)MyQZB&yqesRD+27si{Q8N!R|N1O@F6|fKB}lu6cY&p6Fpbsq5J^p})M| z$F55vMX!impETZg<)f^)L8~mX<&?RG{X%XVnCF&mH2!IaNO?DP!LzJ2Uxl%c{*Ft} zqs9CBW#pP85#n%|MQ#btMA2teyp_@0BvIAGs>0Ooq4HS~-HS5kguV=ns z`lfY_`}oR@a~4c=*#1Zs?VC8ha=%2eC#Tgx$51`<$$Pq#A6+J1ggDOhtJ*bGzKfeO zbBT4NxH8P5-t#3d<+7m`+fG{s3g6Z3npvH_Abp=49_w`^NH#1r*K5+fXxZS$gMmAI z)5VtQan&2U-W9olM@P3e4igX7T|d7uBuFmkm3O}O)Fkv1rGvYLh&tUPe_9VpkrhUiT5}@!mo%z>_T+6( zAGx)N>#p)$LS>g}AB^35M2OvXwfhxrA1JRDZ)fj!F;VV)@b!&P=S+FKg6aGDEh9wf zGDa=?%y$)8v8ktxo@hj|LKP3>c78AS`aMlb?(|b^e`Yu{@m{pp-0S7}KE=|7O|3^m zPS*?(EuU;?dhg#t<~C1eZxbKi%6nUz*(T+G&sx=RUF`MWCq)Z`XmOI zD^0#VV~`=1B!$cxT+3O0^6mD%MpmLM*R;Widp7y+!vV*WOTWg8!ADBHn$;swhIae# zrK`P^?|b!1sy;VMdJeR0U8mGFF=uv<((x-@rOmlbu0JiZEpa|Y5b|s(ml!bm)Ec5ndqmc86rU}y2-QfGsJN~lA_mO92fBPz|yp4+;AEIwwJ#}cQ8hheI zPzeuBUORs=eMT!&^Zj13F_qOB1#ixKu zUaraSgs1WE%du0lrFU<;l6S`j$rs;d&rZvK59(E3WS6`#S-7`oc4Ls=1DW;d-r*_T zGKIag^{nNXDB3vtwwik6m)Pu*dpq+=A@lW*60B^4J!IkJaVGB?eG^Ca3@VmjmMPcy zb@_dHqn@VJKHq~!Qi^Dv*p{2v-OxugpZuhk&+}vA)`Qk74NIhoym$RgPpwat|8D>0 z<}m-MG^?Kap!SYn*gCh+3W2GwTY7E z@tL2^wkOD^*0CM;FY=coc0Su>{U}-Z9ehwHb5xv|;+nhoe9>$%czj8_*ZDd@VTGLP z{Y(}`ofgT5gT6|qz~kqO+GI+LE)5#){AQpj_Q*GT%at(MIxpL~^X~w0cv#72TO3~r z7x(B9tzO26I>yh8$EDqqHJ2^3Yt+M4$jbEwoqCZd2E-?hD_he~n%}lNZekNHpIUgD z7|#roV^)o9aAG9gR7Wal9#lS9Y-ka`D9Q;N-v z5)LaXRJFL1Bkg<|8}3ZgU4!%~FlKWl9B)@O^(yC0Uno)#rKewb~y z&ofrWU$HIaCqB!Ub@nwnv^rHh+&^M)x8%$6mG@3h_g}xoulAqEf9?B081>fly{{i5 zPh05Km{2rc_Nmdi^2Z`B(r=v4*}#ukqF#HOeV5h;3VYGO;7EqEjDP>xq|ubmVt{YG zLN!|D35U0zZqGaTQbyQ~a(EFFE?=*xospUrC%dIQ7+GvXtQgni%A&!}Ve-bM{ew47 z2oo<%)>xdC(c)O7%Y56YY`Lt2SHcaKEKzOpr_24C{gmUYTTd)}EK=0oGx+eU4!_0M zUfcKY{N^iLwAi%d@ZyWoso~z=XY+q=@46;+h*@hddG5!&oDa#LMD63@$)ch`XM{(sPIC5{6k&L^_rZ&& zLgc_UiT8_)Nt1g&z0Fv2B}N9OChfU4HdBUty?V4q+8a?MC+LjVgCrTV#G?0_ln>I% zL98qDEk~Fv@wmU*HvjW|jOY2yW%A^iuo#NP8##HE!x{K88l%CzO~U+b-i77o`wyx1LaO}JT48(90gztn42 zZSeDriE?>PDSPLClH|175nHX-y9u{G4i`$kek8h#4j8}Sd8o`gRep>|f)tJCwh10o zGetZ;yKM55-&f^K zH9B3d6II_$OzYL?)clMf@jSM{pwnJnVt3Iddne_;=M6@PC$e6eIR0j9-<~nw#L2_k zy5|f}mAgAG?9-;Tml);xAU5)uzZezJyv7)_5c&OXgO%1zLq+iH3&qOl<)f3W6D&(w$bj5njkvKSdSlaf5hC_RbmeMgfO5YGJ|AVtataV?E>V$`z8HXR!R{g^=It! zZS+uF+UFG}-7;jDZvM&*O-5}OJ&lYkbKCuT znIx&1TQzNg??dU@;K7pxrC!UjdyMbsS4$P!qc#0Z$H$4@_7i@a-v1{0+kYH8q=&Ou z(c1Q5`ppFC7*PHB*QKdK?5W?XL0+tEZIlz9T{B#`7uWY*zBNXaYc*qnZU1C({;EgY z{X631%%JV@F2YB=aKAV>dr|(Md#=4zYv9FrVep~4S+A@}QS-P{`Ibw+Ndp)Ej|ZZE z$R8I$~XM^RAd-pzPTuhR= z6{cIi_xdiE<&9fzcQIMa-2G(6wqnj=(&48aUGH3$H7lL`-X$?oZtXE=x+waSzR4Aa*HWI2cNA{9r%n$zoBwu^ufa0GMg;g7q&)tmQ ziqgMc?+mKpCdZyXf1vD3e^I02sxy`@`kKRg-$YC~8zl4ezaE)XI#5)s+i67dg=86I zEHW~t1dEWCmFrE2O_gnp&sd%t7%bjAirO-6OpL6z!luyi&KhyYC#R5Ry`${8vwnqM z#r(zBHQ_OBgQA4v(u<`!Qmz|sN|KLPHY#Cx`irP#*=$}> z^K=s59t#yxQ?zKF` zb^q9A*UKh~(%p_Us#PdklqnGr*rt7=92@YteChf5f3M3Js|tncq{+?WE^gkQ8!XGu zefK=@r;n@?v~F5P^iyH>^`&)188590&F|!M+EETDe$!y_;#e`OQz`H9yEEkZq7R(P z?+uoEEpOb<`x!19bPErVd-7yik$B1CNr_zXq8Mot_H*~Igao;BugTWMiC*&Z z$T8h(@GbF(k9p;?HS+O`B9P_R;1-{?6~@wn0|4 zh8bFlS1W(mq#h|s;)7X{wIB2SpLVP}UF@@)sJ?h~38RlM#ohfi+y_oi62~-e z%-x+`#gn~D)_FwQi+yvZF5PnKgZQy|;Mw<^^WWWXeKxp0vzPP7nOCaXQQzFQ<2Cc$ zdlO{T{&gC&Q?YXIfHbSEHG^cCQx{jKh6l+dL%mvb4^0*0`@h`0q*toEl76Gy@VE%s zBl41u&jXE|P|Dr5%!+V%qk?poJRWq^B+EKUtbOIzAf;J^+<3|A(x5S6A|u$m^R~+V(kbTg@I#dhG&NT_##XND zBiq;PQSL#75NZCp&Dpz&!6NB#^np5817%XL;+NWud?d=A%xUa)EmmAQpV-#yNs`>Y z*XfP@kSuX1%C6SZ_gUhJ_qP1s<%WxQb86|ukIVl(gx-gFlQX?!M(mMwPo3lCx`QUY z4$jRH2OoVjiddN>ig&cGc)-qAZr<0i?~Ar!;;><(VhwG-iQL!b{*!zA%B^O_LhJ4D zl-19@OHL1r6YoE*pRp!ANNk$>X-KX347p-%?vV7b1aUh3S^4?-|2;r@&!}nh=S7La ze$V~to4X2={tZ`l=#nYr)d7P|eg%nY9WC{T=0=F^rVD0w(6g7d(sPPr=l}lJ_MX{? zBOw}@SMu&)QzK^?)PCRLMTu$BuxRPJ-)H3idBN+Rnfr2m#0lR?R%XUQ;#}JLNoQK7 z%7>a!?!SvW%1uX(yk7RjOJ05TdD!{j9Py+1zScb+MRGgeg;u?iWa{MoZf?^*ixOvYXXlKIkW0GV4_dK4MwY!ARAktu{GTV*PB^~d zYO*|F7E;c-Sg`!}z{BB*wZ_T;wnw@;zWXVj^{m;c_3c#oWyprya~0o-;WrY7E~{lL z^D<7vBwmh^A-=)!zsh)vmaR1&mPYSneZOAohg=Vk#hqK_IXBCZ8K0Z1@XY`33$6&i z0~W7-i=L6ggHPoDd4tX1a_Ixlg^8kFTCDM&b^egEmza1XwR7Bdea%m&R&_F_28qwNqqp=t`%$h8Ix*_9xGRnx zXw-AnsW^H6`L!qMhJLbpt*v?s*2T+8?TU=PH`iMXKd|Fr=gZ;3|AptLLvi^(mrs7U z)^L*SfJp3Sepz0?dKeoz8pT0qOAo;}6B_G7E@>Lric`E*mPapKB+Nz!! zr#|^0b;-y7U-RLi530JVO&yRtU3$9uQ~B`ltNPRd;bW6WzaTu~*dRRW;YSxF4qxS= z4^Kt%bfH=Y8$R`@gH0W5^6-MzG8y%I8Uq$jj;%YuN{lTL? zer)tqTNe*Mc_8^9e)#l3f2s|SeB$Wf$3{o>>x#nz(Wfq`#?b@GQ+e3%Ky>LB8$PzG zi=RA|uj*sN$0iRyh(0`!I5vIZ*QH0CJZz9W;?yUOA07x#jgtotn|?v!AUqI#_|#Et zc*NCw{AwON5FV)V=@X9wd|iI@iDQH4g6LxFLRC*!9)6XNE`E^pfcWVbJ>uvP z$5s(ONSr)St%skw7;vjzby6WSH2cnN2 zHgWXTeEjgq2Z@uX+T>9Ogr`df9*7QnY;^v{)5S-ReDpx#AnU{i$)^rHbcpM+RX==m zRHQDb#?iwM52R0cpf0~I9W@UgsICJ(NL)qwpkEMOHLmKx2jNo(9bHHseDvX?hYb%T zP9A#r$s>**9&z+Q)vxC3io*k`2VeD5Pvzk!4y*#|GgOM~6E2;SmSn>C(fm>YztHarnfk1FC#@svjNl zRD`d~Pd*5bJp3xc1Bqj+$hujFYU4)-8^o_7eNqoTewB}{#&y}~sruODu^!c?9()iU z@xRp}uGT>hBo3;1s!e_Px^(DQ%_k2XHLlv|fW+0h_|eB!dDyBRe2_SOVypTpAAS7r z@q@%wn>xg)j}CDVKRonQ8@}qN4m#NILE_k;%EyLJJ^U)7qayhrI@ln2tQ#A|uOe~! z!3Ob@hpoo(t33Sli46}#kM)80LGsl&ew9aE{2=+{f%xeQzl!L>1JMUnJ^U&ked4M< zeo)oLR^!+pb^a@#KFC7{WF7FZiIayPTSZ;{s64d}y7>QASLKnf=3&F9KV5uy#Iflg z8$=Hq9)9Yh2daE{_(AlDs|XK2`q=aV;zt+!TU^b9Pd=#jgRLU@=n_|T@T)xXL3Q0K z9~(rMJnEu{4Id;<9y<8RBaRKiLk~m;9*7?ver)vAde|U*H6I%udib$HbU=96@bII{ zI`FGD`Jfsn52Ow}Y!DuLpc*G1q+gYfpM3Zr^@xM29~~7{KDsKB2VX_v^iO>dKRoiW ziBp$6khp52qx#VS)jFykJoK>91DaS$DJ;eo`l!9*JqtHg=v(%7!U zvq-yi@%=;3?vVK77?C z4@4h?kBtt9pM3oE1)>8FKYWlly7)6{$lWeDc*eesokmI%*u7IEX%oK5x@^r0-xUM|4Zf5XZ09fd`@wl7|k6A3YEq5FUBDc*Nm@Y9FdDdTKs4h@NVrry_OH#|F{C zMjt+LU8wS@ulli7gik$?dhk^qHaz;*<%fqJe6=2ac&eT*KRn{9F7?qTuI7;s!iNXq zM;AmF#E%~N#L>ZqkFE0YgT%=LsR!cM#aH#n!w-^=4G$y^QdeE4$|nvVeRQx@{2QOR zS`WWkR~HXH{lLQyqDPflEQgs0l@$p`VnR}no>jiUpy4v;+J z)C19j55gxOKY7^jR6lY2@Id?^e(GVXNSr)u5I^-u9_p@%+k)mC}f@X-UwhX)cTPxX^W98~?-=&?@pu;JlXZFtlH z;e%=(e)2)`b@AyJgipRMKXK}k2jW+eIQ^(P_|>}T;KxP}KYAd3cp&-I!&dpKA3i$R z)KhKMuPd(ds0*KZs$cCF9Tmw3$pg^^$%l_0n>hLSvEi#c{P3~Wb*eh>@Pp{12NGA2 zIQ7U^dDvQ+f0Q5{D0>LmWPSkT`kp=!-Zuc_4leK6=>1LA4Hk@`z)D z@IdtGhdgXudE|rSp$C$$>Ji6}jxIksAoa+{Pd;&Uu|bs&Pu0V(qOLmR!^fr`HgR;+ zIDHaVZT#q{an(1(VG~COga@iT^5N4TsOG^Z zuG;v?!&Z?vbyXd;Piz&{JapkD(I=)qI-;o%4IlZTC-%A-DUkbKpK52`%)#8n$U zs4GsN#8sPo>Zt1^4}=d7J^b*{1BqjU_(Af~#a8)hoP7GiCQs$7w#p-)`q<>b#}DfA zlc%D}ryd9&n>zSGbXY$KpEyXKE>!cWPd>Ja=)prr<*7D2kbLyuW2@`ZWs`?4Hhfi2 z3IQc3P2gwKFqlX_GBp-ywT<~L)2Tw(K__0+au1im?tL76&2gHvqb07}l2hqm{ z;iC%=Bu+h*hhIhX;e+_m1@Xf}PxWK#%A;?1=%5P{2l0~+{w=QRlZT&t{Mgh3;j=FI z#Q)}3b>V}=sRN>qO<{GgggK7M#0eW9!Jv5BLDO&ml|)hCa*is-3G9)5H|wJtpT*s6}| zrypH;#6ey4$%hBxrw?rUKo6Td{2=+*@LYPQ?Mv+yCw^)!KfBmJQY_DC@nmG5kFwIM z($yLmd&(T6qz2tP#mNmO4hLK0|GkEZT}DL=8yX@U#%;duR6bb1Yg}U3`g1))WOmiT z13N4#q{)8yBP!&6x@>D2n71)ES)7a8SN`F!m$GBp>#?58G^p6ZIyriaTWAI-pt>*L)Ux`)AsRk@=fdg4{}7! z1J7^vTX{=%D|RNMhI4{gz3TS5X5RwEIlY;8a?ZHRi#M&Dd|UZw>!6!4ved^|rzCOH z$!tRX)gQ&Wk8hsr7?UB+2e@r3TPi|37j(9mZ1P+6@UYmk&0+JvIwqm!lV z9KX|j!&_11!m0$*T~XS(E*V(oSKn$$@?x7P&AfxT+PuuHf4X0N|63CGzV*4rguMKJ zpRq}3n;2IE4gDqF{M}_^&)eF5sPCJe?O3W-{@?#Tzv%OthI$&W@sEOjuS%3FN)#S- z#`fp`nSa6hXS%Ij`}Xb=$y}M^<3^QZcDMHw=&dc(xZHq+c|xtj{*zB$s72Ek?eqWL zwKB_M=8yWCB?s1;G@x-tvK+Q$P=inJq6O?ixb&9dL~s?Vb(eQ(Qav+;TUzXO2H`J!+Az7u1YR`QYD59$`_ zzTrx}?XkPKX*F`E-GXd+I6;3)m|bCW z`e0qmn|sQ;!+qp@x9R58`&%7XNj>f{eAYQFrs0Ejn?q&Gs23w1-1{zA2X!o>leVN( zOBb9&aO9an#b$QQ64W98h{f`-j+2r_|3%BTZ(SQD*>CPYbLO4nym8Ozhxx$c-15G% z4(ih{_lk21|6ZRP6<53s)Sh$B=eXi-jV3nxt=&ib%rmyh`hFz~2TS^(ukvQQD<8`L zdq7zy`^G+VPvEQPg?ln*F2a-mu>JRUfvm)HPLbFDAMz{CdJD zUD~Yw-Zgt^cK+YX`~ThNGoP{|-f6_3znlm58y(&|&NKIq{pCK>m!@saMEk55nKdoKv2(hf zwqM=@`0O{j>b+rK_*|iHKBxJ-b|o+ z&M|Z6^MQ4Zv$}BiLhU5&=M3*Y>)LF3===JCzqIcGeRQuk-1m{mZ*3l*2b_1_N!HKz z66*3!b02v>m9lKJ(@(+<_k?r8d1Ze1*=P2V?>XFe&IS9(`gtE&7xTx@JICh_ z>sqiP{~r^~EUw}Fa1Z$2#P=fR!98aF>>uv~2#<5kdGB!XR;`su8QSk-=TsUvk9)|u;G8l?=FPs)ANP#UAI>S~hV#lgxxef`^J72wyk6 zoG0ebd*{}^VX3Mf2Iib|&I{igc(++E=aF}lbI<;>U-Zp=;QJu&GW*P&*dOjS>!e@K zCw%scd%@2g+y~ZEqH>nGeFZ~v*0a&u?uN&KBIbObWZry^UCY`r`}*oM?e7ivzQKLv z^MktFe^6a7`^4uXb76n^o=U&Gm%PV(-f&KtAM@w)k$vV|asIjA%$I!D1E0F=FW*P_ zJmq}RfBWpZo8I3x)No!n2k5gOoPYcv=a%zNeeNae=HBr88@>lnk9BiC*&le!f%l5{ zk$X*_>?d{cvu^5f57D7-KL5Ed{M^So#XG_J$YcMG7k@HyLs>)Z=MU@Rdm879cbPuu z4<73z&iq*)=Z5pd{o;LKfB2lFFWwctM-%6~b5B?o_mjEN2j`BynKPfye4j!WKi|L5 z!v@)J*1@^t`x5)YedK!s{qQdHPO&fK@lNo*FhADKeWgD4gzu%?Z}x}xjQ583ke^>U z&)h%mCHqWYoD63lJ=5v$t&H2L5d&>7_&Lj7de(00)#J%Ht z0&{0wyoam{2aACc9qUTSUX> z5_92R@$)wGrat@6_hHt-{pQ|Mhx0*y=;3Fd`Tos&Ro3KA1B*=t+V9KEiSJ+hK9M(&-(NF-*2R0nTsZ%HKj6J%fA}8Hd(JuFo#*q0`^vk3 zE^DA-dR4E@iTws$Xw`; zd(Ha!`IvqzXUmv;eu=El1~pUj_g$9-j<+Mk~a{(PVM=y8tt{H7oFjr~A}^F<%%(I@Aa z{owwx-|#pe___b|&Aj=%Wj^%5drBSNT|Sr5<-Raa?jz@zd(OLpKA*F!hxxMq>?`+( zd&_$1pZctW{y4{+U)~q`5AG~MugC6&V`SPx?Pt1jL!SCUj6Z7F6pfAn~`#@d3Pw-ANKR&m~WB;)^&**RN zn)rM1Rs-$tsX1@#6YnPT;hb=fm`C67H9q+F*4O?XhXV0&njFnx2* z9dn^B-QT_eQ)I3w_fZCdL)k@V$$AoOkBM zzEY3#zylGY}~k2-ajRN5I^{bzp-!3sCQ{v{XgOb zx?c)-Aa&tmQ%B_y2gw8BtM-4@Q+e3vsrW}db)A2+3+C}xU1e_Y)j42;^v61|iK~b% zd~6kotNHlR^NTLGo1_znZ7o+Wk`cR`veKV;>4o%U5*hgF5J_e&RZH3UrBs zI`imH=||1`uYLTj??2*7ec~$qQJ=g5gbxoJ#1CrMSCEelADg-9^1}n+t4KcewW#cu zmRI0cZR&u;bs;?ZQF+=vlsV|)ldsj)F@K$L<(z5jD}L(ytq=0yQx_ZjuR7?V1Co!e zOP4r0=6!3qL&K_|d_yq8i5rRUPW!M-RU?kNLB11NzT~)%-u|C_LiYI?Dc{2g1iD55L-{mZ#K(rz>BL!_(?1^Tn_2Ls^HK4_`&{@so#N zMQxu2>m#o9(`Nw^2erI{eClaY)kP1!f?8dbr`W6yK1f|%HhC(tjz9d$x`|UqMYXOP z*OgD6iu48Q$|tTxWnO>l16>e4wu;0-`o@NjA0$qn_(AeO{2+P6(ZdGeX_0w=%DE$t zI@sugf5iWGzLoQ%=3}dU;-?P~9ySOcq+j@2oAoQG>cH3PF)t-g)kjA~RS#Vd9j#5B zg1YFbI^=1wAP>G)kN!bLU(Hu-bW~kz5T4qn>PKJeSJsJNMe@+sWvl&Q!^cJ+TSYZa zJ_ujs;fF^Yn>a{Zwc-DvjbB?&nd2XEr5<|Z|H1!@o|e!0b>Uxhw0$W3F%S9!wfq8I)vxVKiK9op zmRC^s4}U>Ee)2*5Dx#|*eX9AYpFDV~t@??B#OVuLjce(pUu`RYEYb5`>ezP677 z9ya}AQ&074;|03#@T*8Za{%G}5m(lujVpfYq6ZH|U&~YGqK&f-kUENuu3As455EB6 zYwHx~YHfwD>X46Lr(K|<>cR)Ladmw<@~NxlVT170I%*sqevmxos`|0X{~J%sSN4H? zHI5#A5Lb~nI%=NkC$8#X6W8Xc=R?^iUG>Q)t|D=5-2z>B^nnfHCm%nEA0GO|Ra^Cw z2NFk@eBvN}5I=cZTe)BGsiOugg~bstz_hU8wT@W~(~z z|AyrKSDwmOZCyI>KwUiiDi0fk2kP<@S5f7wHhlabeh|Nke~YVp)z+1-`muE(y2PRA{Qt^R`PiVUqsy=I{$`^C>f#euQRS;~Y z@^sm{@>D;zE*<>3c&cBQjUQBXb@}0eDi1$MTt(t4{*9-`v32RFerz=lTSfTbe~tf* zuj>4-Hg!})4?nhw#6dNVONlCSddgMW+v zKdsI|&23c_1mOubaDWAZ0}P-51z1502?P8PBry6Wm2}p=|9hiKch}nIEYF?&mFsTx z-nEzAWvu_2-7k8_E~CC&dwUuCj%;7#Ax|z}cV4;Pt0(hU)?+3jesTG$_s;9Syw&&p<*nX*`RemN<@G)9@xE8S-s$)FB@gX#xpqCelkw`W-8)!+ zaYlCBIyrp5MIs{wlq@=-uvm zw9ELdeeY#l-?5kcSGwmn>&tp_*}Ik9*=1b49qY;E^7Wr`z2A50$;;F8aP@2FmzRI- z-s6>P*Q2{nbMLNue!1@2yVGCfU+g>ltLK&L?$vY8m&>)UzV}!5&f=%{^}T!b_2fnG zz3%np#r5v$y<>mP%lq}GpMU@3&wqdU?Z@B#`Qfks{zg6OvHi;J-J>1#i@fskySx9^ z-hJO|=3R7GUU&9&r(d2O>&c6~t7k8H`jz#iJAB>i%NM=RlNamB=&o6h_OiQ-{APVw zPcGf9-hOf2>)p%v+Sij8`#w*P`d9W&M(?b=dUs`g`|4kJUfFlA{CeMa-02s2Sg(Ea ztM6{H@5qnTZaw+hdAR)cGP*DAC9nO;{E~n5yv5!(yVI|{dUse)kFPrq z-J>4$uUvQf(!2VVSHJGs_uh{7i+rsA+SgmT{rVmcc~`E-FPHb)^UC!dp8eD8uJjIG zI}ewyzdSp-TYUB0^Raz154$(l>nJ*JFKs$-}klm$&xi zzxMT-ukY5E%f8cH=9PA|%cw^^>KA#)lNWi&lP@k`k9|kKc=>ksGG6_)d*`Be@bcI0 z{;Tw^dl~iEzL|%1)T17+-d^_JuD{4%yzcGSz5UhaVej>2JsJ6Ox%SofUgp1Y-RaSL zE9;S8e!BBZ?|$0%PF{R{pXZ&G_2s8`eC&J8%hNCR9hooV>h0y_+ZVmp{agFqf13SI z@7iBGzwBKz|EvA$9k2ZA_j~7EJ+^<%ywbaimv499tS{@yrMvdBy&wWy zc=`6yyLuV*<=Wd{*}FV!|7l+D=&|p<_U_Q#;#WO)c)iand)HmAC)-Q=MR)b|xOP1c z_2sqmd8jYD%NM=Fld-*wJp5|!4!u`i`|^AD)!e)FJw1Q1-pZf$>wE3Ji}vRBc3gY= zmG$V|MIJ6se|dIvhwWwLmG(uxUB=bhk%#T&a{29F?|8j$u4ljKZqdDcWj((7yie&n F{{u$uPbvTa literal 0 HcmV?d00001 diff --git a/sagemaker-train/tests/data/train/script_mode/data/train/y_train.npy b/sagemaker-train/tests/data/train/script_mode/data/train/y_train.npy new file mode 100644 index 0000000000000000000000000000000000000000..b8c17c4972c671c08983b1d4c05d9e8521fb495d GIT binary patch literal 9760 zcmbVy_g_x`|9|VMu6@qA_EnS>8D(W8U3VlD35o2{utWCD-Yd%9dxea=WJJiGA!L?S zW)g*b9?$Fa{{8`9zqr+{tIj#k$KyWF%8;%dLL+Q##@S4+9vKxkaCCKF_v(JpO{&*+ zuO1ycdd%qlBl^aUj*Md859~iIE{eV%H>m%}DEhiaz4~<<*7mIJKGpsI_n++j%+jQ$ zXEfOFuqb0~UlnqK&D+Oj>F_P^?FvhG4iERt8#E?UK+6foMqON|gZ}xO%4ankZr))3 ziY@a>3hh4TY7kL+NAg3-Ve8iEGZ*Ou%-h}VTKxbGE*|k3vE9}Q$92E2f8NXCp3lnI z6HN{HnQ(bfg|1F;`l@dm>gR~N*J~v8T*u+?vXI^-AMLPZ*frhrHwpwaJT)u(KOXZg zHc#?D=8QEXgjwk_HF8Z2i#m=JAPoKHX|qcY`9OIK)wMFqkJ;5KhBlJpBxb(O%(jL8 z{J-!Ndmh0nJKy;fWk5pm+9^ks_So^_tZKt|uKd28(6>+g)G;y+wskkPX!TEuv2O8M z9-npSdVbH2kn|!;LoZk7tJ?%9Gum8Sr`?pTFl}2%hr8cJdxwX_rF;aKG+2^0lP@MIO&1(mu63Du>s>xfh%3m8ilUDliO_W9;Fw+x^d~(59;<;8~^?kE`uV zb$lowBHdxcvkqF+@|nCvm#xMzKlSHZPi*inWpd%+!yG~zy^(!M(c?*z0d@SVX>n~= zuUEa374W^W+R*PG2j8)8j{0@ffZcA#Hbz{v^B7s+rHAwQm`^7P^my$!Wt#E47TNbz z(>sK7c;~z1SlDwO+B*T+U$X?*JuZ0W)=`f^?>E035Tyls>?hlpvCyyGspvvEa=X|p zj#;k5xM2>j$4=$2Fdd}yxRwISp2>5R_;89=k%CsHPE8&*PO{G^E7Bs620D~ zfd*ac*4(QZsY6E1=U4nv98u4@Ffw$|Ee^`RXh4x}hiQ?A6UI0lX)tH08ok*Ed#aJR zuK69=VhR2wCVK3Q*JH8HXU*JKGDz4u*5PpCp+P#_q!IsLtHYrwN$<}0F17R<*Ri60 zpB+4}bD5`pX~cbVuth|}jT;psv(QIAVa zrfX!K)RF=`oiy#q3o>^nN-wIf9BDB!!@ z$L(Dn$62f{+}j=lYgB*I^@<(d zRoMN$t(O9i(!-_tcRV~AxjmY@P=@^5({{{P7@$4cNAYI10bb!*{mgTvXirZo?1vn0 zJ~{7}#M)xtRNsmXMw)Sa!OGti=yPq{`#kbC8c{iH?_>KVT0C)?(!YM90*}|WeCJcF z1RbtlJK~Y3o$Wl|6aJmF-Ru_lUVt~T#A%5UhOE>ZA4V(CywJ)1_A@ia&%Qs?@Rz=S z`hBgoM|sdTIW)Yh#IgO$3s>gbAb#_xR~W>jcIq*+U@)L@7;Wk+B}Vo5y6uoVhbl4s zJ2pElApQ2qk?wg4aj%9Opl5@et3$(JyR5820YmB%vODmg?OB>Y`?9qT4VcabN47cM zuj?np%|mZ(&i~QCbZx;nU40&x=h#<@uyMrDs0-dd;*>bQ{K3IZ-?hk>2K#MU%)`yA z>E~4=IBbp(w(JTw;zsY_D^6!*_~=aleqD)?1tT`)ALg+z!h7!M*L1!+r#%nbp+!!h z{RDZQ6r;;Lw3XAf2q^tjkZ#i9!@8@DpXNx>hki}-W?ED~6FcS8L;)Q?^!vMcmL5Bn zRC)Niflz+k3Va_me2+0kh4Ysdn%moR@VeAB>V?h>XLmZO$ubN)wsfCITd9bnmjwL0 zkznJ~l|#|5GaaTE+u_@z$R(5JYv9|=v(u z2-F7q1pM!OjOcX6=Wy2x6|wm2%!UIyDWD|UzujF^zQ;VO9lBM|eR7GV{nBPTXC?Bm z9Wvx`4>ui#^<4GX{k{Y%YWK5Cy=6lGjQh2sN{s06^jyk42Mx~nX2yT?;jw&{5dCBt zhak`Q#R+c>_}y#xrLsZ^R?*hkO%;$NJ(qR+sRBYH1mtiq zN{2UHJ)q2TI%&Cj(l|Ak?7AFhLUMLrVj_EJ)-M~{s87^9WSo>bSF;&1wM$x9^i~OV-280QwtUSM-yz zM4cJ~f~wxVPPlo=Wo?_XEou=zNzV;tgxp$%F!sRuIij_4SN~sw%wVc!_;?N>t!pl> z6DfeuxOL58dYI_8{_-5wq#sP1xR1xm+>j6bnrkpV$f@8T@zNO2nBF?jRLnBk9Cmx#HZ^3Qb}H zzFMrf97K3()m5Cx+r&eJZhw3pD@FYkD&YN)TE(k-J7M0!QDq(t1#Ca$ezW^ARr#C{ zZa--17`coGlR4p~91Pz!uKBUe3H!d?>o6cxj)K>X%;`PlVB^`hsum#wxb98I=wRQ^ zZ`R!B9Ck*lkJ)87;$wQ#DnV&_F!;&v<3*j(O@Wc)AC5>1b%2kP^V|2lfav)Hd*_gT zdgPt0e44}~_4C`ks);7J2tT`*j?{wGZ?~THD&pn$4Chy!G+;75$jBpcYBz^}Pu1`q zy!c_K2P(+#EpC0kM2q@8NRd|M!RWuZxfT^x)xBSQNG;-qBjJ4F$86`SZs_iM>(7v- zQjwQkHOR_JIeymJh(3do@3)o9QMoZODDky*-;+7SGXOIRxSJH`nVc_$VY6L7zYJ&i zCo6Tiy=cy7929DQCO%}imZbr)?1xlunxn{l{lWw}gbUjVf#srZ?_oq#qbpUb#F!A@ zZ)wGIsWPbL$9GM9rvZb*rH9qRLk6$V!(MezT0_h`m+0S@MarJ3m}VtH6%4hAL96Is9h)K*W;k40R?ncp((a2o z+JGJ&-~6kGeiwym+J$zNpa+&{k2_wMr;ZD zJ;u*N0Ev-NWzlMo%KUX9Pb>c11;R6ub5D~cSVUv4wKL&$6}I69v|uvzDTk*e(}tWb z(m_@K>XnE!dbGS9{ovgY6`JKmpR{!o&~DK6|Fj0ug9&49=DFxZ|Fz5pFBT9XlOC9z zc68j+OHQ!L9(DQrTyX)N#=KC80Yg)J4@%ZZka42*TIFap*q8UOaLD*kqt~*TMm(MK z)_2Jm!q=A@_?fUWb*MX|=lJgpITSSgFg3T9M%?F1nsQ#XL7L3dviF1%FK4%O zaGfcL`Y%+2W-mG}S-sQ<<$}qzr?w>@MIRq}i1zDeUjMD6w+C)GC!I5&gH3W|W*~Ts z*L9xOxq|`mwH_bgYZ!2sWN1>5BU}mo`lTBX+onmK2|Qirwi}LX^S)b1?Cqa0OAGOQ zl_D;_afRod>V;K3IdM)>oyEBsXh6u_6}`&n=aU@v?2RH_FgSPq@n2)KU~Bb?-#TClk6%Mm0-Ik{ zVKG;6xa$fN7X6L-nYUGj#D1fid><;o*~f1$7ws@&WVL3E?7nFr`XF1B(U0}ts)CBU zaW^-WLt@))t2+(jad5W-fY-y=vl zR3|@dCtc{aT_x^KDGz2V+O6i%gcoDcP zwb_01wRjp@@74>_!3-vYFX_=|;OafkduT*|oub3Ly{i9ADgldUpV$z5#E4%bua4SH z{BWGsGcsR}FZRoyTpa8IrsiH(m4L~DOHH)caUr|q@>ZmunLi?ZB;vgp9!v6epZRI2&?BMa5A|64m2wOg3*2u&{^Qq+pmD<#@OUvSZlJG#=ZA-Pip~(g?CtMv z_9)DEwNH6eY8g1`f}pPP%0h0HA6bLPunCiP*50mz|Iw=Ga8MAR1D&7MZzejyKgoYg zi)TE3W=?)^jdbXpLcf#EYVqLnw#;4|?oytc*q7%u3UsMVj`l8x)lGZrHIoI*ZT{)& zqa7Szobb7Mjt&ea<~KHo`5pNHyKkqaoliD`-0uC&&E0T??D8JUZMxlV_r;~J69zJ& z)I^HdP5H~$_ZRR?(fH+ednZJnF!5V8p>5%&KXC6~6K7CFtFyddOzr|y1cu;zHHDjAe%-ddS#Chn= zq3Od*!yCrimFpN6EU(8BG^wcD6F4}tfN7Bty+-GFR=6P`M#eTtgR)j+thVUjnB8V~ z@)W6v*I7nv9MrASn>H0Od4Aw{E=_@MQvJLJfXW>yncSc)-v=JqUyFhw0_7kT^=Kl=a0o@#)RJIUMf)7hlGehiU_hfRd4A~5Jw3v>84IrLvIVoh-WvuR3LrOfP&El z2K<{qgFjM(L)G$TZ92$-#hR^Y z5FQ@BYW~69RATMd;J>E_n-P#+DQxdrg*c~?I$TJP`|^tXrZp$qV!(cag{)EyxVYZ$8k2TxLRY2?v^bGPWu)@2yEV6w|8-*O~Op zmsbh7q!;)2GQj5X=l5c{(*X17=_f`hEBH~&8L=H7hcJDN6 zfd({f7vHvWftCrI9Xf1{Ihr4LT#u!PnjfudqeLIC*Ujn$8?gE2zS|Fy>_tDUQHuFd zV*#fp?rSw-mj;{qP_pRc1aD?+eFV`LhHH^vA;8#3pC|YW3v?EBU8Vug?hR?1ldA<= z|8bNvl(pqe;djj-R3OGxq5f!=G^*fAvGeS@mcFEmYXgP3@4Cpc`Vo2T9Y9uT7Yg!Cl zH-5vm2JY}ZbR;-znIPsm20a>&8hx%+JszbGT;10Ulp*Weni@MU=lYpS*(yYdI>PRiK?0A>Bir_`H8;QK&GNRyT6c*-+ClhUTJF zlGZrQL_pd~N-1PMa-=x{dO^+1V^;6TsH?^EV#XiPhiTv;n2qyJU_u zcP#Hr)Z$TlQc`nyFxj5bMvX4-Pha+`rvr;eG&RYu(QVznOupgKrU)V41v$*Nl7697 zH%IZv6?1M4{U{_F!0dV{&5=FHrhy^U3sH=8{G%E96K){tw5r@Mnb0b3V_*;Zy(8MQ zNmh$^66_52XFE|oVa+81jQHG=e$ZDB_88-{)L`~>z#hs$D!e-O@{$pCz0>`De<;9G zpBWuESY;sXvo$X`Y!dsbc@q46yiB{I1&40^uH=W4uCnU1i6xdfBwI!u*NM1Ysz$7m z6vsEpznEUhbH$2jzkNrlWTMWhO?Z**e6he+4i+!n{vx3Nu1|d<7SMiDOqr5R&-ED* zx`!O?SfCcI68*xmAC?O4Oi@TM?4$6Y=PVAlcAnh#k^Hi?w`eWFr#1_(91e0qkvT2k z$3x-;mW%q@f?O*}SRT&J*LU5#-wB_4wDf;p!sGqp1EswVn!(mVVK(77rS?8Uep@JI zrI9iqeDi+V1Y0m#?GeL+Qmq4fW+>3uI=B)z{Ofsj=@JJ}`C<6h+=vez^_&k7?-P18 zGE;xmvL{_(zB8Jz04QC8sn<7nExKwJbB>xSv?DoPoBWbB$GORi{%VXAerGx-RZ>aI z^KnNJpO$Mt?&5B@bf#}zxD?9Q!6 z=Vj&Dl@6FwK!AKnh3pw|fvZT@Fm&#eO*q1WyLb&^vrp7GN^@q_f#i3sx|nb;s0B5C z(E|3(?9^`KOAhOM5FvZIig=!;5plS{4v8uv_5&O?7M=Zed>-vBqgB$iQ;*(U^1;&` zl%6fGZNnjV=lYB8Un`;)og#X<;4G!IRpvP3@aRD+K25O`b8V$Y%-x#H@pwuvmztNg zBA!Mo!Qzw0QhQ7ene^h9neuCjg)>6Upf;pg$X%6)L;gxUYw@xDR_e2y7;*nmmiJvc z(O24NapISj5Q!7@_6nV-H`Y0W(v3_{@(EL_lLC~T&mcWG%HH3yn0hp8pIM+5^<*PQQEzP2q5JY^s*sdou0nd~ zmWdR}X$?q}PL3vCcty48x_VMnn`F8)I+#O#Vv1rEyK;PoOtj#Z0!i%yQCEj&z3 zYf)dsbn^-Fkt~;>p6@)xTa#N+K42wsJ0$cX#kfQ-;veyZs88IGkaX=SztjQGoG4ho z?}WlHB#7qg@FR!@Imrn_nIWBJ6!kp$pUAQM-|a|MfKtgs4+)2WMZ@0x_WWb9_5&3- zkV;DNe58jV^{&nH3vzKD?FEq8B|+{Cw*95Vx7>%+L-v^*m^@rRjmKg8cZ+K*;o*Nr zrij=m;Qr>nZTY4qk>6Ju5z2}hjSl_IbC=C;$BTN0al>&eZ!-Y@f28c$*2LhI)iC0&-X4(|r;mK`L15*B9i^nEwpZgV6BO@G%Op zZ=l?s*{xT*%wQ_0kB6q5k9d4pMuvH>5eJ&0ZkuJK{~1x#*1)sZvI{?)wdMVq3GITZ zxubm}ciGCdmPtJ43ldz*XnAn(CMD_$l+e84(NfxA_`ho$hB4z>#efRSOnVmpbHXV1 zs`54QG7;Z7jrhKSizAAww2tgf^K8|vHPl$iUhnwif;m&}9E~_meA9}WJ-VJQlI;n; z=XH>pm--)q2RlC~Nr-p2y!6;mBg1dUYV z9ZXifT_`bp#6IgAepF)0`^=*;Ppm5rLLV*UE z14+^M)k8%wU5B3H91mKN>=LW{PLG2?m+ycPlxF`DUsPOimZf(TMMvRIw5H zD~boT4KJ)LIFYOjsI9}bFiL)xN<^K%MT4{F4IW+Q$i<#x zrUBipfY7}EJwVMvL<^< zuMEqNhg5ur+n{$z?c1}ODba)-%vTkbexM?k`r}jen)hEoJ=n-w^Y1-(mtr>6aAjk4 zSbw0C`fd%v9p;lVZJ=ib*g!p&^G) 2 else None + + # Verify filters contain the job name + assert filters is not None, "Filters should be passed to list_service_job" + assert filters[0]["name"] == "JOB_NAME", "JOB_NAME filter should be present" + assert filters[0]["values"] == [JOB_NAME], "Filter values should contain the job name" + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_list_jobs_empty(self, mock_list_service_job): + """Test list_jobs returns empty list""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY]) + + queue = TrainingQueue(JOB_QUEUE) + jobs = queue.list_jobs() + + assert len(jobs) == 0 + + +class TestTrainingQueueGet: + """Tests for TrainingQueue.get_job method""" + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_get_job_found(self, mock_list_service_job): + """Test get_job returns job when found""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS]) + + queue = TrainingQueue(JOB_QUEUE) + job = queue.get_job(JOB_NAME) + + assert job.job_name == JOB_NAME + assert job.job_arn == JOB_ARN + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_get_job_not_found(self, mock_list_service_job): + """Test get_job raises error when job not found""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY]) + + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(ValueError, match="Cannot find job"): + queue.get_job(JOB_NAME) diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py new file mode 100644 index 0000000000..2e3ca916d2 --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py @@ -0,0 +1,265 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for training_queued_job module""" + +import pytest +import time +import asyncio +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.train.aws_batch.training_queued_job import TrainingQueuedJob +from sagemaker.train.aws_batch.exception import NoTrainingJob, MissingRequiredArgument +from .conftest import ( + JOB_NAME, + JOB_ARN, + REASON, + TRAINING_JOB_NAME, + TRAINING_JOB_ARN, + JOB_STATUS_PENDING, + JOB_STATUS_RUNNING, + JOB_STATUS_SUCCEEDED, + JOB_STATUS_FAILED, + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED, + DESCRIBE_SERVICE_JOB_RESP_FAILED, + DESCRIBE_SERVICE_JOB_RESP_PENDING, +) + + +class TestTrainingQueuedJobInit: + """Tests for TrainingQueuedJob initialization""" + + def test_training_queued_job_init(self): + """Test TrainingQueuedJob initialization""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job.job_arn == JOB_ARN + assert queued_job.job_name == JOB_NAME + + +class TestTrainingQueuedJobDescribe: + """Tests for TrainingQueuedJob.describe method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_describe(self, mock_describe_service_job): + """Test describe returns job details""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.describe() + + assert result["status"] == JOB_STATUS_RUNNING + mock_describe_service_job.assert_called_once_with(JOB_ARN) + + +class TestTrainingQueuedJobTerminate: + """Tests for TrainingQueuedJob.terminate method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.terminate_service_job") + def test_terminate(self, mock_terminate_service_job): + """Test terminate calls terminate API""" + mock_terminate_service_job.return_value = {} + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.terminate(REASON) + + mock_terminate_service_job.assert_called_once_with(JOB_ARN, REASON) + + @patch("sagemaker.train.aws_batch.training_queued_job.terminate_service_job") + def test_terminate_default_reason(self, mock_terminate_service_job): + """Test terminate with default reason""" + mock_terminate_service_job.return_value = {} + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.terminate() + + call_kwargs = mock_terminate_service_job.call_args[0] + assert call_kwargs[0] == JOB_ARN + + +class TestTrainingQueuedJobWait: + """Tests for TrainingQueuedJob.wait method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_immediate_completion(self, mock_describe_service_job): + """Test wait returns immediately when job is completed""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + + assert result["status"] == JOB_STATUS_SUCCEEDED + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_with_polling(self, mock_describe_service_job): + """Test wait polls until job completes""" + mock_describe_service_job.side_effect = [ + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED, + ] + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + + assert result["status"] == JOB_STATUS_SUCCEEDED + assert mock_describe_service_job.call_count == 3 + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_with_timeout(self, mock_describe_service_job): + """Test wait respects timeout""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + start_time = time.time() + result = queued_job.wait(timeout=2) + end_time = time.time() + + # Should timeout after approximately 2 seconds + assert end_time - start_time >= 2 + assert result["status"] == JOB_STATUS_RUNNING + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_job_failed(self, mock_describe_service_job): + """Test wait returns failed status""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_FAILED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + + assert result["status"] == JOB_STATUS_FAILED + + +class TestTrainingQueuedJobGetModelTrainer: + """Tests for TrainingQueuedJob.get_model_trainer method""" + + @patch("sagemaker.train.aws_batch.training_queued_job._remove_system_tags_in_place_in_model_trainer_object") + @patch("sagemaker.train.aws_batch.training_queued_job._construct_model_trainer_from_training_job_name") + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_get_model_trainer_success(self, mock_describe_service_job, mock_construct_trainer, mock_remove_tags): + """Test get_model_trainer returns ModelTrainer when training job created""" + # Return a real dict (not a mock) so nested dict access works + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + mock_trainer = Mock() + mock_construct_trainer.return_value = mock_trainer + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.get_model_trainer() + + assert result == mock_trainer + mock_construct_trainer.assert_called_once() + mock_remove_tags.assert_called_once_with(mock_trainer) + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_get_model_trainer_no_training_job_pending(self, mock_describe_service_job): + """Test get_model_trainer raises error when job still pending""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_PENDING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(NoTrainingJob): + queued_job.get_model_trainer() + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_get_model_trainer_no_latest_attempt(self, mock_describe_service_job): + """Test get_model_trainer raises error when latestAttempt missing""" + resp = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED.copy() + del resp["latestAttempt"] + mock_describe_service_job.return_value = resp + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(MissingRequiredArgument): + queued_job.get_model_trainer() + + +class TestTrainingQueuedJobResult: + """Tests for TrainingQueuedJob.result method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_result_success(self, mock_describe_service_job): + """Test result returns job result when completed""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.result(timeout=100) + + assert result["status"] == JOB_STATUS_SUCCEEDED + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_result_timeout(self, mock_describe_service_job): + """Test result raises TimeoutError when timeout exceeded""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(TimeoutError): + queued_job.result(timeout=1) + + +class TestTrainingQueuedJobAsync: + """Tests for TrainingQueuedJob async methods""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_fetch_job_results_success(self, mock_describe_service_job): + """Test fetch_job_results returns result when job succeeds""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = asyncio.run(queued_job.fetch_job_results()) + + assert result["status"] == JOB_STATUS_SUCCEEDED + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_fetch_job_results_failed(self, mock_describe_service_job): + """Test fetch_job_results raises error when job fails""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_FAILED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(RuntimeError): + asyncio.run(queued_job.fetch_job_results()) + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_fetch_job_results_timeout(self, mock_describe_service_job): + """Test fetch_job_results raises TimeoutError when timeout exceeded""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(TimeoutError): + asyncio.run(queued_job.fetch_job_results(timeout=1)) + + +class TestTrainingQueuedJobTrainingJobCreated: + """Tests for TrainingQueuedJob._training_job_created method""" + + def test_training_job_created_running(self): + """Test _training_job_created returns True for RUNNING status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_RUNNING) is True + + def test_training_job_created_succeeded(self): + """Test _training_job_created returns True for SUCCEEDED status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_SUCCEEDED) is True + + def test_training_job_created_failed(self): + """Test _training_job_created returns True for FAILED status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_FAILED) is True + + def test_training_job_created_pending(self): + """Test _training_job_created returns False for PENDING status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_PENDING) is False From a60459ca1067ad3cd5d425779c646f50f762a6ec Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:30:53 -0800 Subject: [PATCH 3/6] add example notebook --- ...s_getting_started_with_model_trainer.ipynb | 1005 +++++++++++++++++ .../utils/aws_batch_resource_management.py | 566 ++++++++++ 2 files changed, 1571 insertions(+) create mode 100644 v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb create mode 100644 v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py diff --git a/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb b/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb new file mode 100644 index 0000000000..7d09133599 --- /dev/null +++ b/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb @@ -0,0 +1,1005 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f6d21c4c-e5fe-4992-9fd9-a33f36e4db2d", + "metadata": {}, + "source": [ + "# Getting Started with AWS Batch for SageMaker Training jobs\n", + "\n", + "---\n", + "\n", + "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", + "\n", + "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "---\n", + "\n", + "This sample notebook will demonstrate how to submit some simple 'hello world' jobs to an [AWS Batch job queue](https://aws.amazon.com/batch/) using a [ModelTrainer](https://sagemaker.readthedocs.io/en/stable/api/training/model_trainer.html). You can run any of the cells in this notebook interactively to experiment with using your queue. Batch will take care of ensuring your jobs run automatically as your service environment capacity becomes available. " + ] + }, + { + "cell_type": "markdown", + "id": "10e12b35-3dc2-4376-b90e-54c00c70a607", + "metadata": { + "tags": [] + }, + "source": [ + "## Setup and Configure Training Job Variables\n", + "We will need a single instance for a short duration for the sample jobs. Change any of the constant variables below to adjust the example to your liking. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6316085c-262d-4437-8987-9ca7eca94965", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "INSTANCE_TYPE = \"ml.g5.xlarge\"\n", + "INSTANCE_COUNT = 1\n", + "MAX_RUN_TIME = 300\n", + "TRAINING_JOB_NAME = \"hello-world-simple-job\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4edef56-49f3-4729-afdd-5345c5710363", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.basicConfig(\n", + " level=logging.INFO, format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n", + ")\n", + "logging.getLogger(\"botocore.client\").setLevel(level=logging.WARN)\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "from sagemaker.core.helper.session_helper import Session\n", + "from sagemaker.core import image_uris\n", + "\n", + "session = Session()\n", + "\n", + "image_uri = image_uris.retrieve(\n", + " framework=\"pytorch\",\n", + " region=session.boto_session.region_name,\n", + " version=\"2.5\",\n", + " instance_type=INSTANCE_TYPE,\n", + " image_scope=\"training\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6ad14b9b-360f-446f-94ec-5c7cbd6e0818", + "metadata": {}, + "source": [ + "## Create Sample Resources\n", + "The diagram belows shows the Batch resources we'll create for this example.\n", + "\n", + "![The Resources to Create](batch_getting_started_resources.png \"Example Job Queue and Service Environment Resources\")\n", + "\n", + "You can use [Batch Console](https://console.aws.amazon.com/batch) to create these resources, or you can run the cell below. The ```create_resources``` function below will skip creating any resources that already exist." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e325ddb0-aa86-4f3b-9820-753f4bdadb19", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from sagemaker.train.aws_batch.boto_client import get_batch_boto_client\n", + "from utils.aws_batch_resource_management import AwsBatchResourceManager, create_resources\n", + "\n", + "# This job queue name needs to match the Job Queue created in AWS Batch.\n", + "JOB_QUEUE_NAME = \"my-sm-training-fifo-jq\"\n", + "SERVICE_ENVIRONMENT_NAME = \"my-sm-training-fifo-se\"\n", + "\n", + "# Create ServiceEnvironment and JobQueue\n", + "resource_manager = AwsBatchResourceManager(get_batch_boto_client())\n", + "resources = create_resources(\n", + " resource_manager, JOB_QUEUE_NAME, SERVICE_ENVIRONMENT_NAME, max_capacity=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2a5d5e2e-b266-41c4-9d17-ce6c93a42db3", + "metadata": {}, + "source": [ + "## Create Hello World Model Trainer\n", + "Now that our resources are created, we'll construct a simple ModelTrainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d71f7b99-63a3-4fd0-b735-6140fe1489f6", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import SourceCode, Compute, StoppingCondition\n", + "\n", + "source_code = SourceCode(command=\"echo 'Hello World'\")\n", + "\n", + "model_trainer = ModelTrainer(\n", + " training_image=image_uri,\n", + " source_code=source_code,\n", + " base_job_name=TRAINING_JOB_NAME,\n", + " compute=Compute(instance_type=INSTANCE_TYPE, instance_count=INSTANCE_COUNT),\n", + " stopping_condition=StoppingCondition(max_runtime_in_seconds=MAX_RUN_TIME),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "48f030c1-3e32-4265-a1fe-cdd6927e82ac", + "metadata": {}, + "source": [ + "## Create TrainingQueue object\n", + "Using our queue is as easy as referring to it by name in the TrainingQueue contructor. The TrainingQueue class within the SageMaker Python SDK provides built in support for working with Batch queues." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d90c9a4-ff38-492e-b446-61674701d9ca", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.train.aws_batch.training_queue import TrainingQueue, TrainingQueuedJob\n", + "\n", + "# Construct the queue object using the SageMaker Python SDK\n", + "queue = TrainingQueue(JOB_QUEUE_NAME)\n", + "logger.info(f\"Using queue: {queue.queue_name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "734b7fe6-0bf7-460e-aa95-7608421e900c", + "metadata": {}, + "source": [ + "## Submit Some Training Jobs\n", + "Submitting your job to the queue is done by calling queue.submit. This particular job doesn't require any data, but in general, data should be provided by specifying inputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60f2ca52-b1ea-4af2-a143-8202ce34d5e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Submit first job\n", + "training_queued_job_1: TrainingQueuedJob = queue.submit(training_job=model_trainer, inputs=None)\n", + "logger.info(\n", + " f\"Submitted job '{training_queued_job_1.job_name}' to TrainingQueue '{queue.queue_name}'\"\n", + ")\n", + "\n", + "# Submit second job\n", + "training_queued_job_2: TrainingQueuedJob = queue.submit(training_job=model_trainer, inputs=None)\n", + "logger.info(\n", + " f\"Submitted job '{training_queued_job_2.job_name}' to TrainingQueue '{queue.queue_name}'\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "da85757f-d17c-402f-87b8-164149ea9165", + "metadata": { + "tags": [] + }, + "source": [ + "## Terminate a Job in the Queue\n", + "This next cell shows how to terminate an in queue job." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b467767b-21b4-4cc0-987d-d07ef7f15ca0", + "metadata": {}, + "outputs": [], + "source": [ + "logger.info(f\"Terminating job: {training_queued_job_2.job_name}\")\n", + "training_queued_job_2.terminate()" + ] + }, + { + "cell_type": "markdown", + "id": "d578124b-8067-47cf-895c-3076a242d6a6", + "metadata": {}, + "source": [ + "## Monitor Job Status\n", + "This next cell shows how to list the jobs that have been submitted to the TrainingQueue. The TrainingQueue can list jobs by status, and each job can be described individually for more details. Once a TrainingQueuedJob has reached the STARTING status, the logs can be printed from underlying SageMaker training job." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e8de2fe-ab1d-4703-ace3-1e44ae1b2d7c", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "\n", + "def list_jobs_in_training_queue(training_queue: TrainingQueue):\n", + " \"\"\"\n", + " Lists all jobs in a TrainingQueue grouped by their status.\n", + "\n", + " This function retrieves jobs with different statuses (SUBMITTED, PENDING, RUNNABLE,\n", + " SCHEDULED, STARTING, RUNNING, SUCCEEDED, FAILED) from the specified TrainingQueue\n", + " and logs their names and current status.\n", + "\n", + " Args:\n", + " training_queue (TrainingQueue): The TrainingQueue to query for jobs.\n", + "\n", + " Returns:\n", + " None: This function doesn't return a value but logs job information.\n", + " \"\"\"\n", + " submitted_jobs = training_queue.list_jobs(status=\"SUBMITTED\")\n", + " pending_jobs = training_queue.list_jobs(status=\"PENDING\")\n", + " runnable_jobs = training_queue.list_jobs(status=\"RUNNABLE\")\n", + " scheduled_jobs = training_queue.list_jobs(status=\"SCHEDULED\")\n", + " starting_jobs = training_queue.list_jobs(status=\"STARTING\")\n", + " running_jobs = training_queue.list_jobs(status=\"RUNNING\")\n", + " completed_jobs = training_queue.list_jobs(status=\"SUCCEEDED\")\n", + " failed_jobs = training_queue.list_jobs(status=\"FAILED\")\n", + "\n", + " all_jobs = (\n", + " submitted_jobs\n", + " + pending_jobs\n", + " + runnable_jobs\n", + " + scheduled_jobs\n", + " + starting_jobs\n", + " + running_jobs\n", + " + completed_jobs\n", + " + failed_jobs\n", + " )\n", + "\n", + " for job in all_jobs:\n", + " job_status = job.describe().get(\"status\", \"\")\n", + " logger.info(f\"Job : {job.job_name} is {job_status}\")\n", + "\n", + "\n", + "def monitor_training_queued_job(job: TrainingQueuedJob):\n", + " \"\"\"\n", + " Monitors a TrainingQueuedJob until it reaches an active or terminal state.\n", + "\n", + " This function continuously polls the status of the specified TrainingQueuedJob\n", + " until it transitions to one of the following states: STARTING, RUNNING,\n", + " SUCCEEDED, or FAILED. Once the job reaches one of these states, the function\n", + " retrieves and displays the job's logs.\n", + "\n", + " Args:\n", + " job (TrainingQueuedJob): The TrainingQueuedJob to monitor.\n", + "\n", + " Returns:\n", + " None: This function doesn't return a value but displays job logs.\n", + " \"\"\"\n", + " while True:\n", + " job_status = job.describe().get(\"status\", \"\")\n", + "\n", + " if job_status in {\"STARTING\", \"RUNNING\", \"SUCCEEDED\", \"FAILED\"}:\n", + " break\n", + "\n", + " logger.info(f\"Job : {job.job_name} is {job_status}\")\n", + " time.sleep(5)\n", + "\n", + " # Print training job logs\n", + " # job.get_estimator().logs()\n", + " model_trainer = job.get_model_trainer()\n", + " model_trainer.sagemaker_session.logs_for_job(model_trainer._latest_training_job.training_job_name, wait=True)\n", + "\n", + "\n", + "logger.info(f\"Listing all jobs in queue '{queue.queue_name}'...\")\n", + "list_jobs_in_training_queue(queue)\n", + "\n", + "logger.info(f\"Polling job status for '{training_queued_job_1.job_name}'\")\n", + "monitor_training_queued_job(training_queued_job_1)" + ] + }, + { + "cell_type": "markdown", + "id": "3ca39fca-6bb6-4e0b-841e-7f66d97b6074", + "metadata": {}, + "source": [ + "# Optional: Delete AWS Batch Resources\n", + "This shows how to delete the AWS Batch ServiceEnvironment and JobQueue. This step is completely optional, uncomment the code below to delete the resources created a few steps above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d745e2d-40a8-45ef-b231-e8acd9b5e8eb", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from utils.aws_batch_resource_management import delete_resources\n", + "\n", + "delete_resources(resource_manager, resources)" + ] + }, + { + "cell_type": "markdown", + "id": "7070dcc9", + "metadata": {}, + "source": [ + "## Notebook CI Test Results\n", + "\n", + "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", + "\n", + "\n", + "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n" + ] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + }, + { + "_defaultOrder": 55, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 56, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4de.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 57, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.trn1.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 58, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.trn1.32xlarge", + "vcpuNum": 128 + }, + { + "_defaultOrder": 59, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.trn1n.32xlarge", + "vcpuNum": 128 + } + ], + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "venv-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py new file mode 100644 index 0000000000..19f51e0bdc --- /dev/null +++ b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py @@ -0,0 +1,566 @@ +import json +import logging +import time +from dataclasses import dataclass + +import boto3 +from botocore.exceptions import ClientError + +# Configure logging +logger = logging.getLogger(__name__) + + +@dataclass +class Resource: + """ + Represents an AWS resource with a name and ARN. + + Attributes: + name (str): The name of the AWS resource. + arn (str): The Amazon Resource Name (ARN) of the resource. + """ + + name: str + arn: str + + +@dataclass +class Resources: + """ + Container for AWS Batch resources used in the application. + + Attributes: + job_queue (Resource): The AWS Batch job queue resource. + service_environment (Resource): The AWS Batch service environment resource. + """ + + job_queue: Resource = None + service_environment: Resource = None + batch_role: Resource = None + sagemaker_exeuction_role: Resource = None + + +class AwsBatchResourceManager: + """ + Manager for AWS Batch resources including service environments and job queues. + + This class provides methods to create, update, delete, and monitor AWS Batch resources. + + Attributes: + TERMINAL_JOB_STATUSES (set): Set of job statuses considered terminal. + """ + + TERMINAL_JOB_STATUSES = {"SUCCEEDED", "FAILED"} + + def __init__(self, batch_client): + """ + Initialize the AWS Batch Resource Manager. + + Args: + batch_client: The boto3 Batch client to use for AWS operations. + """ + self._batch_client = batch_client + + def create_service_environment(self, create_se_request: dict): + """ + Create a new AWS Batch service environment. + + If the service environment already exists, returns the existing environment details. + + Args: + create_se_request (dict): Request parameters for creating a service environment. + Must contain 'serviceEnvironmentName' key. + + Returns: + dict: Response containing the service environment name and ARN. + + Raises: + ClientError: If there's an error creating the service environment. + """ + try: + return self._batch_client.create_service_environment(**create_se_request) + except ClientError as error: + if error.response["message"] == "Object already exists": + logger.info("ServiceEnvironment already exists, skipping creation.") + desc_resp = self._batch_client.describe_service_environments( + serviceEnvironments=[create_se_request["serviceEnvironmentName"]] + ) + return { + "serviceEnvironmentName": desc_resp["serviceEnvironments"][0][ + "serviceEnvironmentName" + ], + "serviceEnvironmentArn": desc_resp["serviceEnvironments"][0][ + "serviceEnvironmentArn" + ], + } + + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def update_service_environment(self, update_se_request): + """ + Update an existing AWS Batch service environment. + + Args: + update_se_request (dict): Request parameters for updating a service environment. + + Returns: + dict: Response from the update operation. + + Raises: + ClientError: If there's an error updating the service environment. + """ + try: + return self._batch_client.update_service_environment(**update_se_request) + except ClientError as error: + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def await_service_environment_update( + self, service_environment_name: str, expected_status: str, expected_state: str + ): + """ + Wait for a service environment to reach the expected status and state. + + This method polls the service environment status until it reaches the expected state + or is deleted if that's the expected status. + + Args: + se_name (str): Name of the service environment to monitor. + expected_status (str): The expected status to wait for (e.g., "VALID", "DELETED"). + expected_state (str, optional): The expected state to wait for (e.g., "ENABLED", "DISABLED"). + + Returns: + dict: The describe service environments response when the expected state is reached. + """ + while True: + describe_response = self._batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + if describe_response["serviceEnvironments"]: + se = describe_response["serviceEnvironments"][0] + + state = se["state"] + status = se["status"] + + if status == expected_status and state == expected_state: + break + if status == "INVALID": + raise ValueError(f"Something went wrong! {json.dumps(jq, indent=4)}") + elif expected_status == "DELETED": + logger.info(f"ServiceEnvironment {service_environment_name} has been deleted") + break + + time.sleep(5) + + def delete_service_environment(self, service_environment_name: str): + """ + Delete an AWS Batch service environment. + + This method follows the proper deletion workflow: + 1. Disable the service environment + 2. Wait for the disable operation to complete + 3. Delete the service environment + 4. Wait for the deletion to complete + + Args: + se_name (str): Name of the service environment to delete. + """ + logger.info(f"Setting ServiceEnvironment {service_environment_name} to DISABLED") + self._batch_client.update_service_environment( + serviceEnvironment=service_environment_name, state="DISABLED" + ) + + logger.info("Waiting for ServiceEnvironment update to finish...") + self.await_service_environment_update(service_environment_name, "VALID", "DISABLED") + + logger.info(f"Deleting ServiceEnvironment {service_environment_name}") + self._batch_client.delete_service_environment(serviceEnvironment=service_environment_name) + + logger.info("Waiting for ServiceEnvironment update to finish...") + self.await_service_environment_update(service_environment_name, "DELETED", "DISABLED") + + def create_job_queue(self, create_jq_request: dict): + """ + Create a new AWS Batch job queue. + + If the job queue already exists, returns the existing job queue details. + + Args: + create_jq_request (dict): Request parameters for creating a job queue. + Must contain 'jobQueueName' key. + + Returns: + dict: Response containing the job queue name and ARN. + + Raises: + ClientError: If there's an error creating the job queue. + """ + try: + return self._batch_client.create_job_queue(**create_jq_request) + except ClientError as error: + if error.response["message"] == "Object already exists": + logger.info("JobQueue already exists, skipping creation") + desc_resp = self._batch_client.describe_job_queues( + jobQueues=[create_jq_request["jobQueueName"]] + ) + return { + "jobQueueName": desc_resp["jobQueues"][0]["jobQueueName"], + "jobQueueArn": desc_resp["jobQueues"][0]["jobQueueArn"], + } + + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def delete_job_queue(self, job_queue_name: str): + """ + Delete an AWS Batch job queue. + + This method follows the proper deletion workflow: + 1. Disable the job queue + 2. Wait for the disable operation to complete + 3. Delete the job queue + 4. Wait for the deletion to complete + + Args: + jq_name (str): Name of the job queue to delete. + """ + logger.info(f"Disabling JobQueue {job_queue_name}") + self._batch_client.update_job_queue(jobQueue=job_queue_name, state="DISABLED") + + logger.info("Waiting for JobQueue update to finish...") + self.await_job_queue_update(job_queue_name, "VALID", "DISABLED") + + logger.info(f"Deleting JobQueue {job_queue_name}") + self._batch_client.delete_job_queue(jobQueue=job_queue_name) + + logger.info("Waiting for JobQueue update to finish...") + self.await_job_queue_update(job_queue_name, "DELETED", "DISABLED") + + def await_job_queue_update( + self, job_queue_name: str, expected_status: str, expected_state: str + ): + """ + Wait for a job queue to reach the expected status and state. + + This method polls the job queue status until it reaches the expected state and status + or is deleted if that's the expected status. + + Args: + jq_name (str): Name of the job queue to monitor. + expected_status (str): The expected status to wait for (e.g., "VALID", "DELETED"). + expected_state (str, optional): The expected state to wait for (e.g., "ENABLED", "DISABLED"). + + Raises: + ValueError: If the job queue enters an INVALID status. + """ + while True: + describe_jq_response = self._batch_client.describe_job_queues( + jobQueues=[job_queue_name] + ) + if describe_jq_response["jobQueues"]: + jq = describe_jq_response["jobQueues"][0] + + state = jq["state"] + status = jq["status"] + + if status == expected_status and state == expected_state: + break + if status == "INVALID": + raise ValueError(f"Something went wrong! {json.dumps(jq, indent=4)}") + elif expected_status == "DELETED": + logger.info(f"JobQueue {job_queue_name} has been deleted") + break + + time.sleep(5) + + +class RoleManager: + """ + Manager for creating and managing IAM roles required for SageMaker training jobs with AWS Batch. + + This class provides methods to create the necessary IAM roles for AWS Batch to interact with + SageMaker training jobs, including the batch role and SageMaker execution role. + + Attributes: + iam_client: The boto3 IAM client to use for creating roles + sts_client: The boto3 STS client to use for getting account information + """ + + def __init__(self, iam_client, sts_client): + """ + Initialize the RoleManager with IAM and STS clients. + + Args: + iam_client: The boto3 IAM client to use. + sts_client: The boto3 STS client to use. + """ + self.iam_client = iam_client + self.sts_client = sts_client + + def create_batch_role(self, batch_role_name: str): + """ + Create an IAM role for AWS Batch to interact with SageMaker training jobs. + + This method creates a role with permissions for AWS Batch to manage SageMaker + training jobs, including the ability to create service-linked roles and pass roles + to SageMaker. + + Args: + batch_role_name (str): The name to use for the IAM role. + Returns: + Resource: A Resource object containing the name and ARN of the created role. + + Raises: + ClientError: If there's an error creating the role (except when the role already exists). + """ + get_caller_id_resp = self.sts_client.get_caller_identity() + account_id = get_caller_id_resp["Account"] + + try: + create_role_resp = self.iam_client.create_role( + RoleName=batch_role_name, + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"AWS": f"arn:aws:iam::{account_id}:root"}, + "Action": "sts:AssumeRole", + } + ], + } + ), + Description="Role for AWS Batch for SageMaker Training jobs.", + MaxSessionDuration=3600, + ) + self.iam_client.put_role_policy( + RoleName=batch_role_name, + PolicyName="AWSBatchForSageMakerTrainingJobsPolicy", + PolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": ["batch:*"], "Resource": "*"}, + { + "Effect": "Allow", + "Action": ["iam:CreateServiceLinkedRole"], + "Resource": "arn:aws:iam::*:role/*AWSServiceRoleForAWSBatchWithSagemaker", + "Condition": { + "StringEquals": { + "iam:AWSServiceName": "sagemaker-queuing.batch.amazonaws.com" + } + }, + }, + { + "Effect": "Allow", + "Action": "iam:PassRole", + "Resource": "*", + "Condition": { + "StringEquals": { + "iam:PassedToService": ["sagemaker.amazonaws.com"] + } + }, + }, + ], + } + ), + ) + + return Resource( + name=create_role_resp["Role"]["RoleName"], arn=create_role_resp["Role"]["Arn"] + ) + except ClientError as error: + if error.response["Error"]["Code"] == "EntityAlreadyExists": + print(error.response["Error"]["Message"]) + get_resp = self.iam_client.get_role(RoleName=batch_role_name) + + return Resource(name=get_resp["Role"]["RoleName"], arn=get_resp["Role"]["Arn"]) + + logger.error( + f"Error creating {batch_role_name}: {json.dumps(error.__dict__, indent=4)}" + ) + raise error + + def create_sagemaker_execution_role(self, sagemaker_execution_role_name: str): + """ + Create an IAM role for SageMaker to execute training jobs. + + This method creates a role with the AmazonSageMakerFullAccess policy attached, + allowing SageMaker to access necessary resources for training jobs. + + Args: + sagemaker_execution_role_name (str): The name to use for the IAM role. + Returns: + Resource: A Resource object containing the name and ARN of the created role. + + Raises: + ClientError: If there's an error creating the role (except when the role already exists). + """ + try: + create_role_resp = self.iam_client.create_role( + RoleName=sagemaker_execution_role_name, + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": ["sagemaker.amazonaws.com"]}, + "Action": "sts:AssumeRole", + } + ], + } + ), + Description="SageMaker training execution role.", + MaxSessionDuration=3600, + ) + self.iam_client.attach_role_policy( + RoleName=sagemaker_execution_role_name, + PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess", + ) + + return Resource( + name=create_role_resp["Role"]["RoleName"], arn=create_role_resp["Role"]["Arn"] + ) + except ClientError as error: + if error.response["Error"]["Code"] == "EntityAlreadyExists": + print(error.response["Error"]["Message"]) + get_resp = self.iam_client.get_role(RoleName=sagemaker_execution_role_name) + + return Resource(name=get_resp["Role"]["RoleName"], arn=get_resp["Role"]["Arn"]) + + logger.error( + f"Error creating {sagemaker_execution_role_name}: {json.dumps(error.__dict__, indent=4)}" + ) + raise error + + +def create_roles( + role_manager: RoleManager, batch_role_name: str, sagemaker_execution_role_name: str +): + """ + Create all required IAM roles for SageMaker training jobs with AWS Batch. + + This function creates both the AWS Batch role and the SageMaker execution role + using the current AWS account ID. + + Returns: + Resources: A Resources class containing the created roles + """ + logger.info("Creating batch role") + batch_role = role_manager.create_batch_role(batch_role_name) + + logger.info("Creating sagemaker execution role") + sagemaker_execution_role = role_manager.create_sagemaker_execution_role( + sagemaker_execution_role_name + ) + + resources = Resources(batch_role=batch_role, sagemaker_exeuction_role=sagemaker_execution_role) + + logger.info(f"Role creation complete: {resources}") + return resources + + +def assume_role_and_get_session(role: Resource, sts_client): + """ + Assumes the specified IAM role and returns a boto3 session with the assumed credentials. + + Args: + role: The IAM role resource to assume + sts_client: The boto3 STS client + + Returns: + A boto3 session configured with the assumed role credentials + """ + response = sts_client.assume_role(RoleArn=role.arn, RoleSessionName="AssumeRoleSession") + + credentials = response["Credentials"] + + return boto3.Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + +def create_resources( + resource_manager: AwsBatchResourceManager, + job_queue_name: str, + service_environment_name: str, + max_capacity: int = 1, +): + """ + Create AWS Batch resources including a service environment and job queue. + + This function creates a SageMaker training service environment and a corresponding + job queue, waiting for each resource to reach a VALID state before proceeding. + + Args: + resource_manager (AwsBatchResourceManager): The resource manager to use for creating resources. + + Returns: + Resources: A Resources object containing the created service environment and job queue. + """ + # Create ServiceEnvironment + logger.info(f"Creating ServiceEnvironment: {service_environment_name}") + create_se_resp = resource_manager.create_service_environment( + { + "serviceEnvironmentName": service_environment_name, + "serviceEnvironmentType": "SAGEMAKER_TRAINING", + "state": "ENABLED", + "capacityLimits": [{"maxCapacity": max_capacity, "capacityUnit": "NUM_INSTANCES"}], + } + ) + logger.info("Waiting for ServiceEnvironment to transition to VALID...") + resource_manager.await_service_environment_update(service_environment_name, "VALID", "ENABLED") + + # Create JobQueue + logger.info(f"Creating JobQueue: {job_queue_name}") + create_jq_response = resource_manager.create_job_queue( + { + "jobQueueName": job_queue_name, + "jobQueueType": "SAGEMAKER_TRAINING", + "state": "ENABLED", + "priority": 1, + "serviceEnvironmentOrder": [ + {"order": 1, "serviceEnvironment": create_se_resp["serviceEnvironmentName"]}, + ], + } + ) + logger.info("Waiting for JobQueue to transition to VALID...") + resource_manager.await_job_queue_update(job_queue_name, "VALID", "ENABLED") + + resources = Resources( + service_environment=Resource( + name=create_se_resp["serviceEnvironmentName"], + arn=create_se_resp["serviceEnvironmentArn"], + ), + job_queue=Resource( + name=create_jq_response["jobQueueName"], arn=create_jq_response["jobQueueArn"] + ), + ) + + logger.info(f"Resource creation complete: {resources}") + return resources + + +def delete_resources(resource_manager: AwsBatchResourceManager, resources: Resources): + """ + Delete AWS Batch resources. + + This function deletes the job queue first and then the service environment, + following the proper order for resource cleanup. + + Args: + resource_manager (AwsBatchResourceManager): The resource manager to use for deleting resources. + resources (Resources): The Resources object containing the resources to delete. + """ + if resources.job_queue: + logger.info(f"Deleting JobQueue: {resources.job_queue.name}") + resource_manager.delete_job_queue(resources.job_queue.name) + + if resources.service_environment: + logger.info(f"Deleting ServiceEnvironment: {resources.service_environment.name}") + resource_manager.delete_service_environment(resources.service_environment.name) From 41f71d6fb89ed8271d51cb76a5041ce355c3c931 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:11:04 -0800 Subject: [PATCH 4/6] Adding missing dependencies for aws_batch --- sagemaker-train/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/sagemaker-train/pyproject.toml b/sagemaker-train/pyproject.toml index 2718330404..9b88399b7d 100644 --- a/sagemaker-train/pyproject.toml +++ b/sagemaker-train/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "jinja2>=3.0,<4.0", "sagemaker-mlflow>=0.0.1,<1.0.0", "mlflow>=3.0.0,<4.0.0", + "nest_asyncio>=1.5.0", ] [project.urls] From 93268210df49c847bd51cff4b8f8bf7083bd47f6 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 12 Dec 2025 09:43:06 -0800 Subject: [PATCH 5/6] Fixing indentation bug in source code --- .../src/sagemaker/train/model_trainer.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index b0a55f3d71..2e5ed75fe6 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -744,18 +744,18 @@ def train( self.sagemaker_session._intercept_create_request(training_request, None, "train") return - training_job = TrainingJob.create( - session=self.sagemaker_session.boto_session, - **training_request - ) - self._latest_training_job = training_job - - if wait: - training_job.wait(logs=logs) - if logs and not wait: - logger.warning( - "Not displaing the training container logs as 'wait' is set to False." + training_job = TrainingJob.create( + session=self.sagemaker_session.boto_session, + **training_request ) + self._latest_training_job = training_job + + if wait: + training_job.wait(logs=logs) + if logs and not wait: + logger.warning( + "Not displaing the training container logs as 'wait' is set to False." + ) else: local_container = _LocalContainer( From 711b5a3f4b3b340d9bccd98b8ca3c24dad1c3157 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 12 Dec 2025 09:48:50 -0800 Subject: [PATCH 6/6] comment out delete resources in example notebook --- .../sm-training-queues_getting_started_with_model_trainer.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb b/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb index 7d09133599..b7f7c72cd7 100644 --- a/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb +++ b/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb @@ -330,7 +330,7 @@ "source": [ "from utils.aws_batch_resource_management import delete_resources\n", "\n", - "delete_resources(resource_manager, resources)" + "# delete_resources(resource_manager, resources)" ] }, {