diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index bc35c7eb9d..2d6ecf80d5 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1865,6 +1865,30 @@ def expand_role(self, role): if "/" in role: return role return self.boto_session.resource("iam").Role(role).arn + + + def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=None): + """Display logs for a given training job, optionally tailing them until job is complete. + + If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + job_name (str): Name of the training job to display the logs for. + wait (bool): Whether to keep looking for new log entries until the job completes + (default: False). + poll (int): The interval in seconds between polling for new log entries and job + completion (default: 5). + log_type ([str]): A list of strings specifying which logs to print. Acceptable + strings are "All", "None", "Training", or "Rules". To maintain backwards + compatibility, boolean values are also accepted and converted to strings. + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If waiting and the training job fails. + """ + _logs_for_job(self, job_name, wait, poll, log_type, timeout) def _expand_container_def(c_def): @@ -2962,4 +2986,4 @@ def container_def( c_def["Mode"] = container_mode if image_config: c_def["ImageConfig"] = image_config - return c_def + return c_def \ No newline at end of file diff --git a/sagemaker-train/pyproject.toml b/sagemaker-train/pyproject.toml index 2718330404..9b88399b7d 100644 --- a/sagemaker-train/pyproject.toml +++ b/sagemaker-train/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "jinja2>=3.0,<4.0", "sagemaker-mlflow>=0.0.1,<1.0.0", "mlflow>=3.0.0,<4.0.0", + "nest_asyncio>=1.5.0", ] [project.urls] diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/__init__.py b/sagemaker-train/src/sagemaker/train/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py new file mode 100644 index 0000000000..4cb88f65ea --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The module provides helper function for Batch Submit/Describe/Terminal job APIs.""" +from __future__ import absolute_import + +import json +from typing import List, Dict, Optional +from sagemaker.train.aws_batch.constants import ( + SAGEMAKER_TRAINING, + DEFAULT_TIMEOUT, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, +) +from sagemaker.train.aws_batch.boto_client import get_batch_boto_client + + +def submit_service_job( + training_payload: Dict, + job_name: str, + job_queue: str, + retry_config: Optional[Dict] = None, + scheduling_priority: Optional[int] = None, + timeout: Optional[Dict] = None, + share_identifier: Optional[str] = None, + tags: Optional[Dict] = None, +) -> Dict: + """Batch submit_service_job API helper function. + + Args: + training_payload: a dict containing a dict of arguments for Training job. + job_name: Batch job name. + job_queue: Batch job queue ARN. + retry_config: Batch job retry configuration. + scheduling_priority: An integer representing scheduling priority. + timeout: Set with value of timeout if specified, else default to 1 day. + share_identifier: value of shareIdentifier if specified. + tags: A dict of string to string representing Batch tags. + + Returns: + A dict containing jobArn, jobName and jobId. + """ + if timeout is None: + timeout = DEFAULT_TIMEOUT + client = get_batch_boto_client() + training_payload_tags = training_payload.pop("Tags", None) + payload = { + "jobName": job_name, + "jobQueue": job_queue, + "retryStrategy": DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + "serviceJobType": SAGEMAKER_TRAINING, + "serviceRequestPayload": json.dumps(training_payload), + "timeoutConfig": timeout, + } + if retry_config: + payload["retryStrategy"] = retry_config + if scheduling_priority: + payload["schedulingPriority"] = scheduling_priority + if share_identifier: + payload["shareIdentifier"] = share_identifier + if tags or training_payload_tags: + payload["tags"] = __merge_tags(tags, training_payload_tags) + return client.submit_service_job(**payload) + + +def describe_service_job(job_id: str) -> Dict: + """Batch describe_service_job API helper function. + + Args: + job_id: Job ID used. + + Returns: a dict. See the sample below + { + 'attempts': [ + { + 'serviceResourceId': { + 'name': 'string', + 'value': 'string' + }, + 'startedAt': 123, + 'stoppedAt': 123, + 'statusReason': 'string' + }, + ], + 'createdAt': 123, + 'isTerminated': True|False, + 'jobArn': 'string', + 'jobId': 'string', + 'jobName': 'string', + 'jobQueue': 'string', + 'retryStrategy': { + 'attempts': 123 + }, + 'schedulingPriority': 123, + 'serviceRequestPayload': 'string', + 'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING', + 'shareIdentifier': 'string', + 'startedAt': 123, + 'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED', + 'statusReason': 'string', + 'stoppedAt': 123, + 'tags': { + 'string': 'string' + }, + 'timeout': { + 'attemptDurationSeconds': 123 + } + } + """ + client = get_batch_boto_client() + return client.describe_service_job(jobId=job_id) + + +def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict: + """Batch terminate_service_job API helper function. + + Args: + job_id: Job ID + reason: A string representing terminate reason. + + Returns: an empty dict + """ + client = get_batch_boto_client() + return client.terminate_service_job(jobId=job_id, reason=reason) + + +def list_service_job( + job_queue: str, + job_status: Optional[str] = None, + filters: Optional[List] = None, + next_token: Optional[str] = None, +) -> Dict: + """Batch list_service_job API helper function. + + Args: + job_queue: Batch job queue ARN. + job_status: Batch job status. + filters: A list of Dict. Each contains a filter. + next_token: Used to retrieve data in next page. + + Returns: A generator containing list results. + + """ + client = get_batch_boto_client() + payload = {"jobQueue": job_queue} + if filters: + payload["filters"] = filters + if next_token: + payload["nextToken"] = next_token + if job_status: + payload["jobStatus"] = job_status + part_of_jobs = client.list_service_jobs(**payload) + next_token = part_of_jobs.get("nextToken") + yield part_of_jobs + if next_token: + yield from list_service_job(job_queue, job_status, filters, next_token) + + +def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]: + """Merges Batch and training payload tags. + + Returns a copy of Batch tags merged with training payload tags. Training payload tags take + precedence in the case of key conflicts. + + :param batch_tags: A dict of string to string representing Batch tags. + :param training_tags: A list of `{"Key": "string", "Value": "string"}` objects representing + training payload tags. + :return: A dict of string to string representing batch tags merged with training tags. + batch_tags is returned unaltered if training_tags is None or empty. + """ + if not training_tags: + return batch_tags + + training_tags_to_merge = {tag["Key"]: tag["Value"] for tag in training_tags} + batch_tags_copy = batch_tags.copy() if batch_tags else {} + batch_tags_copy.update(training_tags_to_merge) + + return batch_tags_copy diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py b/sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py new file mode 100644 index 0000000000..87f3486887 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/boto_client.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file provides helper function for getting Batch boto client.""" +from __future__ import absolute_import + +from typing import Optional +import boto3 + + +def get_batch_boto_client( + region: Optional[str] = None, + endpoint: Optional[str] = None, +) -> boto3.session.Session.client: + """Helper function for getting Batch boto3 client. + + Args: + region: Region specified + endpoint: Batch API endpoint. + + Returns: Batch boto3 client. + + """ + return boto3.client("batch", region_name=region, endpoint_url=endpoint) diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/constants.py b/sagemaker-train/src/sagemaker/train/aws_batch/constants.py new file mode 100644 index 0000000000..ee41d3a413 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/constants.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file defines constants used for Batch API helper functions.""" + +from __future__ import absolute_import + +SAGEMAKER_TRAINING = "SAGEMAKER_TRAINING" +DEFAULT_ATTEMPT_DURATION_IN_SECONDS = 86400 # 1 day in seconds. +DEFAULT_TIMEOUT = {"attemptDurationSeconds": DEFAULT_ATTEMPT_DURATION_IN_SECONDS} +POLL_IN_SECONDS = 5 +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_COMPLETED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError: " + "We encountered an internal error. Please try again.", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/exception.py b/sagemaker-train/src/sagemaker/train/aws_batch/exception.py new file mode 100644 index 0000000000..94318bbce4 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/exception.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file Defines customized exception for Batch queueing""" +from __future__ import absolute_import + + +class NoTrainingJob(Exception): + """Define NoTrainingJob Exception. + + It means no Training job has been created by AWS Batch service. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) + + +class MissingRequiredArgument(Exception): + """Define MissingRequiredArgument exception. + + It means some required arguments are missing. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py new file mode 100644 index 0000000000..c464f0b382 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py @@ -0,0 +1,203 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define Queue class for AWS Batch service""" +from __future__ import absolute_import + +from typing import Dict, Optional, List +import logging +from sagemaker.train.model_trainer import ModelTrainer, Mode +from .training_queued_job import TrainingQueuedJob +from .batch_api_helper import submit_service_job, list_service_job +from .exception import MissingRequiredArgument +from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING + + +class TrainingQueue: + """TrainingQueue class for AWS Batch service + + With this class, customers are able to create a new queue and submit jobs to AWS Batch Service. + """ + + def __init__(self, queue_name: str): + self.queue_name = queue_name + + def submit( + self, + training_job: ModelTrainer, + inputs, + job_name: Optional[str] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> TrainingQueuedJob: + """Submit a queued job and return a QueuedJob object. + + Args: + training_job: Training job ModelTrainer object. + inputs: Training job inputs. + job_name: Batch job name. + retry_config: Retry configuration for Batch job. + priority: Scheduling priority for Batch job. + share_identifier: Share identifier for Batch job. + timeout: Timeout configuration for Batch job. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a TrainingQueuedJob object with Batch job ARN and job name. + + """ + if not isinstance(training_job, ModelTrainer): + raise TypeError( + "training_job must be an instance of ModelTrainer, " + f"but got {type(training_job)}" + ) + + if training_job.training_mode != Mode.SAGEMAKER_TRAINING_JOB: + raise ValueError( + "TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB" + ) + if experiment_config is not None: + logging.warning( + "ExperimentConfig is not supported for ModelTrainer. " + "It will be ignored when submitting the job." + ) + training_payload = training_job._create_training_job_args( + input_data_config=inputs, boto3=True + ) + + if timeout is None: + timeout = DEFAULT_TIMEOUT + if job_name is None: + job_name = training_payload["TrainingJobName"] + + resp = submit_service_job( + training_payload, + job_name, + self.queue_name, + retry_config, + priority, + timeout, + share_identifier, + tags, + ) + if "jobArn" not in resp or "jobName" not in resp: + raise MissingRequiredArgument( + "jobArn or jobName is missing in response from Batch submit_service_job API" + ) + return TrainingQueuedJob(resp["jobArn"], resp["jobName"]) + + def map( + self, + training_job: ModelTrainer, + inputs, + job_names: Optional[List[str]] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> List[TrainingQueuedJob]: + """Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects. + + Args: + training_job: Training job ModelTrainer object. + inputs: List of Training job inputs. + job_names: List of Batch job names. + retry_config: Retry config for the Batch jobs. + priority: Scheduling priority for the Batch jobs. + share_identifier: Share identifier for the Batch jobs. + timeout: Timeout configuration for the Batch jobs. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name. + + """ + if experiment_config is None: + experiment_config = {} + + if job_names is not None: + if len(job_names) != len(inputs): + raise ValueError( + "When specified, the number of job names must match the number of inputs" + ) + else: + job_names = [None] * len(inputs) + + queued_batch_job_list = [] + for index, value in enumerate(inputs): + queued_batch_job = self.submit( + training_job, + value, + job_names[index], + retry_config, + priority, + share_identifier, + timeout, + tags, + experiment_config, + ) + queued_batch_job_list.append(queued_batch_job) + + return queued_batch_job_list + + def list_jobs( + self, job_name: Optional[str] = None, status: Optional[str] = JOB_STATUS_RUNNING + ) -> List[TrainingQueuedJob]: + """List Batch jobs according to job_name or status. + + Args: + job_name: Batch job name. + status: Batch job status. + + Returns: A list of QueuedJob. + + """ + filters = None + if job_name: + filters = [{"name": "JOB_NAME", "values": [job_name]}] + status = None # job_status is ignored when job_name is specified. + jobs_to_return = [] + next_token = None + for job_result_dict in list_service_job(self.queue_name, status, filters, next_token): + for job_result in job_result_dict.get("jobSummaryList", []): + if "jobArn" in job_result and "jobName" in job_result: + jobs_to_return.append( + TrainingQueuedJob(job_result["jobArn"], job_result["jobName"]) + ) + else: + logging.warning("Missing JobArn or JobName in Batch ListJobs API") + continue + return jobs_to_return + + def get_job(self, job_name): + """Get a Batch job according to job_name. + + Args: + job_name: Batch job name. + + Returns: The QueuedJob with name matching job_name. + + """ + jobs_to_return = self.list_jobs(job_name) + if len(jobs_to_return) == 0: + raise ValueError(f"Cannot find job: {job_name}") + return jobs_to_return[0] diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py new file mode 100644 index 0000000000..d7c42ea7ad --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py @@ -0,0 +1,354 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define QueuedJob class for AWS Batch service""" +from __future__ import absolute_import + +import logging +import time +import asyncio +import re +from typing import Optional, Dict +import nest_asyncio +from sagemaker.core.resources import TrainingJob +from sagemaker.core.shapes import Unassigned +from sagemaker.train.model_trainer import ModelTrainer +from sagemaker.train.configs import ( + Compute, + Networking, + StoppingCondition, + SourceCode, + TrainingImageConfig, +) +from .batch_api_helper import terminate_service_job, describe_service_job +from .exception import NoTrainingJob, MissingRequiredArgument +from ..utils import get_training_job_name_from_training_job_arn +from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS + +logging.basicConfig( + format="%(asctime)s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" +) + + +class TrainingQueuedJob: + """TrainingQueuedJob class for AWS Batch service. + + With this class, customers are able to attach the latest training job to a ModelTrainer. + """ + + def __init__(self, job_arn: str, job_name: str): + self.job_arn = job_arn + self.job_name = job_name + self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"} + + def get_model_trainer(self) -> ModelTrainer: + """Attach the latest training job to a ModelTrainer and return. + + Returns: a ModelTrainer instance. + + """ + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + if self._training_job_created(job_status): + if "latestAttempt" not in describe_resp: + raise MissingRequiredArgument("No LatestAttempt in describe call") + new_training_job_name = _get_new_training_job_name_from_latest_attempt( + describe_resp["latestAttempt"] + ) + output_model_trainer = _construct_model_trainer_from_training_job_name( + new_training_job_name + ) + _remove_system_tags_in_place_in_model_trainer_object(output_model_trainer) + return output_model_trainer + + _output_attempt_history(describe_resp) + raise NoTrainingJob("No Training job created. Job is still waiting in queue") + + def terminate(self, reason: Optional[str] = "Default terminate reason") -> None: + """Terminate Batch job. + + Args: + reason: Reason for terminating a job. + + Returns: None + + """ + terminate_service_job(self.job_arn, reason) + + def describe(self) -> Dict: + """Describe Batch job. + + Returns: A dict which includes job parameters, job status, attempts and so on. + + """ + return describe_service_job(self.job_arn) + + def _training_job_created(self, status: str) -> bool: + """Return True if a Training job has been created + + Args: + status: Job status returned from Batch API. + + Returns: a boolean indicating whether a Training job has been created. + + """ + return status not in self._no_training_job_status + + def result(self, timeout: int = None) -> Dict: + """Fetch the terminal result of the Batch job. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict. + + """ + nest_asyncio.apply() + loop = asyncio.get_event_loop() + task = loop.create_task(self.fetch_job_results(timeout)) + resp = loop.run_until_complete(task) + return resp + + async def fetch_job_results(self, timeout: int = None) -> Dict: + """Async method that waits for the Batch job to complete or until timeout. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict, or an Error. + + """ + self.wait(timeout) + + describe_resp = self.describe() + if describe_resp.get("status", "") == JOB_STATUS_COMPLETED: + return describe_resp + if describe_resp.get("status", "") == JOB_STATUS_FAILED: + raise RuntimeError(describe_resp["statusReason"]) + raise TimeoutError("Reached timeout before the Batch job reached a terminal status") + + def wait(self, timeout: int = None) -> Dict: + """Wait for the Batch job to finish. + + This method blocks on the job completing for up to the timeout value (if specified). + If timeout is ``None``, this method will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: The last describe_service_job response for the Batch job. + """ + request_end_time = time.time() + timeout if timeout else None + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + while not job_completed: + if timeout and time.time() > request_end_time: + logging.info( + "Timeout exceeded: %d seconds elapsed. Returning current results", timeout + ) + break + if job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED): + break + + time.sleep(POLL_IN_SECONDS) + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + return describe_resp + + +def _construct_model_trainer_from_training_job_name(training_job_name: str) -> ModelTrainer: + """Build ModelTrainer instance from training job name. + + Args: + training_job_name: Training job name. + + Returns: a ModelTrainer instance with _latest_training_job set. + + """ + # Step 1: Get the TrainingJob resource + training_job = TrainingJob.get(training_job_name=training_job_name) + + # Step 2: Extract parameters from training_job to reconstruct ModelTrainer + init_params = {} + + # Required/common parameters + init_params["role"] = training_job.role_arn + init_params["base_job_name"] = _extract_base_job_name(training_job_name) + + # Training image or algorithm + if training_job.algorithm_specification and not isinstance(training_job.algorithm_specification, Unassigned): + if (training_job.algorithm_specification.training_image and + not isinstance(training_job.algorithm_specification.training_image, Unassigned)): + init_params["training_image"] = training_job.algorithm_specification.training_image + if (training_job.algorithm_specification.algorithm_name and + not isinstance(training_job.algorithm_specification.algorithm_name, Unassigned)): + init_params["algorithm_name"] = training_job.algorithm_specification.algorithm_name + if (training_job.algorithm_specification.training_input_mode and + not isinstance(training_job.algorithm_specification.training_input_mode, Unassigned)): + init_params["training_input_mode"] = training_job.algorithm_specification.training_input_mode + + # Compute config + if training_job.resource_config and not isinstance(training_job.resource_config, Unassigned): + compute_params = {} + + if (training_job.resource_config.instance_type and + not isinstance(training_job.resource_config.instance_type, Unassigned)): + compute_params["instance_type"] = training_job.resource_config.instance_type + if (training_job.resource_config.instance_count and + not isinstance(training_job.resource_config.instance_count, Unassigned)): + compute_params["instance_count"] = training_job.resource_config.instance_count + if (training_job.resource_config.volume_size_in_gb and + not isinstance(training_job.resource_config.volume_size_in_gb, Unassigned)): + compute_params["volume_size_in_gb"] = training_job.resource_config.volume_size_in_gb + + # Add managed spot training if enabled (available directly on TrainingJob) + if training_job.enable_managed_spot_training and not isinstance(training_job.enable_managed_spot_training, Unassigned): + compute_params["enable_managed_spot_training"] = training_job.enable_managed_spot_training + + if compute_params: # Only create Compute if we have valid params + init_params["compute"] = Compute(**compute_params) + + # Output config - pass the raw training job output config directly + if training_job.output_data_config and not isinstance(training_job.output_data_config, Unassigned): + init_params["output_data_config"] = training_job.output_data_config + + # Stopping condition + if training_job.stopping_condition and not isinstance(training_job.stopping_condition, Unassigned): + if (training_job.stopping_condition.max_runtime_in_seconds and + not isinstance(training_job.stopping_condition.max_runtime_in_seconds, Unassigned)): + init_params["stopping_condition"] = StoppingCondition( + max_runtime_in_seconds=training_job.stopping_condition.max_runtime_in_seconds, + ) + + # Networking + if training_job.vpc_config and not isinstance(training_job.vpc_config, Unassigned): + networking_params = {} + + if (training_job.vpc_config.subnets and + not isinstance(training_job.vpc_config.subnets, Unassigned)): + networking_params["subnets"] = training_job.vpc_config.subnets + if (training_job.vpc_config.security_group_ids and + not isinstance(training_job.vpc_config.security_group_ids, Unassigned)): + networking_params["security_group_ids"] = training_job.vpc_config.security_group_ids + + # Add network isolation if present (available directly on TrainingJob) + if training_job.enable_network_isolation and not isinstance(training_job.enable_network_isolation, Unassigned): + networking_params["enable_network_isolation"] = training_job.enable_network_isolation + + # Add inter-container traffic encryption if present (available directly on TrainingJob) + if training_job.enable_inter_container_traffic_encryption and not isinstance(training_job.enable_inter_container_traffic_encryption, Unassigned): + networking_params["enable_inter_container_traffic_encryption"] = training_job.enable_inter_container_traffic_encryption + + if networking_params: # Only create Networking if we have valid params + init_params["networking"] = Networking(**networking_params) + + # Hyperparameters + if training_job.hyper_parameters and not isinstance(training_job.hyper_parameters, Unassigned): + init_params["hyperparameters"] = training_job.hyper_parameters + + # Environment + if training_job.environment and not isinstance(training_job.environment, Unassigned): + init_params["environment"] = training_job.environment + + # Checkpoint config + if training_job.checkpoint_config and not isinstance(training_job.checkpoint_config, Unassigned): + init_params["checkpoint_config"] = training_job.checkpoint_config + + # Step 3: Create ModelTrainer + model_trainer = ModelTrainer(**init_params) + + # Step 4: Set _latest_training_job (key insight!) + model_trainer._latest_training_job = training_job + + return model_trainer + + +def _extract_base_job_name(training_job_name: str) -> str: + """Extract base job name from full training job name. + + Args: + training_job_name: Full training job name. + + Returns: Base job name. + + """ + # Use the same regex pattern as PySDK V2's base_from_name() function + # Matches timestamps like: YYYY-MM-DD-HH-MM-SS-SSS or YYMMDD-HHMM + match = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", training_job_name) + return match.group(1) if match else training_job_name + + +def _output_attempt_history(describe_resp: Dict) -> None: + """Print attempt history if no Training job created. + + Args: + describe_resp: Describe response from Batch API. + + Returns: None + + """ + has_seen_status_reason = False + for i, attempt_dict in enumerate(describe_resp.get("attempts", [])): + if "statusReason" in attempt_dict: + logging.info("Attempt %d - %s", i + 1, attempt_dict["statusReason"]) + has_seen_status_reason = True + if not has_seen_status_reason: + logging.info("No attempts found or no statusReason found.") + + +def _get_new_training_job_name_from_latest_attempt(latest_attempt: Dict) -> str: + """Extract new Training job name from latest attempt in Batch Describe response. + + Args: + latest_attempt: a Dict containing Training job arn. + + Returns: new Training job name or None if not found. + + """ + training_job_arn = latest_attempt.get("serviceResourceId", {}).get("value", None) + return get_training_job_name_from_training_job_arn(training_job_arn) + + +def _remove_system_tags_in_place_in_model_trainer_object(model_trainer: ModelTrainer) -> None: + """Remove system tags in place. + + Args: + model_trainer: input ModelTrainer object. + + Returns: None. Remove system tags in place. + + """ + if model_trainer.tags: + filtered_tags = [] + for tag in model_trainer.tags: + # Handle both V2 dict format {"Key": "...", "Value": "..."} and V3 object format with .key attribute + if isinstance(tag, dict): + # V2 format + if not tag.get("Key", "").startswith("aws:"): + filtered_tags.append(tag) + else: + # V3 format - assume it has .key attribute + if hasattr(tag, 'key') and not tag.key.startswith("aws:"): + filtered_tags.append(tag) + elif hasattr(tag, 'Key') and not tag.Key.startswith("aws:"): + # Fallback for other formats + filtered_tags.append(tag) + else: + # If we can't determine the key, keep the tag to be safe + filtered_tags.append(tag) + + model_trainer.tags = filtered_tags diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index c2fbd5da46..2e5ed75fe6 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -27,6 +27,8 @@ from sagemaker.core.resources import TrainingJob from sagemaker.core import shapes from sagemaker.core.shapes import AlgorithmSpecification +from sagemaker.core.utils.utils import serialize +from sagemaker.core.apiutils._boto_functions import to_pascal_case from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call from sagemaker.core.config.config_schema import ( @@ -250,6 +252,9 @@ class ModelTrainer(BaseModel): # Private Attributes for JumpStart _jumpstart_config: Optional[JumpStartConfig] = PrivateAttr(default=None) + # Private Attributes for AWS_Batch + _temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ "role", "base_job_name", @@ -386,6 +391,8 @@ def __del__(self): if hasattr(self, "__pydantic_fields_set__"): if self._temp_recipe_train_dir is not None: self._temp_recipe_train_dir.cleanup() + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] @@ -525,30 +532,25 @@ def model_post_init(self, __context: Any): if self.training_image: logger.info(f"Training image URI: {self.training_image}") + - @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") - @runnable_by_pipeline - @validate_call - def train( + def _create_training_job_args( self, input_data_config: Optional[List[Union[Channel, InputData]]] = None, - wait: Optional[bool] = True, - logs: Optional[bool] = True, - ): - """Train a model using AWS SageMaker. - + boto3: bool = False, + ) -> Dict[str, Any]: + """Create the training job arguments. Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): input_data_config (Optional[List[Union[Channel, InputData]]]): The input data config for the training job. Takes a list of Channel objects or a dictionary of channel names to DataSourceType. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object. - wait (Optional[bool]): - Whether to wait for the training job to complete before returning. - Defaults to True. - logs (Optional[bool]): - Whether to display the training container logs while training. - Defaults to True. + boto3 (bool): Whether to return the arguments in boto3 format. Defaults to False. + By default, the arguments are returned in the format used by the SageMaker Core. + Returns: + Dict[str, Any]: The training job arguments. """ self._populate_intelligent_defaults() current_training_job_name = _get_unique_name(self.base_job_name) @@ -593,16 +595,16 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) + self._temp_code_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) else: - tmp_dir = TemporaryDirectory() + self._temp_code_dir = TemporaryDirectory() # Copy everything under container_drivers/ to a temporary directory - shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + shutil.copytree(SM_DRIVERS_LOCAL_PATH, self._temp_code_dir.name, dirs_exist_ok=True) # If distributed is provided, overwrite code under /drivers if self.distributed: distributed_driver_dir = self.distributed.driver_dir - driver_dir = os.path.join(tmp_dir.name, "distributed_drivers") + driver_dir = os.path.join(self._temp_code_dir.name, "distributed_drivers") shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code @@ -616,7 +618,7 @@ def train( final_input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=tmp_dir, + tmp_dir=self._temp_code_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -625,13 +627,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=self._temp_code_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=self._temp_code_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=tmp_dir.name, + data_source=self._temp_code_dir.name, key_prefix=input_data_key_prefix, ) final_input_data_config.append(sm_drivers_channel) @@ -656,63 +658,92 @@ def train( resource_config = self.compute._to_resource_config() vpc_config = self.networking._to_vpc_config() if self.networking else None - if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: - # Convert tags to dictionaries if they are Tag objects - tags_as_dicts = None - if self.tags: - tags_as_dicts = [] - for tag in self.tags: - if hasattr(tag, 'model_dump'): - tags_as_dicts.append(tag.model_dump()) - elif isinstance(tag, dict): - tags_as_dicts.append(tag) - else: - # Fallback for any other tag-like object - tags_as_dicts.append({"key": getattr(tag, 'key', ''), "value": getattr(tag, 'value', '')}) - - # Build training request with snake_case keys (Python SDK convention) - training_request = { - "training_job_name": current_training_job_name, - "algorithm_specification": algorithm_specification, - "hyper_parameters": string_hyper_parameters, - "input_data_config": final_input_data_config, - "resource_config": resource_config, - "vpc_config": vpc_config, - "role_arn": self.role, - "tags": tags_as_dicts, - "stopping_condition": self.stopping_condition, - "output_data_config": self.output_data_config, - "checkpoint_config": self.checkpoint_config, - "environment": self.environment, - "enable_managed_spot_training": self.compute.enable_managed_spot_training, - "enable_inter_container_traffic_encryption": ( - self.networking.enable_inter_container_traffic_encryption - if self.networking - else None - ), - "enable_network_isolation": ( - self.networking.enable_network_isolation if self.networking else None - ), - "remote_debug_config": self._remote_debug_config, - "tensor_board_output_config": self._tensorboard_output_config, - "retry_strategy": self._retry_strategy, - "infra_check_config": self._infra_check_config, - "session_chaining_config": self._session_chaining_config, - } - - # Handle PipelineSession + # Convert tags to dictionaries if they are Tag objects + tags_as_dicts = None + if self.tags: + tags_as_dicts = [] + for tag in self.tags: + if hasattr(tag, 'model_dump'): + tags_as_dicts.append(tag.model_dump()) + elif isinstance(tag, dict): + tags_as_dicts.append(tag) + else: + # Fallback for any other tag-like object + tags_as_dicts.append({"key": getattr(tag, 'key', ''), "value": getattr(tag, 'value', '')}) + + # Build training request with snake_case keys (Python SDK convention) + training_request = { + "training_job_name": current_training_job_name, + "algorithm_specification": algorithm_specification, + "hyper_parameters": string_hyper_parameters, + "input_data_config": final_input_data_config, + "resource_config": resource_config, + "vpc_config": vpc_config, + "role_arn": self.role, + "tags": tags_as_dicts, + "stopping_condition": self.stopping_condition, + "output_data_config": self.output_data_config, + "checkpoint_config": self.checkpoint_config, + "environment": self.environment, + "enable_managed_spot_training": self.compute.enable_managed_spot_training, + "enable_inter_container_traffic_encryption": ( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None + ), + "enable_network_isolation": ( + self.networking.enable_network_isolation if self.networking else None + ), + "remote_debug_config": self._remote_debug_config, + "tensor_board_output_config": self._tensorboard_output_config, + "retry_strategy": self._retry_strategy, + "infra_check_config": self._infra_check_config, + "session_chaining_config": self._session_chaining_config, + } + + if boto3 or isinstance(self.sagemaker_session, PipelineSession): if isinstance(self.sagemaker_session, PipelineSession): - from sagemaker.core.utils.utils import serialize - from sagemaker.core.apiutils._boto_functions import to_pascal_case - - # Remove training_job_name for pipeline as it's auto-generated at execution time training_request.pop("training_job_name", None) - # Convert snake_case to PascalCase for AWS API - pipeline_request = {to_pascal_case(k): v for k, v in training_request.items()} - serialized_request = serialize(pipeline_request) - self.sagemaker_session._intercept_create_request(serialized_request, None, "train") - return + # Convert snake_case to PascalCase for AWS API + pipeline_request = {to_pascal_case(k): v for k, v in training_request.items()} + serialized_request = serialize(pipeline_request) + return serialized_request + + return training_request + + + @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") + @runnable_by_pipeline + @validate_call + def train( + self, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + wait: Optional[bool] = True, + logs: Optional[bool] = True, + ): + """Train a model using AWS SageMaker. + + Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + wait (Optional[bool]): + Whether to wait for the training job to complete before returning. + Defaults to True. + logs (Optional[bool]): + Whether to display the training container logs while training. + Defaults to True. + """ + training_request = self._create_training_job_args(input_data_config=input_data_config) + # Handle PipelineSession + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + if isinstance(self.sagemaker_session, PipelineSession): + self.sagemaker_session._intercept_create_request(training_request, None, "train") + return + training_job = TrainingJob.create( session=self.sagemaker_session.boto_session, **training_request @@ -725,21 +756,25 @@ def train( logger.warning( "Not displaing the training container logs as 'wait' is set to False." ) + else: local_container = _LocalContainer( - training_job_name=_get_unique_name(self.base_job_name), - instance_type=resource_config.instance_type, - instance_count=resource_config.instance_count, - image=algorithm_specification.training_image, + training_job_name=training_request["training_job_name"], + instance_type=training_request["resource_config"].instance_type, + instance_count=training_request["resource_config"].instance_count, + image=training_request["algorithm_specification"].training_image, container_root=self.local_container_root, sagemaker_session=self.sagemaker_session, - container_entrypoint=algorithm_specification.container_entrypoint, - container_arguments=algorithm_specification.container_arguments, - input_data_config=final_input_data_config, - hyper_parameters=string_hyper_parameters, - environment=self.environment, + container_entrypoint=training_request["algorithm_specification"].container_entrypoint, + container_arguments=training_request["algorithm_specification"].container_arguments, + input_data_config=training_request["input_data_config"], + hyper_parameters=training_request["hyper_parameters"], + environment=training_request["environment"], ) local_container.train(wait) + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() + def create_input_data_channel( self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 6e0f8c4ff7..cebd4338ee 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -13,6 +13,7 @@ """Utils module.""" from __future__ import absolute_import +import re import os import json import subprocess @@ -246,3 +247,18 @@ def _get_studio_tags(model_id: str, hub_name: str): "value": hub_name } ] + + +def get_training_job_name_from_training_job_arn(training_job_arn: str) -> str: + """Extract Training job name from Training job arn. + Args: + training_job_arn: Training job arn. + Returns: Training job name. + """ + if training_job_arn is None: + return None + pattern = "arn:aws[a-z-]*:sagemaker:[a-z0-9-]*:[0-9]{12}:training-job/(.+)" + match = re.match(pattern, training_job_arn) + if match: + return match.group(1) + return None \ No newline at end of file diff --git a/sagemaker-train/tests/data/train/script_mode/custom_script.py b/sagemaker-train/tests/data/train/script_mode/custom_script.py new file mode 100644 index 0000000000..a57ddee743 --- /dev/null +++ b/sagemaker-train/tests/data/train/script_mode/custom_script.py @@ -0,0 +1,191 @@ +# flake8: noqa +import argparse +import numpy as np +import os +import sys +import logging +import json +import shutil +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_model_def import get_model + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_train_data(train_dir): + """ + Get the training data and convert to tensors + """ + + x_train = np.load(os.path.join(train_dir, "x_train.npy")) + y_train = np.load(os.path.join(train_dir, "y_train.npy")) + logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}") + + return torch.from_numpy(x_train), torch.from_numpy(y_train) + + +def get_test_data(test_dir): + """ + Get the testing data and convert to tensors + """ + + x_test = np.load(os.path.join(test_dir, "x_test.npy")) + y_test = np.load(os.path.join(test_dir, "y_test.npy")) + logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}") + + return torch.from_numpy(x_test), torch.from_numpy(y_test) + + +def model_fn(model_dir): + """ + Load the model for inference + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model() + model.load_state_dict(torch.load(model_dir + "/model.pth")) + model.eval() + return model.to(device) + + +def input_fn(request_body, request_content_type): + """ + Deserialize and prepare the prediction input + """ + + if request_content_type == "application/json": + request = json.loads(request_body) + train_inputs = torch.tensor(request) + return train_inputs + + +def predict_fn(input_data, model): + """ + Apply model to the incoming request + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + with torch.no_grad(): + return model(input_data.float()).numpy()[0] + + +def parse_args(): + """ + Parse the command line arguments + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + type=str, + default=os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")), + help="Directory to save the model", + ) + parser.add_argument( + "--train-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TRAIN", os.path.join(current_dir, "data/train")), + help="Directory containing training data", + ) + parser.add_argument( + "--test-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TEST", os.path.join(current_dir, "data/test")), + help="Directory containing testing data", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Batch size for training", + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="Number of epochs for training", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=0.1, + help="Learning rate for training", + ) + return parser.parse_args() + + +def train(): + """ + Train the PyTorch model + """ + args = parse_args() + # Directories: train, test and model + train_dir = args.train_dir + test_dir = args.test_dir + model_dir = args.model_dir + + # Load the training and testing data + x_train, y_train = get_train_data(train_dir) + x_test, y_test = get_test_data(test_dir) + train_ds = TensorDataset(x_train, y_train) + + # Training parameters - used to configure the training loop + batch_size = args.batch_size + epochs = args.epochs + learning_rate = args.learning_rate + logger.info( + "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) + ) + + train_dl = DataLoader(train_ds, batch_size, shuffle=True) + + # Define the model, loss function and optimizer + model = get_model() + model = model.to(device) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + # Train the model + for epoch in range(epochs): + for x_train_batch, y_train_batch in train_dl: + y = model(x_train_batch.float()) + loss = criterion(y.flatten(), y_train_batch.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch += 1 + logger.info(f"epoch: {epoch} -> loss: {loss}") + + # Test the model + with torch.no_grad(): + y = model(x_test.float()).flatten() + mse = ((y - y_test) ** 2).sum() / y_test.shape[0] + print("\nTest MSE:", mse.numpy()) + + # Save the model + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_dir + "/model.pth") + inference_code_path = model_dir + "/code/" + + if not os.path.exists(inference_code_path): + os.mkdir(inference_code_path) + logger.info("Created a folder at {}!".format(inference_code_path)) + + shutil.copy("custom_script.py", inference_code_path) + shutil.copy("pytorch_model_def.py", inference_code_path) + logger.info("Saving models files to {}".format(inference_code_path)) + + +if __name__ == "__main__": + print("Running the training job ...\n") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train() diff --git a/sagemaker-train/tests/data/train/script_mode/data/test/x_test.npy b/sagemaker-train/tests/data/train/script_mode/data/test/x_test.npy new file mode 100644 index 0000000000..a9977e39c0 Binary files /dev/null and b/sagemaker-train/tests/data/train/script_mode/data/test/x_test.npy differ diff --git a/sagemaker-train/tests/data/train/script_mode/data/test/y_test.npy b/sagemaker-train/tests/data/train/script_mode/data/test/y_test.npy new file mode 100644 index 0000000000..a7191945ee Binary files /dev/null and b/sagemaker-train/tests/data/train/script_mode/data/test/y_test.npy differ diff --git a/sagemaker-train/tests/data/train/script_mode/data/train/x_train.npy b/sagemaker-train/tests/data/train/script_mode/data/train/x_train.npy new file mode 100644 index 0000000000..d267502e65 Binary files /dev/null and b/sagemaker-train/tests/data/train/script_mode/data/train/x_train.npy differ diff --git a/sagemaker-train/tests/data/train/script_mode/data/train/y_train.npy b/sagemaker-train/tests/data/train/script_mode/data/train/y_train.npy new file mode 100644 index 0000000000..b8c17c4972 Binary files /dev/null and b/sagemaker-train/tests/data/train/script_mode/data/train/y_train.npy differ diff --git a/sagemaker-train/tests/data/train/script_mode/pytorch_model_def.py b/sagemaker-train/tests/data/train/script_mode/pytorch_model_def.py new file mode 100644 index 0000000000..2440b22f88 --- /dev/null +++ b/sagemaker-train/tests/data/train/script_mode/pytorch_model_def.py @@ -0,0 +1,23 @@ +# flake8: noqa +import torch +import torch.nn as nn + + +class NeuralNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 8) + self.fc2 = nn.Linear(8, 6) + self.fc3 = nn.Linear(6, 1) + + def forward(self, x): + x = torch.tanh(self.fc1(x)) + x = torch.sigmoid(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_model(): + + model = NeuralNet() + return model diff --git a/sagemaker-train/tests/data/train/script_mode/requirements.txt b/sagemaker-train/tests/data/train/script_mode/requirements.txt new file mode 100644 index 0000000000..f7b8ccf0cc --- /dev/null +++ b/sagemaker-train/tests/data/train/script_mode/requirements.txt @@ -0,0 +1,3 @@ +numpy +-f https://download.pytorch.org/whl/torch_stable.html +torch==2.7.0 diff --git a/sagemaker-train/tests/integ/__init__.py b/sagemaker-train/tests/integ/__init__.py index bd6c65c770..aca26431cb 100644 --- a/sagemaker-train/tests/integ/__init__.py +++ b/sagemaker-train/tests/integ/__init__.py @@ -12,3 +12,7 @@ # language governing permissions and limitations under the License. """This module contains the Integ Tests for SageMaker PySDK Training.""" from __future__ import absolute_import + +import os + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") diff --git a/sagemaker-train/tests/integ/train/aws_batch/__init__.py b/sagemaker-train/tests/integ/train/aws_batch/__init__.py new file mode 100644 index 0000000000..b8aa447909 --- /dev/null +++ b/sagemaker-train/tests/integ/train/aws_batch/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""AWS Batch integration tests""" diff --git a/sagemaker-train/tests/integ/train/aws_batch/manager.py b/sagemaker-train/tests/integ/train/aws_batch/manager.py new file mode 100644 index 0000000000..b417f86b53 --- /dev/null +++ b/sagemaker-train/tests/integ/train/aws_batch/manager.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import time + + +class BatchTestResourceManager: + + def __init__( + self, + batch_client, + queue_name="pysdk-test-queue", + service_env_name="pysdk-test-queue-service-environment", + ): + self.batch_client = batch_client + self.queue_name = queue_name + self.service_environment_name = service_env_name + + def _create_or_get_service_environment(self, service_environment_name): + print(f"Creating service environment: {service_environment_name}") + try: + response = self.batch_client.create_service_environment( + serviceEnvironmentName=service_environment_name, + serviceEnvironmentType="SAGEMAKER_TRAINING", + capacityLimits=[{"maxCapacity": 10, "capacityUnit": "NUM_INSTANCES"}], + ) + print(f"Service environment {service_environment_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + response = self.batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + return response["serviceEnvironments"][0] + else: + print(f"Error creating service environment: {e}") + raise + + def _create_or_get_queue(self, queue_name, service_environment_arn): + + print(f"Creating job queue: {queue_name}") + try: + response = self.batch_client.create_job_queue( + jobQueueName=queue_name, + priority=1, + computeEnvironmentOrder=[], + serviceEnvironmentOrder=[ + { + "order": 1, + "serviceEnvironment": service_environment_arn, + }, + ], + jobQueueType="SAGEMAKER_TRAINING", + ) + print(f"Job queue {queue_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + return response["jobQueues"][0] + else: + print(f"Error creating job queue: {e}") + raise + + def _update_queue_state(self, queue_name, state): + try: + print(f"Updating queue {queue_name} to state {state}") + response = self.batch_client.update_job_queue(jobQueue=queue_name, state=state) + return response + except Exception as e: + print(f"Error updating queue: {e}") + + def _update_service_environment_state(self, service_environment_name, state): + print(f"Updating service environment {service_environment_name} to state {state}") + try: + response = self.batch_client.update_service_environment( + serviceEnvironment=service_environment_name, state=state + ) + return response + except Exception as e: + print(f"Error updating service environment: {e}") + + def _wait_for_queue_state(self, queue_name, state): + print(f"Waiting for queue {queue_name} to be {state}...") + while True: + response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + print(f"Current state: {response}") + if response["jobQueues"][0]["state"] == state: + break + time.sleep(5) + print(f"Queue {queue_name} is now {state}.") + + def _wait_for_service_environment_state(self, service_environment_name, state): + print(f"Waiting for service environment {service_environment_name} to be {state}...") + while True: + response = self.batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + print(f"Current state: {response}") + if response["serviceEnvironments"][0]["state"] == state: + break + time.sleep(5) + print(f"Service environment {service_environment_name} is now {state}.") + + def get_or_create_resources(self, queue_name=None, service_environment_name=None): + queue_name = queue_name or self.queue_name + service_environment_name = service_environment_name or self.service_environment_name + + service_environment = self._create_or_get_service_environment(service_environment_name) + if service_environment.get("state") != "ENABLED": + self._update_service_environment_state(service_environment_name, "ENABLED") + self._wait_for_service_environment_state(service_environment_name, "ENABLED") + time.sleep(10) + + queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"]) + if queue.get("state") != "ENABLED": + self._update_queue_state(queue_name, "ENABLED") + self._wait_for_queue_state(queue_name, "ENABLED") + time.sleep(10) + return queue, service_environment diff --git a/sagemaker-train/tests/integ/train/aws_batch/test_queue.py b/sagemaker-train/tests/integ/train/aws_batch/test_queue.py new file mode 100644 index 0000000000..7333acddca --- /dev/null +++ b/sagemaker-train/tests/integ/train/aws_batch/test_queue.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 +import botocore +import pytest + +from sagemaker.train.model_trainer import ModelTrainer +from sagemaker.train.configs import SourceCode, InputData, Compute + +from sagemaker.train.aws_batch.training_queue import TrainingQueue + +from tests.integ import DATA_DIR +from tests.integ.train.conftest import sagemaker_session # noqa: F401 +from tests.integ.train.test_model_trainer import ( + DEFAULT_CPU_IMAGE, +) +from .manager import BatchTestResourceManager + + +@pytest.fixture(scope="module") +def batch_client(): + return boto3.client("batch", region_name="us-west-2") + + +@pytest.fixture(scope="function") +def batch_test_resource_manager(batch_client): + resource_manager = BatchTestResourceManager(batch_client=batch_client) + resource_manager.get_or_create_resources() + return resource_manager + + +def test_model_trainer_submit(batch_test_resource_manager, sagemaker_session): # noqa: F811 + queue_name = batch_test_resource_manager.queue_name + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/train/script_mode/", + requirements="requirements.txt", + entry_script="custom_script.py", + ) + hyperparameters = { + "batch-size": 32, + "epochs": 1, + "learning-rate": 0.01, + } + compute = Compute(instance_type="ml.m5.2xlarge") + model_trainer = ModelTrainer( + sagemaker_session=sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + source_code=source_code, + compute=compute, + hyperparameters=hyperparameters, + base_job_name="test-batch-model-trainer", + ) + train_data = InputData( + channel_name="train", + data_source=f"{DATA_DIR}/train/script_mode/data/train/", + ) + test_data = InputData( + channel_name="test", + data_source=f"{DATA_DIR}/train/script_mode/data/test/", + ) + + training_queue = TrainingQueue(queue_name=queue_name) + + try: + queued_job = training_queue.submit( + training_job=model_trainer, + inputs=[train_data, test_data], + ) + except botocore.exceptions.ClientError as e: + print(e.response["ResponseMetadata"]) + print(e.response["Error"]["Message"]) + raise e + res = queued_job.describe() + assert res is not None + assert res["status"] == "SUBMITTED" + + queued_job.wait(timeout=1800) + res = queued_job.describe() + assert res is not None + assert res["status"] == "SUCCEEDED" diff --git a/sagemaker-train/tests/unit/train/aws_batch/__init__.py b/sagemaker-train/tests/unit/train/aws_batch/__init__.py new file mode 100644 index 0000000000..61dd4913ae --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""AWS Batch unit tests""" diff --git a/sagemaker-train/tests/unit/train/aws_batch/conftest.py b/sagemaker-train/tests/unit/train/aws_batch/conftest.py new file mode 100644 index 0000000000..c33baa0752 --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/conftest.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Test constants for AWS Batch unit tests""" + +# Job identifiers +JOB_NAME = "test-training-job" +JOB_ARN = "arn:aws:batch:us-west-2:123456789012:job/test-job-id" +JOB_ID = "test-job-id" +JOB_QUEUE = "test-queue" + +# Training job identifiers +TRAINING_JOB_NAME = "training-job-20251211" +TRAINING_JOB_ARN = "arn:aws:sagemaker:us-west-2:123456789012:training-job/training-job-20251211" + +# Job statuses +JOB_STATUS_SUBMITTED = "SUBMITTED" +JOB_STATUS_PENDING = "PENDING" +JOB_STATUS_RUNNABLE = "RUNNABLE" +JOB_STATUS_STARTING = "STARTING" +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_SUCCEEDED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" + +# Configuration values +INSTANCE_TYPE = "ml.m5.xlarge" +INSTANCE_COUNT = 1 +VOLUME_SIZE_IN_GB = 30 +MAX_RUNTIME_IN_SECONDS = 3600 +TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.5-gpu-py311" +EXECUTION_ROLE = "arn:aws:iam::123456789012:role/SageMakerRole" +S3_OUTPUT_PATH = "s3://my-bucket/output" + +# Batch configuration +SCHEDULING_PRIORITY = 1 +SHARE_IDENTIFIER = "test-share-id" +ATTEMPT_DURATION_IN_SECONDS = 86400 +REASON = "Test termination reason" +NEXT_TOKEN = "test-next-token" + +# Tags +BATCH_TAGS = {"batch-key": "batch-value", "environment": "test"} +TRAINING_TAGS = [ + {"Key": "training-key", "Value": "training-value"}, + {"Key": "project", "Value": "test-project"}, +] +TRAINING_TAGS_CONVERTED = {"training-key": "training-value", "project": "test-project"} +MERGED_TAGS = {**BATCH_TAGS, **TRAINING_TAGS_CONVERTED} + +# Retry configuration +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} + +# Timeout configuration +TIMEOUT_CONFIG = {"attemptDurationSeconds": ATTEMPT_DURATION_IN_SECONDS} + +# API responses +SUBMIT_SERVICE_JOB_RESP = { + "jobArn": JOB_ARN, + "jobName": JOB_NAME, + "jobId": JOB_ID, +} + +DESCRIBE_SERVICE_JOB_RESP_RUNNING = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_RUNNING, + "createdAt": 1702300000, + "startedAt": 1702300100, +} + +DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_SUCCEEDED, + "createdAt": 1702300000, + "startedAt": 1702300100, + "stoppedAt": 1702300400, + "latestAttempt": { + "serviceResourceId": { + "name": "trainingJobArn", + "value": TRAINING_JOB_ARN, + }, + "startedAt": 1702300100, + "stoppedAt": 1702300400, + }, +} + +DESCRIBE_SERVICE_JOB_RESP_FAILED = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_FAILED, + "statusReason": "Task failed", + "createdAt": 1702300000, + "startedAt": 1702300100, + "stoppedAt": 1702300200, +} + +DESCRIBE_SERVICE_JOB_RESP_PENDING = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_PENDING, + "createdAt": 1702300000, +} + +LIST_SERVICE_JOB_RESP_EMPTY = { + "jobSummaryList": [], + "nextToken": None, +} + +LIST_SERVICE_JOB_RESP_WITH_JOBS = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, + {"jobName": "another-job", "jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id", "jobId": "another-id"}, + ], + "nextToken": None, +} + +LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, + ], + "nextToken": NEXT_TOKEN, +} + +# Training payload +TRAINING_JOB_PAYLOAD = { + "TrainingJobName": TRAINING_JOB_NAME, + "RoleArn": EXECUTION_ROLE, + "OutputDataConfig": {"S3OutputPath": S3_OUTPUT_PATH}, + "ResourceConfig": { + "InstanceType": INSTANCE_TYPE, + "InstanceCount": INSTANCE_COUNT, + "VolumeSizeInGB": VOLUME_SIZE_IN_GB, + }, + "StoppingCondition": {"MaxRuntimeInSeconds": MAX_RUNTIME_IN_SECONDS}, + "AlgorithmSpecification": { + "TrainingImage": TRAINING_IMAGE, + "TrainingInputMode": "File", + }, +} diff --git a/sagemaker-train/tests/unit/train/aws_batch/constants.py b/sagemaker-train/tests/unit/train/aws_batch/constants.py new file mode 100644 index 0000000000..c33baa0752 --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/constants.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Test constants for AWS Batch unit tests""" + +# Job identifiers +JOB_NAME = "test-training-job" +JOB_ARN = "arn:aws:batch:us-west-2:123456789012:job/test-job-id" +JOB_ID = "test-job-id" +JOB_QUEUE = "test-queue" + +# Training job identifiers +TRAINING_JOB_NAME = "training-job-20251211" +TRAINING_JOB_ARN = "arn:aws:sagemaker:us-west-2:123456789012:training-job/training-job-20251211" + +# Job statuses +JOB_STATUS_SUBMITTED = "SUBMITTED" +JOB_STATUS_PENDING = "PENDING" +JOB_STATUS_RUNNABLE = "RUNNABLE" +JOB_STATUS_STARTING = "STARTING" +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_SUCCEEDED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" + +# Configuration values +INSTANCE_TYPE = "ml.m5.xlarge" +INSTANCE_COUNT = 1 +VOLUME_SIZE_IN_GB = 30 +MAX_RUNTIME_IN_SECONDS = 3600 +TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.5-gpu-py311" +EXECUTION_ROLE = "arn:aws:iam::123456789012:role/SageMakerRole" +S3_OUTPUT_PATH = "s3://my-bucket/output" + +# Batch configuration +SCHEDULING_PRIORITY = 1 +SHARE_IDENTIFIER = "test-share-id" +ATTEMPT_DURATION_IN_SECONDS = 86400 +REASON = "Test termination reason" +NEXT_TOKEN = "test-next-token" + +# Tags +BATCH_TAGS = {"batch-key": "batch-value", "environment": "test"} +TRAINING_TAGS = [ + {"Key": "training-key", "Value": "training-value"}, + {"Key": "project", "Value": "test-project"}, +] +TRAINING_TAGS_CONVERTED = {"training-key": "training-value", "project": "test-project"} +MERGED_TAGS = {**BATCH_TAGS, **TRAINING_TAGS_CONVERTED} + +# Retry configuration +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} + +# Timeout configuration +TIMEOUT_CONFIG = {"attemptDurationSeconds": ATTEMPT_DURATION_IN_SECONDS} + +# API responses +SUBMIT_SERVICE_JOB_RESP = { + "jobArn": JOB_ARN, + "jobName": JOB_NAME, + "jobId": JOB_ID, +} + +DESCRIBE_SERVICE_JOB_RESP_RUNNING = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_RUNNING, + "createdAt": 1702300000, + "startedAt": 1702300100, +} + +DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_SUCCEEDED, + "createdAt": 1702300000, + "startedAt": 1702300100, + "stoppedAt": 1702300400, + "latestAttempt": { + "serviceResourceId": { + "name": "trainingJobArn", + "value": TRAINING_JOB_ARN, + }, + "startedAt": 1702300100, + "stoppedAt": 1702300400, + }, +} + +DESCRIBE_SERVICE_JOB_RESP_FAILED = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_FAILED, + "statusReason": "Task failed", + "createdAt": 1702300000, + "startedAt": 1702300100, + "stoppedAt": 1702300200, +} + +DESCRIBE_SERVICE_JOB_RESP_PENDING = { + "jobId": JOB_ID, + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobQueue": JOB_QUEUE, + "status": JOB_STATUS_PENDING, + "createdAt": 1702300000, +} + +LIST_SERVICE_JOB_RESP_EMPTY = { + "jobSummaryList": [], + "nextToken": None, +} + +LIST_SERVICE_JOB_RESP_WITH_JOBS = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, + {"jobName": "another-job", "jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id", "jobId": "another-id"}, + ], + "nextToken": None, +} + +LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, + ], + "nextToken": NEXT_TOKEN, +} + +# Training payload +TRAINING_JOB_PAYLOAD = { + "TrainingJobName": TRAINING_JOB_NAME, + "RoleArn": EXECUTION_ROLE, + "OutputDataConfig": {"S3OutputPath": S3_OUTPUT_PATH}, + "ResourceConfig": { + "InstanceType": INSTANCE_TYPE, + "InstanceCount": INSTANCE_COUNT, + "VolumeSizeInGB": VOLUME_SIZE_IN_GB, + }, + "StoppingCondition": {"MaxRuntimeInSeconds": MAX_RUNTIME_IN_SECONDS}, + "AlgorithmSpecification": { + "TrainingImage": TRAINING_IMAGE, + "TrainingInputMode": "File", + }, +} diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py b/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py new file mode 100644 index 0000000000..16391161db --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py @@ -0,0 +1,260 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for batch_api_helper module""" + +import json +import pytest +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.train.aws_batch.batch_api_helper import ( + submit_service_job, + describe_service_job, + terminate_service_job, + list_service_job, +) +from .conftest import ( + JOB_NAME, + JOB_QUEUE, + JOB_ID, + REASON, + BATCH_TAGS, + TRAINING_TAGS, + TRAINING_TAGS_CONVERTED, + MERGED_TAGS, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + TIMEOUT_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + SUBMIT_SERVICE_JOB_RESP, + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + LIST_SERVICE_JOB_RESP_EMPTY, + LIST_SERVICE_JOB_RESP_WITH_JOBS, + LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN, + TRAINING_JOB_PAYLOAD, + NEXT_TOKEN, + JOB_STATUS_RUNNING, +) + + +class TestSubmitServiceJob: + """Tests for submit_service_job function""" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_submit_service_job_basic(self, mock_get_client): + """Test basic submit_service_job call""" + mock_client = Mock() + mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + mock_get_client.return_value = mock_client + + result = submit_service_job( + TRAINING_JOB_PAYLOAD, + JOB_NAME, + JOB_QUEUE, + ) + + assert result["jobArn"] == SUBMIT_SERVICE_JOB_RESP["jobArn"] + assert result["jobName"] == SUBMIT_SERVICE_JOB_RESP["jobName"] + assert result["jobId"] == SUBMIT_SERVICE_JOB_RESP["jobId"] + mock_client.submit_service_job.assert_called_once() + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_submit_service_job_with_all_params(self, mock_get_client): + """Test submit_service_job with all optional parameters""" + mock_client = Mock() + mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + mock_get_client.return_value = mock_client + + result = submit_service_job( + TRAINING_JOB_PAYLOAD, + JOB_NAME, + JOB_QUEUE, + retry_config=DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + scheduling_priority=SCHEDULING_PRIORITY, + timeout=TIMEOUT_CONFIG, + share_identifier=SHARE_IDENTIFIER, + tags=BATCH_TAGS, + ) + + assert result["jobArn"] == SUBMIT_SERVICE_JOB_RESP["jobArn"] + call_kwargs = mock_client.submit_service_job.call_args[1] + assert call_kwargs["jobName"] == JOB_NAME + assert call_kwargs["jobQueue"] == JOB_QUEUE + assert call_kwargs["schedulingPriority"] == SCHEDULING_PRIORITY + assert call_kwargs["shareIdentifier"] == SHARE_IDENTIFIER + assert call_kwargs["timeoutConfig"] == TIMEOUT_CONFIG + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_submit_service_job_with_tags(self, mock_get_client): + """Test submit_service_job merges batch and training tags""" + mock_client = Mock() + mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + mock_get_client.return_value = mock_client + + payload = TRAINING_JOB_PAYLOAD.copy() + payload["Tags"] = TRAINING_TAGS + + result = submit_service_job( + payload, + JOB_NAME, + JOB_QUEUE, + tags=BATCH_TAGS, + ) + + assert result["jobArn"] == SUBMIT_SERVICE_JOB_RESP["jobArn"] + call_kwargs = mock_client.submit_service_job.call_args[1] + assert "tags" in call_kwargs + # Verify tags were merged + merged = call_kwargs["tags"] + assert merged["batch-key"] == "batch-value" + assert merged["training-key"] == "training-value" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_submit_service_job_payload_serialized(self, mock_get_client): + """Test that training payload is JSON serialized""" + mock_client = Mock() + mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + mock_get_client.return_value = mock_client + + submit_service_job( + TRAINING_JOB_PAYLOAD, + JOB_NAME, + JOB_QUEUE, + ) + + call_kwargs = mock_client.submit_service_job.call_args[1] + payload_str = call_kwargs["serviceRequestPayload"] + # Verify it's a JSON string + parsed = json.loads(payload_str) + assert parsed["TrainingJobName"] == TRAINING_JOB_PAYLOAD["TrainingJobName"] + + +class TestDescribeServiceJob: + """Tests for describe_service_job function""" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_describe_service_job(self, mock_get_client): + """Test describe_service_job returns job details""" + mock_client = Mock() + mock_client.describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + mock_get_client.return_value = mock_client + + result = describe_service_job(JOB_ID) + + assert result["jobId"] == JOB_ID + assert result["status"] == "RUNNING" + mock_client.describe_service_job.assert_called_once_with(jobId=JOB_ID) + + +class TestTerminateServiceJob: + """Tests for terminate_service_job function""" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_terminate_service_job(self, mock_get_client): + """Test terminate_service_job calls terminate API""" + mock_client = Mock() + mock_client.terminate_service_job.return_value = {} + mock_get_client.return_value = mock_client + + result = terminate_service_job(JOB_ID, REASON) + + assert result == {} + mock_client.terminate_service_job.assert_called_once_with( + jobId=JOB_ID, reason=REASON + ) + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_terminate_service_job_default_reason(self, mock_get_client): + """Test terminate_service_job with default reason""" + mock_client = Mock() + mock_client.terminate_service_job.return_value = {} + mock_get_client.return_value = mock_client + + terminate_service_job(JOB_ID) + + call_kwargs = mock_client.terminate_service_job.call_args[1] + assert call_kwargs["jobId"] == JOB_ID + assert "reason" in call_kwargs + + +class TestListServiceJob: + """Tests for list_service_job function""" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_list_service_job_empty(self, mock_get_client): + """Test list_service_job with no jobs""" + mock_client = Mock() + mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_EMPTY + mock_get_client.return_value = mock_client + + gen = list_service_job(JOB_QUEUE) + result = next(gen) + + assert result["jobSummaryList"] == [] + assert result["nextToken"] is None + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_list_service_job_with_jobs(self, mock_get_client): + """Test list_service_job returns jobs""" + mock_client = Mock() + mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_WITH_JOBS + mock_get_client.return_value = mock_client + + gen = list_service_job(JOB_QUEUE) + result = next(gen) + + assert len(result["jobSummaryList"]) == 2 + assert result["jobSummaryList"][0]["jobName"] == JOB_NAME + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_list_service_job_with_pagination(self, mock_get_client): + """Test list_service_job handles pagination""" + mock_client = Mock() + mock_client.list_service_jobs.side_effect = [ + LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN, + LIST_SERVICE_JOB_RESP_EMPTY, + ] + mock_get_client.return_value = mock_client + + gen = list_service_job(JOB_QUEUE) + first_result = next(gen) + assert first_result["nextToken"] == NEXT_TOKEN + + second_result = next(gen) + assert second_result["jobSummaryList"] == [] + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_list_service_job_with_filters(self, mock_get_client): + """Test list_service_job with filters""" + mock_client = Mock() + mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_WITH_JOBS + mock_get_client.return_value = mock_client + + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + gen = list_service_job(JOB_QUEUE, filters=filters) + result = next(gen) + + call_kwargs = mock_client.list_service_jobs.call_args[1] + assert call_kwargs["filters"] == filters + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_list_service_job_with_status(self, mock_get_client): + """Test list_service_job with job status filter""" + mock_client = Mock() + mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_WITH_JOBS + mock_get_client.return_value = mock_client + + gen = list_service_job(JOB_QUEUE, job_status=JOB_STATUS_RUNNING) + result = next(gen) + + call_kwargs = mock_client.list_service_jobs.call_args[1] + assert call_kwargs["jobStatus"] == JOB_STATUS_RUNNING diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py new file mode 100644 index 0000000000..8cc19a6d7a --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py @@ -0,0 +1,328 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for training_queue module""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.train.aws_batch.training_queue import TrainingQueue +from sagemaker.train.model_trainer import ModelTrainer, Mode +from .conftest import ( + JOB_NAME, + JOB_QUEUE, + JOB_ARN, + JOB_ID, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SUBMIT_SERVICE_JOB_RESP, + LIST_SERVICE_JOB_RESP_WITH_JOBS, + LIST_SERVICE_JOB_RESP_EMPTY, + TRAINING_JOB_PAYLOAD, +) + + +class TestTrainingQueueInit: + """Tests for TrainingQueue initialization""" + + def test_training_queue_init(self): + """Test TrainingQueue initialization""" + queue = TrainingQueue(JOB_QUEUE) + assert queue.queue_name == JOB_QUEUE + + +class TestTrainingQueueSubmit: + """Tests for TrainingQueue.submit method""" + + @patch("sagemaker.train.aws_batch.training_queue.submit_service_job") + def test_submit_model_trainer(self, mock_submit_service_job): + """Test submit with ModelTrainer""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + queued_job = queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + assert queued_job.job_name == JOB_NAME + assert queued_job.job_arn == JOB_ARN + mock_submit_service_job.assert_called_once() + + @patch("sagemaker.train.aws_batch.training_queue.submit_service_job") + def test_submit_with_default_timeout(self, mock_submit_service_job): + """Test submit uses default timeout when not provided""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + None, # No timeout + BATCH_TAGS, + ) + + call_kwargs = mock_submit_service_job.call_args[0] + # Timeout should be set to default + assert call_kwargs[5] is not None + + @patch("sagemaker.train.aws_batch.training_queue.submit_service_job") + def test_submit_with_generated_job_name(self, mock_submit_service_job): + """Test submit generates job name from payload if not provided""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + trainer, + [], + None, # No job name provided + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + call_kwargs = mock_submit_service_job.call_args[0] + # Job name should come from payload + assert call_kwargs[1] == TRAINING_JOB_PAYLOAD["TrainingJobName"] + + def test_submit_invalid_training_job_type(self): + """Test submit raises error for invalid training job type""" + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(TypeError, match="training_job must be an instance of ModelTrainer"): + queue.submit( + "not-a-trainer", + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + def test_submit_invalid_training_mode(self): + """Test submit raises error for invalid training mode""" + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.LOCAL_CONTAINER + + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(ValueError, match="Mode.SAGEMAKER_TRAINING_JOB"): + queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + @patch("sagemaker.train.aws_batch.training_queue.submit_service_job") + def test_submit_missing_job_arn_in_response(self, mock_submit_service_job): + """Test submit raises error when jobArn missing from response""" + mock_submit_service_job.return_value = {"jobName": JOB_NAME} # Missing jobArn + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(Exception): # MissingRequiredArgument + queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + +class TestTrainingQueueMap: + """Tests for TrainingQueue.map method""" + + @patch("sagemaker.train.aws_batch.training_queue.submit_service_job") + def test_map_multiple_inputs(self, mock_submit_service_job): + """Test map submits multiple jobs""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + inputs = ["input1", "input2", "input3"] + queued_jobs = queue.map( + trainer, + inputs, + None, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + assert len(queued_jobs) == 3 + assert mock_submit_service_job.call_count == 3 + + @patch("sagemaker.train.aws_batch.training_queue.submit_service_job") + def test_map_with_job_names(self, mock_submit_service_job): + """Test map with explicit job names""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + inputs = ["input1", "input2"] + job_names = ["job1", "job2"] + queued_jobs = queue.map( + trainer, + inputs, + job_names, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + assert len(queued_jobs) == 2 + + def test_map_mismatched_job_names_length(self): + """Test map raises error when job names length doesn't match inputs""" + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + + queue = TrainingQueue(JOB_QUEUE) + inputs = ["input1", "input2"] + job_names = ["job1"] # Mismatch + + with pytest.raises(ValueError, match="number of job names must match"): + queue.map( + trainer, + inputs, + job_names, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + ) + + +class TestTrainingQueueList: + """Tests for TrainingQueue.list_jobs method""" + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_list_jobs_default(self, mock_list_service_job): + """Test list_jobs with default parameters""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS]) + + queue = TrainingQueue(JOB_QUEUE) + jobs = queue.list_jobs() + + assert len(jobs) == 2 + assert jobs[0].job_name == JOB_NAME + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_list_jobs_with_name_filter(self, mock_list_service_job): + """Test list_jobs with job name filter""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS]) + + queue = TrainingQueue(JOB_QUEUE) + jobs = queue.list_jobs(job_name=JOB_NAME) + + # Verify list_service_job was called + mock_list_service_job.assert_called_once() + + # Get the call arguments - list_service_job is called with positional args: + # list_service_job(queue_name, status, filters, next_token) + call_args = mock_list_service_job.call_args[0] + + # The 3rd positional argument (index 2) is filters + filters = call_args[2] if len(call_args) > 2 else None + + # Verify filters contain the job name + assert filters is not None, "Filters should be passed to list_service_job" + assert filters[0]["name"] == "JOB_NAME", "JOB_NAME filter should be present" + assert filters[0]["values"] == [JOB_NAME], "Filter values should contain the job name" + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_list_jobs_empty(self, mock_list_service_job): + """Test list_jobs returns empty list""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY]) + + queue = TrainingQueue(JOB_QUEUE) + jobs = queue.list_jobs() + + assert len(jobs) == 0 + + +class TestTrainingQueueGet: + """Tests for TrainingQueue.get_job method""" + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_get_job_found(self, mock_list_service_job): + """Test get_job returns job when found""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS]) + + queue = TrainingQueue(JOB_QUEUE) + job = queue.get_job(JOB_NAME) + + assert job.job_name == JOB_NAME + assert job.job_arn == JOB_ARN + + @patch("sagemaker.train.aws_batch.training_queue.list_service_job") + def test_get_job_not_found(self, mock_list_service_job): + """Test get_job raises error when job not found""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY]) + + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(ValueError, match="Cannot find job"): + queue.get_job(JOB_NAME) diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py new file mode 100644 index 0000000000..2e3ca916d2 --- /dev/null +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py @@ -0,0 +1,265 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for training_queued_job module""" + +import pytest +import time +import asyncio +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.train.aws_batch.training_queued_job import TrainingQueuedJob +from sagemaker.train.aws_batch.exception import NoTrainingJob, MissingRequiredArgument +from .conftest import ( + JOB_NAME, + JOB_ARN, + REASON, + TRAINING_JOB_NAME, + TRAINING_JOB_ARN, + JOB_STATUS_PENDING, + JOB_STATUS_RUNNING, + JOB_STATUS_SUCCEEDED, + JOB_STATUS_FAILED, + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED, + DESCRIBE_SERVICE_JOB_RESP_FAILED, + DESCRIBE_SERVICE_JOB_RESP_PENDING, +) + + +class TestTrainingQueuedJobInit: + """Tests for TrainingQueuedJob initialization""" + + def test_training_queued_job_init(self): + """Test TrainingQueuedJob initialization""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job.job_arn == JOB_ARN + assert queued_job.job_name == JOB_NAME + + +class TestTrainingQueuedJobDescribe: + """Tests for TrainingQueuedJob.describe method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_describe(self, mock_describe_service_job): + """Test describe returns job details""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.describe() + + assert result["status"] == JOB_STATUS_RUNNING + mock_describe_service_job.assert_called_once_with(JOB_ARN) + + +class TestTrainingQueuedJobTerminate: + """Tests for TrainingQueuedJob.terminate method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.terminate_service_job") + def test_terminate(self, mock_terminate_service_job): + """Test terminate calls terminate API""" + mock_terminate_service_job.return_value = {} + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.terminate(REASON) + + mock_terminate_service_job.assert_called_once_with(JOB_ARN, REASON) + + @patch("sagemaker.train.aws_batch.training_queued_job.terminate_service_job") + def test_terminate_default_reason(self, mock_terminate_service_job): + """Test terminate with default reason""" + mock_terminate_service_job.return_value = {} + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.terminate() + + call_kwargs = mock_terminate_service_job.call_args[0] + assert call_kwargs[0] == JOB_ARN + + +class TestTrainingQueuedJobWait: + """Tests for TrainingQueuedJob.wait method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_immediate_completion(self, mock_describe_service_job): + """Test wait returns immediately when job is completed""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + + assert result["status"] == JOB_STATUS_SUCCEEDED + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_with_polling(self, mock_describe_service_job): + """Test wait polls until job completes""" + mock_describe_service_job.side_effect = [ + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + DESCRIBE_SERVICE_JOB_RESP_RUNNING, + DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED, + ] + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + + assert result["status"] == JOB_STATUS_SUCCEEDED + assert mock_describe_service_job.call_count == 3 + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_with_timeout(self, mock_describe_service_job): + """Test wait respects timeout""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + start_time = time.time() + result = queued_job.wait(timeout=2) + end_time = time.time() + + # Should timeout after approximately 2 seconds + assert end_time - start_time >= 2 + assert result["status"] == JOB_STATUS_RUNNING + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_wait_job_failed(self, mock_describe_service_job): + """Test wait returns failed status""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_FAILED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + + assert result["status"] == JOB_STATUS_FAILED + + +class TestTrainingQueuedJobGetModelTrainer: + """Tests for TrainingQueuedJob.get_model_trainer method""" + + @patch("sagemaker.train.aws_batch.training_queued_job._remove_system_tags_in_place_in_model_trainer_object") + @patch("sagemaker.train.aws_batch.training_queued_job._construct_model_trainer_from_training_job_name") + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_get_model_trainer_success(self, mock_describe_service_job, mock_construct_trainer, mock_remove_tags): + """Test get_model_trainer returns ModelTrainer when training job created""" + # Return a real dict (not a mock) so nested dict access works + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + mock_trainer = Mock() + mock_construct_trainer.return_value = mock_trainer + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.get_model_trainer() + + assert result == mock_trainer + mock_construct_trainer.assert_called_once() + mock_remove_tags.assert_called_once_with(mock_trainer) + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_get_model_trainer_no_training_job_pending(self, mock_describe_service_job): + """Test get_model_trainer raises error when job still pending""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_PENDING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(NoTrainingJob): + queued_job.get_model_trainer() + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_get_model_trainer_no_latest_attempt(self, mock_describe_service_job): + """Test get_model_trainer raises error when latestAttempt missing""" + resp = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED.copy() + del resp["latestAttempt"] + mock_describe_service_job.return_value = resp + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(MissingRequiredArgument): + queued_job.get_model_trainer() + + +class TestTrainingQueuedJobResult: + """Tests for TrainingQueuedJob.result method""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_result_success(self, mock_describe_service_job): + """Test result returns job result when completed""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.result(timeout=100) + + assert result["status"] == JOB_STATUS_SUCCEEDED + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_result_timeout(self, mock_describe_service_job): + """Test result raises TimeoutError when timeout exceeded""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(TimeoutError): + queued_job.result(timeout=1) + + +class TestTrainingQueuedJobAsync: + """Tests for TrainingQueuedJob async methods""" + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_fetch_job_results_success(self, mock_describe_service_job): + """Test fetch_job_results returns result when job succeeds""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = asyncio.run(queued_job.fetch_job_results()) + + assert result["status"] == JOB_STATUS_SUCCEEDED + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_fetch_job_results_failed(self, mock_describe_service_job): + """Test fetch_job_results raises error when job fails""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_FAILED + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(RuntimeError): + asyncio.run(queued_job.fetch_job_results()) + + @patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job") + def test_fetch_job_results_timeout(self, mock_describe_service_job): + """Test fetch_job_results raises TimeoutError when timeout exceeded""" + mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + with pytest.raises(TimeoutError): + asyncio.run(queued_job.fetch_job_results(timeout=1)) + + +class TestTrainingQueuedJobTrainingJobCreated: + """Tests for TrainingQueuedJob._training_job_created method""" + + def test_training_job_created_running(self): + """Test _training_job_created returns True for RUNNING status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_RUNNING) is True + + def test_training_job_created_succeeded(self): + """Test _training_job_created returns True for SUCCEEDED status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_SUCCEEDED) is True + + def test_training_job_created_failed(self): + """Test _training_job_created returns True for FAILED status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_FAILED) is True + + def test_training_job_created_pending(self): + """Test _training_job_created returns False for PENDING status""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + assert queued_job._training_job_created(JOB_STATUS_PENDING) is False diff --git a/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb b/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb new file mode 100644 index 0000000000..b7f7c72cd7 --- /dev/null +++ b/v3-examples/training-examples/aws_batch/sm-training-queues_getting_started_with_model_trainer.ipynb @@ -0,0 +1,1005 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f6d21c4c-e5fe-4992-9fd9-a33f36e4db2d", + "metadata": {}, + "source": [ + "# Getting Started with AWS Batch for SageMaker Training jobs\n", + "\n", + "---\n", + "\n", + "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", + "\n", + "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "---\n", + "\n", + "This sample notebook will demonstrate how to submit some simple 'hello world' jobs to an [AWS Batch job queue](https://aws.amazon.com/batch/) using a [ModelTrainer](https://sagemaker.readthedocs.io/en/stable/api/training/model_trainer.html). You can run any of the cells in this notebook interactively to experiment with using your queue. Batch will take care of ensuring your jobs run automatically as your service environment capacity becomes available. " + ] + }, + { + "cell_type": "markdown", + "id": "10e12b35-3dc2-4376-b90e-54c00c70a607", + "metadata": { + "tags": [] + }, + "source": [ + "## Setup and Configure Training Job Variables\n", + "We will need a single instance for a short duration for the sample jobs. Change any of the constant variables below to adjust the example to your liking. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6316085c-262d-4437-8987-9ca7eca94965", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "INSTANCE_TYPE = \"ml.g5.xlarge\"\n", + "INSTANCE_COUNT = 1\n", + "MAX_RUN_TIME = 300\n", + "TRAINING_JOB_NAME = \"hello-world-simple-job\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4edef56-49f3-4729-afdd-5345c5710363", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.basicConfig(\n", + " level=logging.INFO, format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n", + ")\n", + "logging.getLogger(\"botocore.client\").setLevel(level=logging.WARN)\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "from sagemaker.core.helper.session_helper import Session\n", + "from sagemaker.core import image_uris\n", + "\n", + "session = Session()\n", + "\n", + "image_uri = image_uris.retrieve(\n", + " framework=\"pytorch\",\n", + " region=session.boto_session.region_name,\n", + " version=\"2.5\",\n", + " instance_type=INSTANCE_TYPE,\n", + " image_scope=\"training\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6ad14b9b-360f-446f-94ec-5c7cbd6e0818", + "metadata": {}, + "source": [ + "## Create Sample Resources\n", + "The diagram belows shows the Batch resources we'll create for this example.\n", + "\n", + "![The Resources to Create](batch_getting_started_resources.png \"Example Job Queue and Service Environment Resources\")\n", + "\n", + "You can use [Batch Console](https://console.aws.amazon.com/batch) to create these resources, or you can run the cell below. The ```create_resources``` function below will skip creating any resources that already exist." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e325ddb0-aa86-4f3b-9820-753f4bdadb19", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from sagemaker.train.aws_batch.boto_client import get_batch_boto_client\n", + "from utils.aws_batch_resource_management import AwsBatchResourceManager, create_resources\n", + "\n", + "# This job queue name needs to match the Job Queue created in AWS Batch.\n", + "JOB_QUEUE_NAME = \"my-sm-training-fifo-jq\"\n", + "SERVICE_ENVIRONMENT_NAME = \"my-sm-training-fifo-se\"\n", + "\n", + "# Create ServiceEnvironment and JobQueue\n", + "resource_manager = AwsBatchResourceManager(get_batch_boto_client())\n", + "resources = create_resources(\n", + " resource_manager, JOB_QUEUE_NAME, SERVICE_ENVIRONMENT_NAME, max_capacity=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2a5d5e2e-b266-41c4-9d17-ce6c93a42db3", + "metadata": {}, + "source": [ + "## Create Hello World Model Trainer\n", + "Now that our resources are created, we'll construct a simple ModelTrainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d71f7b99-63a3-4fd0-b735-6140fe1489f6", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import SourceCode, Compute, StoppingCondition\n", + "\n", + "source_code = SourceCode(command=\"echo 'Hello World'\")\n", + "\n", + "model_trainer = ModelTrainer(\n", + " training_image=image_uri,\n", + " source_code=source_code,\n", + " base_job_name=TRAINING_JOB_NAME,\n", + " compute=Compute(instance_type=INSTANCE_TYPE, instance_count=INSTANCE_COUNT),\n", + " stopping_condition=StoppingCondition(max_runtime_in_seconds=MAX_RUN_TIME),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "48f030c1-3e32-4265-a1fe-cdd6927e82ac", + "metadata": {}, + "source": [ + "## Create TrainingQueue object\n", + "Using our queue is as easy as referring to it by name in the TrainingQueue contructor. The TrainingQueue class within the SageMaker Python SDK provides built in support for working with Batch queues." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d90c9a4-ff38-492e-b446-61674701d9ca", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.train.aws_batch.training_queue import TrainingQueue, TrainingQueuedJob\n", + "\n", + "# Construct the queue object using the SageMaker Python SDK\n", + "queue = TrainingQueue(JOB_QUEUE_NAME)\n", + "logger.info(f\"Using queue: {queue.queue_name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "734b7fe6-0bf7-460e-aa95-7608421e900c", + "metadata": {}, + "source": [ + "## Submit Some Training Jobs\n", + "Submitting your job to the queue is done by calling queue.submit. This particular job doesn't require any data, but in general, data should be provided by specifying inputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60f2ca52-b1ea-4af2-a143-8202ce34d5e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Submit first job\n", + "training_queued_job_1: TrainingQueuedJob = queue.submit(training_job=model_trainer, inputs=None)\n", + "logger.info(\n", + " f\"Submitted job '{training_queued_job_1.job_name}' to TrainingQueue '{queue.queue_name}'\"\n", + ")\n", + "\n", + "# Submit second job\n", + "training_queued_job_2: TrainingQueuedJob = queue.submit(training_job=model_trainer, inputs=None)\n", + "logger.info(\n", + " f\"Submitted job '{training_queued_job_2.job_name}' to TrainingQueue '{queue.queue_name}'\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "da85757f-d17c-402f-87b8-164149ea9165", + "metadata": { + "tags": [] + }, + "source": [ + "## Terminate a Job in the Queue\n", + "This next cell shows how to terminate an in queue job." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b467767b-21b4-4cc0-987d-d07ef7f15ca0", + "metadata": {}, + "outputs": [], + "source": [ + "logger.info(f\"Terminating job: {training_queued_job_2.job_name}\")\n", + "training_queued_job_2.terminate()" + ] + }, + { + "cell_type": "markdown", + "id": "d578124b-8067-47cf-895c-3076a242d6a6", + "metadata": {}, + "source": [ + "## Monitor Job Status\n", + "This next cell shows how to list the jobs that have been submitted to the TrainingQueue. The TrainingQueue can list jobs by status, and each job can be described individually for more details. Once a TrainingQueuedJob has reached the STARTING status, the logs can be printed from underlying SageMaker training job." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e8de2fe-ab1d-4703-ace3-1e44ae1b2d7c", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "\n", + "def list_jobs_in_training_queue(training_queue: TrainingQueue):\n", + " \"\"\"\n", + " Lists all jobs in a TrainingQueue grouped by their status.\n", + "\n", + " This function retrieves jobs with different statuses (SUBMITTED, PENDING, RUNNABLE,\n", + " SCHEDULED, STARTING, RUNNING, SUCCEEDED, FAILED) from the specified TrainingQueue\n", + " and logs their names and current status.\n", + "\n", + " Args:\n", + " training_queue (TrainingQueue): The TrainingQueue to query for jobs.\n", + "\n", + " Returns:\n", + " None: This function doesn't return a value but logs job information.\n", + " \"\"\"\n", + " submitted_jobs = training_queue.list_jobs(status=\"SUBMITTED\")\n", + " pending_jobs = training_queue.list_jobs(status=\"PENDING\")\n", + " runnable_jobs = training_queue.list_jobs(status=\"RUNNABLE\")\n", + " scheduled_jobs = training_queue.list_jobs(status=\"SCHEDULED\")\n", + " starting_jobs = training_queue.list_jobs(status=\"STARTING\")\n", + " running_jobs = training_queue.list_jobs(status=\"RUNNING\")\n", + " completed_jobs = training_queue.list_jobs(status=\"SUCCEEDED\")\n", + " failed_jobs = training_queue.list_jobs(status=\"FAILED\")\n", + "\n", + " all_jobs = (\n", + " submitted_jobs\n", + " + pending_jobs\n", + " + runnable_jobs\n", + " + scheduled_jobs\n", + " + starting_jobs\n", + " + running_jobs\n", + " + completed_jobs\n", + " + failed_jobs\n", + " )\n", + "\n", + " for job in all_jobs:\n", + " job_status = job.describe().get(\"status\", \"\")\n", + " logger.info(f\"Job : {job.job_name} is {job_status}\")\n", + "\n", + "\n", + "def monitor_training_queued_job(job: TrainingQueuedJob):\n", + " \"\"\"\n", + " Monitors a TrainingQueuedJob until it reaches an active or terminal state.\n", + "\n", + " This function continuously polls the status of the specified TrainingQueuedJob\n", + " until it transitions to one of the following states: STARTING, RUNNING,\n", + " SUCCEEDED, or FAILED. Once the job reaches one of these states, the function\n", + " retrieves and displays the job's logs.\n", + "\n", + " Args:\n", + " job (TrainingQueuedJob): The TrainingQueuedJob to monitor.\n", + "\n", + " Returns:\n", + " None: This function doesn't return a value but displays job logs.\n", + " \"\"\"\n", + " while True:\n", + " job_status = job.describe().get(\"status\", \"\")\n", + "\n", + " if job_status in {\"STARTING\", \"RUNNING\", \"SUCCEEDED\", \"FAILED\"}:\n", + " break\n", + "\n", + " logger.info(f\"Job : {job.job_name} is {job_status}\")\n", + " time.sleep(5)\n", + "\n", + " # Print training job logs\n", + " # job.get_estimator().logs()\n", + " model_trainer = job.get_model_trainer()\n", + " model_trainer.sagemaker_session.logs_for_job(model_trainer._latest_training_job.training_job_name, wait=True)\n", + "\n", + "\n", + "logger.info(f\"Listing all jobs in queue '{queue.queue_name}'...\")\n", + "list_jobs_in_training_queue(queue)\n", + "\n", + "logger.info(f\"Polling job status for '{training_queued_job_1.job_name}'\")\n", + "monitor_training_queued_job(training_queued_job_1)" + ] + }, + { + "cell_type": "markdown", + "id": "3ca39fca-6bb6-4e0b-841e-7f66d97b6074", + "metadata": {}, + "source": [ + "# Optional: Delete AWS Batch Resources\n", + "This shows how to delete the AWS Batch ServiceEnvironment and JobQueue. This step is completely optional, uncomment the code below to delete the resources created a few steps above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d745e2d-40a8-45ef-b231-e8acd9b5e8eb", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from utils.aws_batch_resource_management import delete_resources\n", + "\n", + "# delete_resources(resource_manager, resources)" + ] + }, + { + "cell_type": "markdown", + "id": "7070dcc9", + "metadata": {}, + "source": [ + "## Notebook CI Test Results\n", + "\n", + "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", + "\n", + "\n", + "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n", + "\n", + "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/build_and_train_models|sm-training-queues|sm-training-queues_getting_started_with_model_trainer.ipynb)\n" + ] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + }, + { + "_defaultOrder": 55, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 56, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4de.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 57, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.trn1.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 58, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.trn1.32xlarge", + "vcpuNum": 128 + }, + { + "_defaultOrder": 59, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.trn1n.32xlarge", + "vcpuNum": 128 + } + ], + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "venv-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py new file mode 100644 index 0000000000..19f51e0bdc --- /dev/null +++ b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py @@ -0,0 +1,566 @@ +import json +import logging +import time +from dataclasses import dataclass + +import boto3 +from botocore.exceptions import ClientError + +# Configure logging +logger = logging.getLogger(__name__) + + +@dataclass +class Resource: + """ + Represents an AWS resource with a name and ARN. + + Attributes: + name (str): The name of the AWS resource. + arn (str): The Amazon Resource Name (ARN) of the resource. + """ + + name: str + arn: str + + +@dataclass +class Resources: + """ + Container for AWS Batch resources used in the application. + + Attributes: + job_queue (Resource): The AWS Batch job queue resource. + service_environment (Resource): The AWS Batch service environment resource. + """ + + job_queue: Resource = None + service_environment: Resource = None + batch_role: Resource = None + sagemaker_exeuction_role: Resource = None + + +class AwsBatchResourceManager: + """ + Manager for AWS Batch resources including service environments and job queues. + + This class provides methods to create, update, delete, and monitor AWS Batch resources. + + Attributes: + TERMINAL_JOB_STATUSES (set): Set of job statuses considered terminal. + """ + + TERMINAL_JOB_STATUSES = {"SUCCEEDED", "FAILED"} + + def __init__(self, batch_client): + """ + Initialize the AWS Batch Resource Manager. + + Args: + batch_client: The boto3 Batch client to use for AWS operations. + """ + self._batch_client = batch_client + + def create_service_environment(self, create_se_request: dict): + """ + Create a new AWS Batch service environment. + + If the service environment already exists, returns the existing environment details. + + Args: + create_se_request (dict): Request parameters for creating a service environment. + Must contain 'serviceEnvironmentName' key. + + Returns: + dict: Response containing the service environment name and ARN. + + Raises: + ClientError: If there's an error creating the service environment. + """ + try: + return self._batch_client.create_service_environment(**create_se_request) + except ClientError as error: + if error.response["message"] == "Object already exists": + logger.info("ServiceEnvironment already exists, skipping creation.") + desc_resp = self._batch_client.describe_service_environments( + serviceEnvironments=[create_se_request["serviceEnvironmentName"]] + ) + return { + "serviceEnvironmentName": desc_resp["serviceEnvironments"][0][ + "serviceEnvironmentName" + ], + "serviceEnvironmentArn": desc_resp["serviceEnvironments"][0][ + "serviceEnvironmentArn" + ], + } + + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def update_service_environment(self, update_se_request): + """ + Update an existing AWS Batch service environment. + + Args: + update_se_request (dict): Request parameters for updating a service environment. + + Returns: + dict: Response from the update operation. + + Raises: + ClientError: If there's an error updating the service environment. + """ + try: + return self._batch_client.update_service_environment(**update_se_request) + except ClientError as error: + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def await_service_environment_update( + self, service_environment_name: str, expected_status: str, expected_state: str + ): + """ + Wait for a service environment to reach the expected status and state. + + This method polls the service environment status until it reaches the expected state + or is deleted if that's the expected status. + + Args: + se_name (str): Name of the service environment to monitor. + expected_status (str): The expected status to wait for (e.g., "VALID", "DELETED"). + expected_state (str, optional): The expected state to wait for (e.g., "ENABLED", "DISABLED"). + + Returns: + dict: The describe service environments response when the expected state is reached. + """ + while True: + describe_response = self._batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + if describe_response["serviceEnvironments"]: + se = describe_response["serviceEnvironments"][0] + + state = se["state"] + status = se["status"] + + if status == expected_status and state == expected_state: + break + if status == "INVALID": + raise ValueError(f"Something went wrong! {json.dumps(jq, indent=4)}") + elif expected_status == "DELETED": + logger.info(f"ServiceEnvironment {service_environment_name} has been deleted") + break + + time.sleep(5) + + def delete_service_environment(self, service_environment_name: str): + """ + Delete an AWS Batch service environment. + + This method follows the proper deletion workflow: + 1. Disable the service environment + 2. Wait for the disable operation to complete + 3. Delete the service environment + 4. Wait for the deletion to complete + + Args: + se_name (str): Name of the service environment to delete. + """ + logger.info(f"Setting ServiceEnvironment {service_environment_name} to DISABLED") + self._batch_client.update_service_environment( + serviceEnvironment=service_environment_name, state="DISABLED" + ) + + logger.info("Waiting for ServiceEnvironment update to finish...") + self.await_service_environment_update(service_environment_name, "VALID", "DISABLED") + + logger.info(f"Deleting ServiceEnvironment {service_environment_name}") + self._batch_client.delete_service_environment(serviceEnvironment=service_environment_name) + + logger.info("Waiting for ServiceEnvironment update to finish...") + self.await_service_environment_update(service_environment_name, "DELETED", "DISABLED") + + def create_job_queue(self, create_jq_request: dict): + """ + Create a new AWS Batch job queue. + + If the job queue already exists, returns the existing job queue details. + + Args: + create_jq_request (dict): Request parameters for creating a job queue. + Must contain 'jobQueueName' key. + + Returns: + dict: Response containing the job queue name and ARN. + + Raises: + ClientError: If there's an error creating the job queue. + """ + try: + return self._batch_client.create_job_queue(**create_jq_request) + except ClientError as error: + if error.response["message"] == "Object already exists": + logger.info("JobQueue already exists, skipping creation") + desc_resp = self._batch_client.describe_job_queues( + jobQueues=[create_jq_request["jobQueueName"]] + ) + return { + "jobQueueName": desc_resp["jobQueues"][0]["jobQueueName"], + "jobQueueArn": desc_resp["jobQueues"][0]["jobQueueArn"], + } + + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def delete_job_queue(self, job_queue_name: str): + """ + Delete an AWS Batch job queue. + + This method follows the proper deletion workflow: + 1. Disable the job queue + 2. Wait for the disable operation to complete + 3. Delete the job queue + 4. Wait for the deletion to complete + + Args: + jq_name (str): Name of the job queue to delete. + """ + logger.info(f"Disabling JobQueue {job_queue_name}") + self._batch_client.update_job_queue(jobQueue=job_queue_name, state="DISABLED") + + logger.info("Waiting for JobQueue update to finish...") + self.await_job_queue_update(job_queue_name, "VALID", "DISABLED") + + logger.info(f"Deleting JobQueue {job_queue_name}") + self._batch_client.delete_job_queue(jobQueue=job_queue_name) + + logger.info("Waiting for JobQueue update to finish...") + self.await_job_queue_update(job_queue_name, "DELETED", "DISABLED") + + def await_job_queue_update( + self, job_queue_name: str, expected_status: str, expected_state: str + ): + """ + Wait for a job queue to reach the expected status and state. + + This method polls the job queue status until it reaches the expected state and status + or is deleted if that's the expected status. + + Args: + jq_name (str): Name of the job queue to monitor. + expected_status (str): The expected status to wait for (e.g., "VALID", "DELETED"). + expected_state (str, optional): The expected state to wait for (e.g., "ENABLED", "DISABLED"). + + Raises: + ValueError: If the job queue enters an INVALID status. + """ + while True: + describe_jq_response = self._batch_client.describe_job_queues( + jobQueues=[job_queue_name] + ) + if describe_jq_response["jobQueues"]: + jq = describe_jq_response["jobQueues"][0] + + state = jq["state"] + status = jq["status"] + + if status == expected_status and state == expected_state: + break + if status == "INVALID": + raise ValueError(f"Something went wrong! {json.dumps(jq, indent=4)}") + elif expected_status == "DELETED": + logger.info(f"JobQueue {job_queue_name} has been deleted") + break + + time.sleep(5) + + +class RoleManager: + """ + Manager for creating and managing IAM roles required for SageMaker training jobs with AWS Batch. + + This class provides methods to create the necessary IAM roles for AWS Batch to interact with + SageMaker training jobs, including the batch role and SageMaker execution role. + + Attributes: + iam_client: The boto3 IAM client to use for creating roles + sts_client: The boto3 STS client to use for getting account information + """ + + def __init__(self, iam_client, sts_client): + """ + Initialize the RoleManager with IAM and STS clients. + + Args: + iam_client: The boto3 IAM client to use. + sts_client: The boto3 STS client to use. + """ + self.iam_client = iam_client + self.sts_client = sts_client + + def create_batch_role(self, batch_role_name: str): + """ + Create an IAM role for AWS Batch to interact with SageMaker training jobs. + + This method creates a role with permissions for AWS Batch to manage SageMaker + training jobs, including the ability to create service-linked roles and pass roles + to SageMaker. + + Args: + batch_role_name (str): The name to use for the IAM role. + Returns: + Resource: A Resource object containing the name and ARN of the created role. + + Raises: + ClientError: If there's an error creating the role (except when the role already exists). + """ + get_caller_id_resp = self.sts_client.get_caller_identity() + account_id = get_caller_id_resp["Account"] + + try: + create_role_resp = self.iam_client.create_role( + RoleName=batch_role_name, + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"AWS": f"arn:aws:iam::{account_id}:root"}, + "Action": "sts:AssumeRole", + } + ], + } + ), + Description="Role for AWS Batch for SageMaker Training jobs.", + MaxSessionDuration=3600, + ) + self.iam_client.put_role_policy( + RoleName=batch_role_name, + PolicyName="AWSBatchForSageMakerTrainingJobsPolicy", + PolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": ["batch:*"], "Resource": "*"}, + { + "Effect": "Allow", + "Action": ["iam:CreateServiceLinkedRole"], + "Resource": "arn:aws:iam::*:role/*AWSServiceRoleForAWSBatchWithSagemaker", + "Condition": { + "StringEquals": { + "iam:AWSServiceName": "sagemaker-queuing.batch.amazonaws.com" + } + }, + }, + { + "Effect": "Allow", + "Action": "iam:PassRole", + "Resource": "*", + "Condition": { + "StringEquals": { + "iam:PassedToService": ["sagemaker.amazonaws.com"] + } + }, + }, + ], + } + ), + ) + + return Resource( + name=create_role_resp["Role"]["RoleName"], arn=create_role_resp["Role"]["Arn"] + ) + except ClientError as error: + if error.response["Error"]["Code"] == "EntityAlreadyExists": + print(error.response["Error"]["Message"]) + get_resp = self.iam_client.get_role(RoleName=batch_role_name) + + return Resource(name=get_resp["Role"]["RoleName"], arn=get_resp["Role"]["Arn"]) + + logger.error( + f"Error creating {batch_role_name}: {json.dumps(error.__dict__, indent=4)}" + ) + raise error + + def create_sagemaker_execution_role(self, sagemaker_execution_role_name: str): + """ + Create an IAM role for SageMaker to execute training jobs. + + This method creates a role with the AmazonSageMakerFullAccess policy attached, + allowing SageMaker to access necessary resources for training jobs. + + Args: + sagemaker_execution_role_name (str): The name to use for the IAM role. + Returns: + Resource: A Resource object containing the name and ARN of the created role. + + Raises: + ClientError: If there's an error creating the role (except when the role already exists). + """ + try: + create_role_resp = self.iam_client.create_role( + RoleName=sagemaker_execution_role_name, + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": ["sagemaker.amazonaws.com"]}, + "Action": "sts:AssumeRole", + } + ], + } + ), + Description="SageMaker training execution role.", + MaxSessionDuration=3600, + ) + self.iam_client.attach_role_policy( + RoleName=sagemaker_execution_role_name, + PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess", + ) + + return Resource( + name=create_role_resp["Role"]["RoleName"], arn=create_role_resp["Role"]["Arn"] + ) + except ClientError as error: + if error.response["Error"]["Code"] == "EntityAlreadyExists": + print(error.response["Error"]["Message"]) + get_resp = self.iam_client.get_role(RoleName=sagemaker_execution_role_name) + + return Resource(name=get_resp["Role"]["RoleName"], arn=get_resp["Role"]["Arn"]) + + logger.error( + f"Error creating {sagemaker_execution_role_name}: {json.dumps(error.__dict__, indent=4)}" + ) + raise error + + +def create_roles( + role_manager: RoleManager, batch_role_name: str, sagemaker_execution_role_name: str +): + """ + Create all required IAM roles for SageMaker training jobs with AWS Batch. + + This function creates both the AWS Batch role and the SageMaker execution role + using the current AWS account ID. + + Returns: + Resources: A Resources class containing the created roles + """ + logger.info("Creating batch role") + batch_role = role_manager.create_batch_role(batch_role_name) + + logger.info("Creating sagemaker execution role") + sagemaker_execution_role = role_manager.create_sagemaker_execution_role( + sagemaker_execution_role_name + ) + + resources = Resources(batch_role=batch_role, sagemaker_exeuction_role=sagemaker_execution_role) + + logger.info(f"Role creation complete: {resources}") + return resources + + +def assume_role_and_get_session(role: Resource, sts_client): + """ + Assumes the specified IAM role and returns a boto3 session with the assumed credentials. + + Args: + role: The IAM role resource to assume + sts_client: The boto3 STS client + + Returns: + A boto3 session configured with the assumed role credentials + """ + response = sts_client.assume_role(RoleArn=role.arn, RoleSessionName="AssumeRoleSession") + + credentials = response["Credentials"] + + return boto3.Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + +def create_resources( + resource_manager: AwsBatchResourceManager, + job_queue_name: str, + service_environment_name: str, + max_capacity: int = 1, +): + """ + Create AWS Batch resources including a service environment and job queue. + + This function creates a SageMaker training service environment and a corresponding + job queue, waiting for each resource to reach a VALID state before proceeding. + + Args: + resource_manager (AwsBatchResourceManager): The resource manager to use for creating resources. + + Returns: + Resources: A Resources object containing the created service environment and job queue. + """ + # Create ServiceEnvironment + logger.info(f"Creating ServiceEnvironment: {service_environment_name}") + create_se_resp = resource_manager.create_service_environment( + { + "serviceEnvironmentName": service_environment_name, + "serviceEnvironmentType": "SAGEMAKER_TRAINING", + "state": "ENABLED", + "capacityLimits": [{"maxCapacity": max_capacity, "capacityUnit": "NUM_INSTANCES"}], + } + ) + logger.info("Waiting for ServiceEnvironment to transition to VALID...") + resource_manager.await_service_environment_update(service_environment_name, "VALID", "ENABLED") + + # Create JobQueue + logger.info(f"Creating JobQueue: {job_queue_name}") + create_jq_response = resource_manager.create_job_queue( + { + "jobQueueName": job_queue_name, + "jobQueueType": "SAGEMAKER_TRAINING", + "state": "ENABLED", + "priority": 1, + "serviceEnvironmentOrder": [ + {"order": 1, "serviceEnvironment": create_se_resp["serviceEnvironmentName"]}, + ], + } + ) + logger.info("Waiting for JobQueue to transition to VALID...") + resource_manager.await_job_queue_update(job_queue_name, "VALID", "ENABLED") + + resources = Resources( + service_environment=Resource( + name=create_se_resp["serviceEnvironmentName"], + arn=create_se_resp["serviceEnvironmentArn"], + ), + job_queue=Resource( + name=create_jq_response["jobQueueName"], arn=create_jq_response["jobQueueArn"] + ), + ) + + logger.info(f"Resource creation complete: {resources}") + return resources + + +def delete_resources(resource_manager: AwsBatchResourceManager, resources: Resources): + """ + Delete AWS Batch resources. + + This function deletes the job queue first and then the service environment, + following the proper order for resource cleanup. + + Args: + resource_manager (AwsBatchResourceManager): The resource manager to use for deleting resources. + resources (Resources): The Resources object containing the resources to delete. + """ + if resources.job_queue: + logger.info(f"Deleting JobQueue: {resources.job_queue.name}") + resource_manager.delete_job_queue(resources.job_queue.name) + + if resources.service_environment: + logger.info(f"Deleting ServiceEnvironment: {resources.service_environment.name}") + resource_manager.delete_service_environment(resources.service_environment.name)