Creates a Task State to execute a `SageMaker Training Job` https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html The TrainingStep will also create a model by default, and the model shares the same name as the training job.

Super classes

stepfunctions::Block -> stepfunctions::State -> stepfunctions::Task -> TrainingStep

Methods

Public methods

Inherited methods

Method new()

Initialize TrainingStep class

Usage

TrainingStep$new(
  state_id,
  estimator,
  job_name,
  data = NULL,
  hyperparameters = NULL,
  mini_batch_size = NULL,
  experiment_config = NULL,
  wait_for_completion = TRUE,
  tags = NULL,
  ...
)

Arguments

state_id

(str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.

estimator

(sagemaker.estimator.EstimatorBase): The estimator for the training step. Can be a `BYO estimator, Framework estimator` https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms.html or `Amazon built-in algorithm estimator` https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html.

job_name

(str or Placeholder): Specify a training job name, this is required for the training job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.

data

: Information about the training data. Please refer to the ``fit()`` method of the associated estimator, as this can take any of the following forms:

  • (str) - The S3 location where training data is saved.

  • (list[str, str] or list[str, sagemaker.inputs.TrainingInput]) - If using multiple channels for training data, you can specify a list mapping channel names to strings or :func:`~sagemaker.inputs.TrainingInput` objects.

  • (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can provide additional information about the training dataset. See :func:`sagemaker.inputs.TrainingInput` for full details.

  • (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of Amazon :class:`Record` objects serialized and stored in S3. For use with an estimator for an Amazon algorithm.

  • (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of :class:`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data.

hyperparameters

(list, optional): Specify the hyper parameters for the training. (Default: None)

mini_batch_size

(int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.

experiment_config

(list, optional): Specify the experiment config for the training. (Default: None)

wait_for_completion

(bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)

tags

(list[list], optional): List to tags https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html to associate with the resource.

...

: Extra Fields passed to Task class


Method get_expected_model()

Build Sagemaker model representation of the expected trained model from the Training step. This can be passed to the ModelStep to save the trained model in Sagemaker.

Usage

TrainingStep$get_expected_model(model_name = NULL)

Arguments

model_name

(str, optional): Specify a model name. If not provided, training job name will be used as the model name.

Returns

sagemaker.model.Model: Sagemaker model representation of the expected trained model.


Method clone()

The objects of this class are cloneable with this method.

Usage

TrainingStep$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.