AWS SageMaker¶
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerScheduler(session_name: str, client: Optional[Any] = None, docker_client: Optional[DockerClient] = None)[源代码]¶
基类:
DockerWorkspaceMixin
,Scheduler
[AWSSageMakerOpts
]AWSSageMakerScheduler 是 AWS SageMaker 的 TorchX 调度接口。
$ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello aws_batch://torchx_user/1234 $ torchx status aws_batch://torchx_user/1234 ...
使用
boto3
凭据处理从环境中加载身份验证。配置选项
usage: role=ROLE,instance_type=INSTANCE_TYPE,[instance_count=INSTANCE_COUNT],[user=USER],[keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS],[volume_size=VOLUME_SIZE],[volume_kms_key=VOLUME_KMS_KEY],[max_run=MAX_RUN],[input_mode=INPUT_MODE],[output_path=OUTPUT_PATH],[output_kms_key=OUTPUT_KMS_KEY],[base_job_name=BASE_JOB_NAME],[tags=TAGS],[subnets=SUBNETS],[security_group_ids=SECURITY_GROUP_IDS],[model_uri=MODEL_URI],[model_channel_name=MODEL_CHANNEL_NAME],[metric_definitions=METRIC_DEFINITIONS],[encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC],[use_spot_instances=USE_SPOT_INSTANCES],[max_wait=MAX_WAIT],[checkpoint_s3_uri=CHECKPOINT_S3_URI],[checkpoint_local_path=CHECKPOINT_LOCAL_PATH],[debugger_hook_config=DEBUGGER_HOOK_CONFIG],[enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS],[enable_network_isolation=ENABLE_NETWORK_ISOLATION],[disable_profiler=DISABLE_PROFILER],[environment=ENVIRONMENT],[max_retry_attempts=MAX_RETRY_ATTEMPTS],[source_dir=SOURCE_DIR],[git_config=GIT_CONFIG],[hyperparameters=HYPERPARAMETERS],[container_log_level=CONTAINER_LOG_LEVEL],[code_location=CODE_LOCATION],[dependencies=DEPENDENCIES],[training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE],[training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN],[disable_output_compression=DISABLE_OUTPUT_COMPRESSION],[enable_infra_check=ENABLE_INFRA_CHECK],[image_repo=IMAGE_REPO],[quiet=QUIET] required arguments: role=ROLE (str) an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_type=INSTANCE_TYPE (str) type of EC2 instance to use for training, for example, 'ml.c4.xlarge' optional arguments: instance_count=INSTANCE_COUNT (int, 1) number of Amazon EC2 instances to use for training. Required if instance_groups is not set. user=USER (str, ec2-user) the username to tag the job with. `getpass.getuser()` if not specified. keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS (int, None) the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs. volume_size=VOLUME_SIZE (int, None) size in GB of the storage volume to use for storing input and output data during training (default: 30). volume_kms_key=VOLUME_KMS_KEY (str, None) KMS key ID for encrypting EBS volume attached to the training instance. max_run=MAX_RUN (int, None) timeout in seconds for training (default: 24 * 60 * 60). input_mode=INPUT_MODE (str, None) the input mode that the algorithm supports (default: ‘File’). output_path=OUTPUT_PATH (str, None) S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution. output_kms_key=OUTPUT_KMS_KEY (str, None) KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3). base_job_name=BASE_JOB_NAME (str, None) prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp. tags=TAGS (typing.List[typing.Dict[str, str]], None) list of tags for labeling a training job. subnets=SUBNETS (typing.List[str], None) list of subnet ids. If not specified training job will be created without VPC config. security_group_ids=SECURITY_GROUP_IDS (typing.List[str], None) list of security group ids. If not specified training job will be created without VPC config. model_uri=MODEL_URI (str, None) URI where a pre-trained model is stored, either locally or in S3. model_channel_name=MODEL_CHANNEL_NAME (str, None) name of the channel where ‘model_uri’ will be downloaded (default: ‘model’). metric_definitions=METRIC_DEFINITIONS (typing.List[typing.Dict[str, str]], None) list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs. encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC (bool, None) specifies whether traffic between training containers is encrypted for the training job (default: False). use_spot_instances=USE_SPOT_INSTANCES (bool, None) specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set. max_wait=MAX_WAIT (int, None) timeout in seconds waiting for spot training job. checkpoint_s3_uri=CHECKPOINT_S3_URI (str, None) S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. checkpoint_local_path=CHECKPOINT_LOCAL_PATH (str, None) local path that the algorithm writes its checkpoints to. debugger_hook_config=DEBUGGER_HOOK_CONFIG (bool, None) configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False. enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS (bool, None) enable SageMaker Metrics Time Series. enable_network_isolation=ENABLE_NETWORK_ISOLATION (bool, None) specifies whether container will run in network isolation mode (default: False). disable_profiler=DISABLE_PROFILER (bool, None) specifies whether Debugger monitoring and profiling will be disabled (default: False). environment=ENVIRONMENT (typing.Dict[str, str], None) environment variables to be set for use during training job max_retry_attempts=MAX_RETRY_ATTEMPTS (int, None) number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts. source_dir=SOURCE_DIR (str, None) absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory) git_config=GIT_CONFIG (typing.Dict[str, str], None) git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token. hyperparameters=HYPERPARAMETERS (typing.Dict[str, str], None) dictionary containing the hyperparameters to initialize this estimator with. container_log_level=CONTAINER_LOG_LEVEL (int, None) log level to use within the container (default: logging.INFO). code_location=CODE_LOCATION (str, None) S3 prefix URI where custom code is uploaded. dependencies=DEPENDENCIES (typing.List[str], None) list of absolute or relative paths to directories with any additional libraries that should be exported to the container. training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE (str, None) specifies how SageMaker accesses the Docker image that contains the training algorithm. training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN (str, None) Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted. disable_output_compression=DISABLE_OUTPUT_COMPRESSION (bool, None) when set to true, Model is uploaded to Amazon S3 without compression after training finishes. enable_infra_check=ENABLE_INFRA_CHECK (bool, None) specifies whether it is running Sagemaker built-in infra check jobs. image_repo=IMAGE_REPO (str, None) (remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container quiet=QUIET (bool, False) whether to suppress verbose output for image building. Defaults to ``False``.
兼容性
功能
调度器支持
获取日志
❌
分布式作业
✔️
取消作业
✔️
描述作业
部分支持。SageMakerScheduler 将返回作业和副本状态,但不提供完整的原始 AppSpec。
工作区 / 修补
✔️
挂载
❌
弹性
❌
- describe(app_id: str) Optional[DescribeAppResponse] [源代码]¶
描述指定的应用程序。
- 返回:
AppDef 描述,如果应用程序不存在,则返回
None
。
- list() List[ListAppResponse] [源代码]¶
对于在调度器上启动的应用程序,此 API 返回一个 ListAppResponse 对象列表,每个对象都包含应用程序 ID 及其状态。注意:此 API 处于原型阶段,如有更改,恕不另行通知。
- log_iter(app_id: 字符串, role_name: 字符串, k: 整数 = 0, regex: 可选[字符串] = 无, since: 可选[日期时间] = 无, until: 可选[日期时间] = 无, should_tail: 布尔值 = 假, streams: 可选[流] = 无) 可迭代[字符串] [源代码]¶
返回一个迭代器,指向
``角色'' 的 ``第 k 个 副本''
的日志行。当所有符合条件的日志行都被读取后,迭代器结束。如果调度程序支持基于时间的游标来获取自定义时间范围内的日志行,则会遵循
since
和until
字段,否则将忽略它们。不指定since
和until
等同于获取所有可用的日志行。如果until
为空,则迭代器的行为类似于tail -f
,跟踪日志输出,直到作业达到终端状态。构成日志的确切定义因调度程序而异。一些调度程序可能将 stderr 或 stdout 视为日志,而另一些调度程序可能从日志文件中读取日志。
行为和假设
如果对不存在的应用程序调用,则会产生未定义的行为。调用者应在调用此方法之前使用
exists(app_id)
检查应用程序是否存在。不是有状态的,使用相同的参数调用此方法两次会返回一个新的迭代器。之前的迭代进度会丢失。
并非始终支持日志跟踪。并非所有调度程序都支持实时日志迭代(例如,在应用程序运行时跟踪日志)。有关迭代器行为的详细信息,请参阅特定调度程序的文档。
- 3.1 如果调度程序支持日志跟踪,则应由
should_tail
参数控制。
不保证日志保留。在调用此方法时,底层调度程序可能已经清除了此应用程序的日志记录。如果是这样,此方法会引发任意异常。
如果
should_tail
为 True,则该方法仅在可访问的日志行已完全耗尽且应用程序已达到最终状态时才会引发StopIteration
异常。例如,如果应用程序卡住并且不生成任何日志行,则迭代器会阻塞,直到应用程序最终被终止(通过超时或手动终止),此时它会引发StopIteration
。如果
should_tail
为 False,则该方法在没有更多日志时引发StopIteration
。并非所有调度程序都需要支持。
某些调度程序可能通过支持
__getitem__
来支持行游标(例如,iter[50]
查找第 50 行日志)。- 保留空格,每行新行都应包含
\n
。为了 支持交互式进度条,返回的行不需要包含
\n
,但应该在打印时不带换行符,以便正确处理\r
回车符。
- 保留空格,每行新行都应包含
- 参数:
streams - 要选择的 IO 输出流。以下之一:组合、stdout、stderr。如果调度程序不支持所选流,则会抛出 ValueError。
- 返回:
指定角色副本的日志行的
迭代器
- 引发:
NotImplementedError - 如果调度程序不支持日志迭代
- schedule(dryrun_info: AppDryRunInfo[AWSSageMakerJob]) 字符串 [源代码]¶
与
submit
相同,只是它接受一个AppDryRunInfo
。鼓励实现者实现此方法,而不是直接实现submit
,因为submit
可以通过以下方式轻松实现dryrun_info = self.submit_dryrun(app, cfg) return schedule(dryrun_info)
- 类 torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerJob(job_name: 字符串, job_def: 字典[字符串, 任何], images_to_push: 字典[字符串, 元组[字符串, 字符串]])[源代码]¶
作业定义了在 SageMaker 上调度作业所需的关键值。这将是 AppDryRunInfo 对象中 request 的值。
job_name:定义显示在 SageMaker 中的作业名称
job_def:定义将用于在 SageMaker 上调度作业的作业描述
images_to_push:由 torchx 用于推送到 image_repo
参考¶
- torchx.schedulers.aws_sagemaker_scheduler.create_scheduler(session_name: str, **kwargs: object) AWSSageMakerScheduler [源代码]¶