diff --git a/.changelog/46892.txt b/.changelog/46892.txt new file mode 100644 index 000000000000..716e53e2a590 --- /dev/null +++ b/.changelog/46892.txt @@ -0,0 +1,7 @@ +```release-note:new-resource +aws_sagemaker_training_job +``` + +```release-note:new-list-resource +aws_sagemaker_training_job +``` \ No newline at end of file diff --git a/internal/service/sagemaker/exports_test.go b/internal/service/sagemaker/exports_test.go index 875871bb75b3..8464cb8c9940 100644 --- a/internal/service/sagemaker/exports_test.go +++ b/internal/service/sagemaker/exports_test.go @@ -34,6 +34,7 @@ var ( ResourceProject = resourceProject ResourceSpace = resourceSpace ResourceStudioLifecycleConfig = resourceStudioLifecycleConfig + ResourceTrainingJob = newResourceTrainingJob ResourceUserProfile = resourceUserProfile ResourceWorkforce = resourceWorkforce ResourceWorkteam = resourceWorkteam @@ -69,6 +70,7 @@ var ( FindServicecatalogPortfolioStatus = findServicecatalogPortfolioStatus FindSpaceByName = findSpaceByName FindStudioLifecycleConfigByName = findStudioLifecycleConfigByName + FindTrainingJobByName = findTrainingJobByName FindUserProfileByName = findUserProfileByName FindWorkforceByName = findWorkforceByName FindWorkteamByName = findWorkteamByName diff --git a/internal/service/sagemaker/service_package_gen.go b/internal/service/sagemaker/service_package_gen.go index 686876bb2a49..3d2c2c5ca3c1 100644 --- a/internal/service/sagemaker/service_package_gen.go +++ b/internal/service/sagemaker/service_package_gen.go @@ -7,6 +7,8 @@ package sagemaker import ( "context" + "iter" + "slices" "unique" "github.com/aws/aws-sdk-go-v2/aws" @@ -63,9 +65,37 @@ func (p *servicePackage) FrameworkResources(ctx context.Context) []*inttypes.Ser Name: "Model Card Export Job", Region: unique.Make(inttypes.ResourceRegionDefault()), }, + { + Factory: newResourceTrainingJob, + TypeName: "aws_sagemaker_training_job", + Name: "Training Job", + Tags: unique.Make(inttypes.ServicePackageResourceTags{ + IdentifierAttribute: names.AttrARN, + }), + Region: unique.Make(inttypes.ResourceRegionDefault()), + Identity: inttypes.RegionalSingleParameterIdentity("training_job_name"), + Import: inttypes.FrameworkImport{ + WrappedImport: true, + }, + }, } } +func (p *servicePackage) FrameworkListResources(ctx context.Context) iter.Seq[*inttypes.ServicePackageFrameworkListResource] { + return slices.Values([]*inttypes.ServicePackageFrameworkListResource{ + { + Factory: newTrainingJobResourceAsListResource, + TypeName: "aws_sagemaker_training_job", + Name: "Training Job", + Tags: unique.Make(inttypes.ServicePackageResourceTags{ + IdentifierAttribute: names.AttrARN, + }), + Region: unique.Make(inttypes.ResourceRegionDefault()), + Identity: inttypes.RegionalSingleParameterIdentity("training_job_name"), + }, + }) +} + func (p *servicePackage) SDKDataSources(ctx context.Context) []*inttypes.ServicePackageSDKDataSource { return []*inttypes.ServicePackageSDKDataSource{ { diff --git a/internal/service/sagemaker/sweep.go b/internal/service/sagemaker/sweep.go index 62dcb863fe46..f49e5c65dcb9 100644 --- a/internal/service/sagemaker/sweep.go +++ b/internal/service/sagemaker/sweep.go @@ -52,6 +52,7 @@ func RegisterSweepers() { awsv2.Register("aws_sagemaker_pipeline", sweepPipelines) awsv2.Register("aws_sagemaker_hub", sweepHubs) awsv2.Register("aws_sagemaker_model_card", sweepModelCards) + awsv2.Register("aws_sagemaker_training_job", sweepTrainingJobs) } func sweepAppImagesConfig(ctx context.Context, client *conns.AWSClient) ([]sweep.Sweepable, error) { @@ -731,3 +732,25 @@ func sweepModelCards(ctx context.Context, client *conns.AWSClient) ([]sweep.Swee return sweepResources, nil } + +func sweepTrainingJobs(ctx context.Context, client *conns.AWSClient) ([]sweep.Sweepable, error) { + input := sagemaker.ListTrainingJobsInput{} + conn := client.SageMakerClient(ctx) + var sweepResources []sweep.Sweepable + + pages := sagemaker.NewListTrainingJobsPaginator(conn, &input) + for pages.HasMorePages() { + page, err := pages.NextPage(ctx) + if err != nil { + return nil, err + } + + for _, v := range page.TrainingJobSummaries { + sweepResources = append(sweepResources, framework.NewSweepResource(newResourceTrainingJob, client, + framework.NewAttribute("training_job_name", aws.ToString(v.TrainingJobName))), + ) + } + } + + return sweepResources, nil +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/basic/main_gen.tf b/internal/service/sagemaker/testdata/TrainingJob/basic/main_gen.tf new file mode 100644 index 000000000000..9ab0b5ddb3bc --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/basic/main_gen.tf @@ -0,0 +1,67 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +resource "aws_sagemaker_training_job" "test" { + training_job_name = var.rName + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + kms_key_id = aws_kms_key.test.arn + s3_output_path = "s3://example-training-job-output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} + +data "aws_partition" "current" {} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + + repository_name = "linear-learner" + image_tag = "1" +} + +resource "aws_iam_role" "test" { + name = var.rName + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_kms_key" "test" { + description = "KMS key for SageMaker training job" +} + +variable "rName" { + description = "Name for resource" + type = string + nullable = false +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/list_basic/main.tf b/internal/service/sagemaker/testdata/TrainingJob/list_basic/main.tf new file mode 100644 index 000000000000..64ddf8ebd796 --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/list_basic/main.tf @@ -0,0 +1,72 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +resource "aws_sagemaker_training_job" "test" { + count = var.resource_count + + training_job_name = "${var.rName}-${count.index}" + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity", "sts:TagSession"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = var.rName + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_s3_bucket" "test" { + bucket = var.rName + force_destroy = true +} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + repository_name = "linear-learner" + image_tag = "1" +} + +variable "rName" { + description = "Name for resource" + type = string + nullable = false +} + +variable "resource_count" { + description = "Number of resources to create" + type = number + nullable = false +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/list_basic/query.tfquery.hcl b/internal/service/sagemaker/testdata/TrainingJob/list_basic/query.tfquery.hcl new file mode 100644 index 000000000000..016e48f1c090 --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/list_basic/query.tfquery.hcl @@ -0,0 +1,6 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +list "aws_sagemaker_training_job" "test" { + provider = aws +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/list_include_resource/main.tf b/internal/service/sagemaker/testdata/TrainingJob/list_include_resource/main.tf new file mode 100644 index 000000000000..16fc4e36fcca --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/list_include_resource/main.tf @@ -0,0 +1,78 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +resource "aws_sagemaker_training_job" "test" { + count = var.resource_count + + training_job_name = "${var.rName}-${count.index}" + role_arn = aws_iam_role.test.arn + tags = var.resource_tags + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity", "sts:TagSession"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = var.rName + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_s3_bucket" "test" { + bucket = var.rName + force_destroy = true +} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + repository_name = "linear-learner" + image_tag = "1" +} + +variable "rName" { + description = "Name for resource" + type = string + nullable = false +} + +variable "resource_count" { + description = "Number of resources to create" + type = number + nullable = false +} + +variable "resource_tags" { + type = map(string) + nullable = true +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/list_include_resource/query.tfquery.hcl b/internal/service/sagemaker/testdata/TrainingJob/list_include_resource/query.tfquery.hcl new file mode 100644 index 000000000000..62721b013862 --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/list_include_resource/query.tfquery.hcl @@ -0,0 +1,8 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +list "aws_sagemaker_training_job" "test" { + provider = aws + + include_resource = true +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/list_region_override/main.tf b/internal/service/sagemaker/testdata/TrainingJob/list_region_override/main.tf new file mode 100644 index 000000000000..6b97fb676d93 --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/list_region_override/main.tf @@ -0,0 +1,81 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +resource "aws_sagemaker_training_job" "test" { + count = var.resource_count + region = var.region + + training_job_name = "${var.rName}-${count.index}" + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity", "sts:TagSession"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = var.rName + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_s3_bucket" "test" { + region = var.region + bucket = var.rName + force_destroy = true +} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + region = var.region + repository_name = "linear-learner" + image_tag = "1" +} + +variable "rName" { + description = "Name for resource" + type = string + nullable = false +} + +variable "resource_count" { + description = "Number of resources to create" + type = number + nullable = false +} + +variable "region" { + description = "Region to deploy resource in" + type = string + nullable = false +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/list_region_override/query.tfquery.hcl b/internal/service/sagemaker/testdata/TrainingJob/list_region_override/query.tfquery.hcl new file mode 100644 index 000000000000..2990b7999511 --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/list_region_override/query.tfquery.hcl @@ -0,0 +1,10 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +list "aws_sagemaker_training_job" "test" { + provider = aws + + config { + region = var.region + } +} diff --git a/internal/service/sagemaker/testdata/TrainingJob/region_override/main_gen.tf b/internal/service/sagemaker/testdata/TrainingJob/region_override/main_gen.tf new file mode 100644 index 000000000000..bd0dc2e631a9 --- /dev/null +++ b/internal/service/sagemaker/testdata/TrainingJob/region_override/main_gen.tf @@ -0,0 +1,79 @@ +# Copyright IBM Corp. 2014, 2026 +# SPDX-License-Identifier: MPL-2.0 + +resource "aws_sagemaker_training_job" "test" { + region = var.region + + training_job_name = var.rName + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + kms_key_id = aws_kms_key.test.arn + s3_output_path = "s3://example-training-job-output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} + +data "aws_partition" "current" {} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + region = var.region + + + repository_name = "linear-learner" + image_tag = "1" +} + +resource "aws_iam_role" "test" { + name = var.rName + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_kms_key" "test" { + region = var.region + + description = "KMS key for SageMaker training job" +} + +variable "rName" { + description = "Name for resource" + type = string + nullable = false +} + +variable "region" { + description = "Region to deploy resource in" + type = string + nullable = false +} diff --git a/internal/service/sagemaker/testdata/tmpl/training_job_basic.gtpl b/internal/service/sagemaker/testdata/tmpl/training_job_basic.gtpl new file mode 100644 index 000000000000..2889b88cab4d --- /dev/null +++ b/internal/service/sagemaker/testdata/tmpl/training_job_basic.gtpl @@ -0,0 +1,62 @@ +resource "aws_sagemaker_training_job" "test" { +{{- template "region" }} + training_job_name = var.rName + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + kms_key_id = aws_kms_key.test.arn + s3_output_path = "s3://example-training-job-output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +{{- template "tags" . }} +} + +data "aws_partition" "current" {} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { +{{- template "region" }} + + repository_name = "linear-learner" + image_tag = "1" +} + +resource "aws_iam_role" "test" { + name = var.rName + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_kms_key" "test" { +{{- template "region" }} + description = "KMS key for SageMaker training job" +} diff --git a/internal/service/sagemaker/training_job.go b/internal/service/sagemaker/training_job.go new file mode 100644 index 000000000000..253a45a6c6d4 --- /dev/null +++ b/internal/service/sagemaker/training_job.go @@ -0,0 +1,2138 @@ +// Copyright IBM Corp. 2014, 2026 +// SPDX-License-Identifier: MPL-2.0 + +package sagemaker + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/YakDriver/regexache" + "github.com/YakDriver/smarterr" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/sagemaker" + awstypes "github.com/aws/aws-sdk-go-v2/service/sagemaker/types" + "github.com/hashicorp/aws-sdk-go-base/v2/tfawserr" + "github.com/hashicorp/terraform-plugin-framework-timeouts/resource/timeouts" + "github.com/hashicorp/terraform-plugin-framework-validators/int64validator" + "github.com/hashicorp/terraform-plugin-framework-validators/listvalidator" + "github.com/hashicorp/terraform-plugin-framework-validators/mapvalidator" + "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/resource/schema" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/boolplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/int64planmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/listplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/mapplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/hashicorp/terraform-provider-aws/internal/errs" + "github.com/hashicorp/terraform-provider-aws/internal/errs/fwdiag" + "github.com/hashicorp/terraform-provider-aws/internal/framework" + "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" + fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" + tfobjectvalidator "github.com/hashicorp/terraform-provider-aws/internal/framework/validators/objectvalidator" + "github.com/hashicorp/terraform-provider-aws/internal/retry" + tfec2 "github.com/hashicorp/terraform-provider-aws/internal/service/ec2" + "github.com/hashicorp/terraform-provider-aws/internal/smerr" + tftags "github.com/hashicorp/terraform-provider-aws/internal/tags" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" + "github.com/hashicorp/terraform-provider-aws/names" +) + +// @FrameworkResource("aws_sagemaker_training_job", name="Training Job") +// @Tags(identifierAttribute="arn") +// @IdentityAttribute("training_job_name") +// @Testing(existsType="github.com/aws/aws-sdk-go-v2/service/sagemaker;sagemaker.DescribeTrainingJobOutput") +// @Testing(plannableImportAction="NoOp") +// @Testing(importStateIdAttribute="training_job_name") +// @Testing(hasNoPreExistingResource=true) +func newResourceTrainingJob(_ context.Context) (resource.ResourceWithConfigure, error) { + r := &resourceTrainingJob{} + + r.SetDefaultCreateTimeout(45 * time.Minute) + r.SetDefaultUpdateTimeout(45 * time.Minute) + r.SetDefaultDeleteTimeout(45 * time.Minute) + + return r, nil +} + +const ( + ResNameTrainingJob = "Training Job" +) + +var ( + serverlessBaseModelARNVersionRegex = regexache.MustCompile(`/\d{1,4}\.\d{1,4}\.\d{1,4}$`) +) + +type resourceTrainingJob struct { + framework.ResourceWithModel[resourceTrainingJobModel] + framework.WithImportByIdentity + framework.WithTimeouts +} + +func (r *resourceTrainingJob) Schema(ctx context.Context, req resource.SchemaRequest, resp *resource.SchemaResponse) { + resp.Schema = schema.Schema{ + Attributes: map[string]schema.Attribute{ + names.AttrARN: framework.ARNAttributeComputedOnly(), + names.AttrRoleARN: schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "enable_inter_container_traffic_encryption": schema.BoolAttribute{ + Optional: true, + Computed: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.UseStateForUnknown(), + boolplanmodifier.RequiresReplace(), + }, + }, + "enable_managed_spot_training": schema.BoolAttribute{ + Optional: true, + Computed: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.UseStateForUnknown(), + boolplanmodifier.RequiresReplace(), + }, + }, + "enable_network_isolation": schema.BoolAttribute{ + Optional: true, + Computed: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.UseStateForUnknown(), + boolplanmodifier.RequiresReplace(), + }, + }, + names.AttrEnvironment: schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Optional: true, + Validators: []validator.Map{ + mapvalidator.SizeBetween(0, 100), + mapvalidator.KeysAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 512), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`), "key must start with a letter or underscore and contain only letters, digits, and underscores"), + )), + mapvalidator.ValueStringsAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 512), + )), + }, + PlanModifiers: []planmodifier.Map{ + mapplanmodifier.RequiresReplace(), + }, + }, + "hyper_parameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Optional: true, + Validators: []validator.Map{ + mapvalidator.SizeBetween(0, 100), + mapvalidator.KeysAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 256), + )), + mapvalidator.ValueStringsAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 2500), + )), + }, + PlanModifiers: []planmodifier.Map{ + mapplanmodifier.RequiresReplace(), + }, + }, + names.AttrTags: tftags.TagsAttribute(), + names.AttrTagsAll: tftags.TagsAttributeComputedOnly(), + "training_job_name": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 63), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$`), "must start with a letter or number and contain only letters, numbers, and hyphens"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "algorithm_specification": trainingJobAlgorithmSpecificationBlock(ctx), + "checkpoint_config": checkpointConfigBlock(ctx), + "debug_hook_config": debugHookConfigBlock(ctx), + "debug_rule_configurations": debugRuleConfigurationsBlock(ctx), + "experiment_config": experimentConfigBlock(ctx), + "infra_check_config": infraCheckConfigBlock(ctx), + "input_data_config": inputDataConfigBlock(ctx), + "mlflow_config": mlflowConfigBlock(ctx), + "model_package_config": modelPackageConfigBlock(ctx), + "output_data_config": outputDataConfigBlock(ctx), + "profiler_config": profilerConfigBlock(ctx), + "profiler_rule_configurations": profilerRuleConfigurationsBlock(ctx), + "remote_debug_config": remoteDebugConfigBlock(ctx), + "resource_config": resourceConfigBlock(ctx), + "retry_strategy": retryStrategyBlock(ctx), + "serverless_job_config": serverlessJobConfigBlock(ctx), + "session_chaining_config": sessionChainingConfigBlock(ctx), + "stopping_condition": stoppingConditionBlock(ctx), + "tensor_board_output_config": tensorBoardOutputConfigBlock(ctx), + names.AttrVPCConfig: vpcConfigBlock(ctx), + names.AttrTimeouts: timeouts.Block(ctx, timeouts.Opts{ + Create: true, + Update: true, + Delete: true, + }), + }, + } +} + +func trainingJobAlgorithmSpecificationBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobAlgorithmSpecificationModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(1, 1), + }, + NestedObject: schema.NestedBlockObject{ + Validators: []validator.Object{ + tfobjectvalidator.ExactlyOneOfChildren( + path.MatchRelative().AtName("algorithm_name"), + path.MatchRelative().AtName("training_image"), + ), + }, + Attributes: map[string]schema.Attribute{ + "algorithm_name": schema.StringAttribute{ + Optional: true, + MarkdownDescription: "Name or ARN of a SageMaker algorithm resource. Exactly one of `algorithm_name` or `training_image` must be set.", + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 170), + stringvalidator.RegexMatches(regexache.MustCompile(`(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*)?`), "must be a valid algorithm name or ARN"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "container_arguments": schema.ListAttribute{ + ElementType: types.StringType, + Optional: true, + Validators: []validator.List{ + listvalidator.SizeBetween(1, 100), + listvalidator.ValueStringsAre(stringvalidator.LengthBetween(0, 256)), + }, + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + }, + "container_entrypoint": schema.ListAttribute{ + ElementType: types.StringType, + Optional: true, + Validators: []validator.List{ + listvalidator.SizeBetween(1, 100), + listvalidator.ValueStringsAre(stringvalidator.LengthBetween(0, 256)), + }, + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + }, + "enable_sagemaker_metrics_time_series": schema.BoolAttribute{ + Optional: true, + Computed: true, + MarkdownDescription: "Whether SageMaker AI should publish time-series metrics. SageMaker enables this automatically for built-in algorithms, supported prebuilt images, and jobs with explicit `metric_definitions`.", + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.UseStateForUnknown(), + boolplanmodifier.RequiresReplace(), + }, + }, + "training_image": schema.StringAttribute{ + Optional: true, + MarkdownDescription: "Registry path of the training image. Exactly one of `algorithm_name` or `training_image` must be set. Use `metric_definitions` only when you need to extract custom metrics from your own training container logs.", + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 255), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "training_input_mode": schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.TrainingInputMode](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "metric_definitions": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobMetricDefinitionModel](ctx), + MarkdownDescription: "Metric definitions used to extract custom metrics from training container logs. SageMaker may still return built-in metric definitions for built-in algorithms or supported prebuilt images even when this block is omitted.", + Validators: []validator.List{ + listvalidator.SizeBetween(0, 40), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrName: schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 255), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "regex": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 500), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + "training_image_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobTrainingImageConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "training_repository_access_mode": schema.StringAttribute{ + Optional: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "training_repository_auth_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobTrainingRepositoryAuthConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "training_repository_credentials_provider_arn": schema.StringAttribute{ + Optional: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func checkpointConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobCheckpointConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "local_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 4096), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "s3_uri": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func debugHookConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobDebugHookConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "hook_parameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Optional: true, + Validators: []validator.Map{ + mapvalidator.SizeBetween(0, 20), + mapvalidator.KeysAre(stringvalidator.All( + stringvalidator.LengthBetween(1, 256), + )), + mapvalidator.ValueStringsAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 256), + )), + }, + PlanModifiers: []planmodifier.Map{ + mapplanmodifier.RequiresReplace(), + }, + }, + "local_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 4096), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "s3_output_path": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "collection_configurations": schema.ListNestedBlock{ + Validators: []validator.List{ + listvalidator.SizeBetween(0, 20), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "collection_name": schema.StringAttribute{ + Optional: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "collection_parameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Optional: true, + PlanModifiers: []planmodifier.Map{ + mapplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + }, + }, + } +} + +func debugRuleConfigurationsBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobDebugRuleConfigurationModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(0, 20), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrInstanceType: schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.ProcessingInstanceType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "local_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 4096), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "rule_configuration_name": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 256), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "rule_evaluator_image": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 255), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "rule_parameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Optional: true, + Validators: []validator.Map{ + mapvalidator.SizeBetween(0, 100), + mapvalidator.KeysAre(stringvalidator.All( + stringvalidator.LengthBetween(1, 256), + )), + mapvalidator.ValueStringsAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 256), + )), + }, + PlanModifiers: []planmodifier.Map{ + mapplanmodifier.RequiresReplace(), + }, + }, + "s3_output_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "volume_size_in_gb": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + int64planmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func experimentConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobExperimentConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "experiment_name": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 120), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$`), "must start with a letter or number and contain only letters, numbers, and hyphens"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "run_name": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 120), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$`), "must start with a letter or number and contain only letters, numbers, and hyphens"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "trial_component_display_name": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 120), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$`), "must start with a letter or number and contain only letters, numbers, and hyphens"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "trial_name": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 120), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$`), "must start with a letter or number and contain only letters, numbers, and hyphens"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func infraCheckConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobInfraCheckConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "enable_infra_check": schema.BoolAttribute{ + Optional: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func inputDataConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobInputDataConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(1, 20), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "channel_name": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 64), + stringvalidator.RegexMatches(regexache.MustCompile(`[A-Za-z0-9\.\-_]+`), "must contain only letters, numbers, dots, hyphens, and underscores"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "compression_type": schema.StringAttribute{ + Optional: true, + Computed: true, + CustomType: fwtypes.StringEnumType[awstypes.CompressionType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + names.AttrContentType: schema.StringAttribute{ + Optional: true, + Computed: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 256), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + "input_mode": schema.StringAttribute{ + Optional: true, + Computed: true, + CustomType: fwtypes.StringEnumType[awstypes.TrainingInputMode](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + "record_wrapper_type": schema.StringAttribute{ + Optional: true, + Computed: true, + CustomType: fwtypes.StringEnumType[awstypes.RecordWrapper](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "data_source": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobDataSourceModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(1, 1), + }, + NestedObject: schema.NestedBlockObject{ + Blocks: map[string]schema.Block{ + "file_system_data_source": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobFileSystemDataSourceModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "directory_path": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 4096), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "file_system_access_mode": schema.StringAttribute{ + Required: true, + CustomType: fwtypes.StringEnumType[awstypes.FileSystemAccessMode](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + names.AttrFileSystemID: schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(11, 21), + stringvalidator.RegexMatches(regexache.MustCompile(`(fs-[0-9a-f]{8,})`), ""), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "file_system_type": schema.StringAttribute{ + Required: true, + CustomType: fwtypes.StringEnumType[awstypes.FileSystemType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + "s3_data_source": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobS3DataSourceModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "attribute_names": schema.ListAttribute{ + ElementType: types.StringType, + Optional: true, + Validators: []validator.List{ + listvalidator.SizeBetween(0, 16), + listvalidator.ValueStringsAre( + stringvalidator.LengthBetween(1, 256), + stringvalidator.RegexMatches(regexache.MustCompile(`.+`), ""), + ), + }, + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + }, + "instance_group_names": schema.ListAttribute{ + ElementType: types.StringType, + Optional: true, + Validators: []validator.List{ + listvalidator.SizeBetween(0, 5), + listvalidator.ValueStringsAre( + stringvalidator.LengthBetween(1, 64), + stringvalidator.RegexMatches(regexache.MustCompile(`.+`), ""), + ), + }, + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + }, + "s3_data_distribution_type": schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.S3DataDistribution](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "s3_data_type": schema.StringAttribute{ + Required: true, + CustomType: fwtypes.StringEnumType[awstypes.S3DataType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "s3_uri": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "hub_access_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobHubAccessConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "hub_content_arn": schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Required: true, + }, + }, + }, + }, + "model_access_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobModelAccessConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "accept_eula": schema.BoolAttribute{ + Required: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + "shuffle_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobShuffleConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "seed": schema.Int64Attribute{ + Optional: true, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + }, + }, + } +} + +func mlflowConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobMlflowConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "mlflow_experiment_name": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 256), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "mlflow_resource_arn": schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 2048), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "mlflow_run_name": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 256), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func modelPackageConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobModelPackageConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + listvalidator.AlsoRequires( + path.MatchRoot("serverless_job_config"), + ), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "model_package_group_arn": schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 2048), + stringvalidator.RegexMatches(regexache.MustCompile(`arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]{9,16}:[0-9]{12}:model-package-group/[\S]+`), "must be a valid model package group ARN"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "source_model_package_arn": schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 2048), + stringvalidator.RegexMatches(regexache.MustCompile(`arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]{9,16}:[0-9]{12}:model-package/[\S]+`), "must be a valid source model package ARN"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func outputDataConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobOutputDataConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "compression_type": schema.StringAttribute{ + Optional: true, + Computed: true, + CustomType: fwtypes.StringEnumType[awstypes.OutputCompressionType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + names.AttrKMSKeyID: schema.StringAttribute{ + Optional: true, + Computed: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 2048), + stringvalidator.RegexMatches(regexache.MustCompile(`[a-zA-Z0-9:/_-]*`), "must match the KMS key ID pattern"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + "s3_output_path": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func profilerConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobProfilerConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "disable_profiler": schema.BoolAttribute{ + Optional: true, + }, + "profiling_interval_in_milliseconds": schema.Int64Attribute{ + Optional: true, + Validators: []validator.Int64{ + int64validator.OneOf(100, 200, 500, 1000, 5000, 60000), + }, + }, + "profiling_parameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Optional: true, + Validators: []validator.Map{ + mapvalidator.SizeBetween(0, 20), + mapvalidator.KeysAre(stringvalidator.All( + stringvalidator.LengthBetween(1, 256), + )), + mapvalidator.ValueStringsAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 256), + )), + }, + }, + "s3_output_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + }, + }, + }, + } +} + +func profilerRuleConfigurationsBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobProfilerRuleConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(0, 20), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrInstanceType: schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.ProcessingInstanceType](), + }, + "local_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 4096), + }, + }, + "rule_configuration_name": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 256), + }, + }, + "rule_evaluator_image": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 255), + }, + }, + "rule_parameters": schema.MapAttribute{ + ElementType: types.StringType, + Optional: true, + Validators: []validator.Map{ + mapvalidator.SizeBetween(0, 100), + mapvalidator.KeysAre(stringvalidator.All( + stringvalidator.LengthBetween(1, 256), + )), + mapvalidator.ValueStringsAre(stringvalidator.All( + stringvalidator.LengthBetween(0, 256), + )), + }, + }, + "s3_output_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + }, + "volume_size_in_gb": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + }, + }, + }, + }, + } +} + +func remoteDebugConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobRemoteDebugConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "enable_remote_debug": schema.BoolAttribute{ + Optional: true, + }, + }, + }, + } +} + +func resourceConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobResourceConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrInstanceCount: schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + int64validator.ConflictsWith( + path.MatchRelative().AtParent().AtName("instance_groups"), + ), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + int64planmodifier.RequiresReplace(), + }, + }, + names.AttrInstanceType: schema.StringAttribute{ + Optional: true, + Computed: true, + CustomType: fwtypes.StringEnumType[awstypes.TrainingInstanceType](), + Validators: []validator.String{ + stringvalidator.ConflictsWith( + path.MatchRelative().AtParent().AtName("instance_groups"), + ), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + stringplanmodifier.RequiresReplace(), + }, + }, + "keep_alive_period_in_seconds": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.Between(0, 3600), + int64validator.ConflictsWith( + path.MatchRelative().AtParent().AtName("instance_groups"), + ), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + }, + }, + "training_plan_arn": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(50, 2048), + stringvalidator.RegexMatches(regexache.MustCompile(`arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:training-plan/.*`), "must be a valid training plan ARN"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "volume_kms_key_id": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 2048), + stringvalidator.RegexMatches(regexache.MustCompile(`[a-zA-Z0-9:/_-]*`), "must match the KMS key ID pattern"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "volume_size_in_gb": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + int64planmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "instance_groups": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobInstanceGroupModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(0, 5), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrInstanceCount: schema.Int64Attribute{ + Optional: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.RequiresReplace(), + }, + }, + "instance_group_name": schema.StringAttribute{ + Optional: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + names.AttrInstanceType: schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.TrainingInstanceType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + "instance_placement_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobInstancePlacementConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "enable_multiple_jobs": schema.BoolAttribute{ + Optional: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.RequiresReplace(), + }, + }, + }, + Blocks: map[string]schema.Block{ + "placement_specifications": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobPlacementSpecificationModel](ctx), + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrInstanceCount: schema.Int64Attribute{ + Optional: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.RequiresReplace(), + }, + }, + "ultra_server_id": schema.StringAttribute{ + Optional: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func retryStrategyBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobRetryStrategyModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "maximum_retry_attempts": schema.Int64Attribute{ + Required: true, + Validators: []validator.Int64{ + int64validator.Between(1, 30), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func serverlessJobConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobServerlessJobConfigModel](ctx, fwtypes.WithSemanticEqualityFunc(serverlessJobConfigEqualityFunc)), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + listvalidator.ConflictsWith( + path.MatchRoot("algorithm_specification"), + path.MatchRoot("enable_managed_spot_training"), + path.MatchRoot(names.AttrEnvironment), + path.MatchRoot("retry_strategy"), + path.MatchRoot("checkpoint_config"), + path.MatchRoot("debug_hook_config"), + path.MatchRoot("experiment_config"), + path.MatchRoot("profiler_config"), + path.MatchRoot("profiler_rule_configurations"), + path.MatchRoot("tensor_board_output_config"), + ), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "accept_eula": schema.BoolAttribute{ + Optional: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.RequiresReplace(), + }, + }, + "base_model_arn": schema.StringAttribute{ + Required: true, + MarkdownDescription: "Base model ARN in SageMaker Public Hub. SageMaker always selects the latest version of the provided model.", + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 2048), + stringvalidator.RegexMatches(regexache.MustCompile(`(arn:[a-z0-9-\.]{1,63}:sagemaker:\w+(?:-\w+)+:(\d{12}|aws):hub-content\/)[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}\/Model\/[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}(\/\d{1,4}.\d{1,4}.\d{1,4})?`), "must be a valid SageMaker Public Hub model ARN (hub-content)"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "customization_technique": schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.CustomizationTechnique](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "evaluation_type": schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.EvaluationType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "evaluator_arn": schema.StringAttribute{ + Optional: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "job_type": schema.StringAttribute{ + Required: true, + CustomType: fwtypes.StringEnumType[awstypes.ServerlessJobType](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "peft": schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.StringEnumType[awstypes.Peft](), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func sessionChainingConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobSessionChainingConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "enable_session_tag_chaining": schema.BoolAttribute{ + Optional: true, + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func stoppingConditionBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobStoppingConditionModel](ctx), + Validators: []validator.List{ + listvalidator.SizeBetween(1, 1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "max_pending_time_in_seconds": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.Between(7200, 2419200), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + int64planmodifier.RequiresReplace(), + }, + }, + "max_runtime_in_seconds": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + int64planmodifier.RequiresReplace(), + }, + }, + "max_wait_time_in_seconds": schema.Int64Attribute{ + Optional: true, + Computed: true, + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + PlanModifiers: []planmodifier.Int64{ + int64planmodifier.UseStateForUnknown(), + int64planmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func tensorBoardOutputConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobTensorBoardOutputConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "local_path": schema.StringAttribute{ + Optional: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 4096), + stringvalidator.RegexMatches(regexache.MustCompile(`.*`), ""), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "s3_output_path": schema.StringAttribute{ + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(0, 1024), + stringvalidator.RegexMatches(regexache.MustCompile(`(https|s3)://([^/]+)/?(.*)`), "must be a valid S3 or HTTPS URI"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func vpcConfigBlock(ctx context.Context) schema.Block { + return schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[trainingJobVPCConfigModel](ctx), + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + names.AttrSecurityGroupIDs: schema.ListAttribute{ + ElementType: types.StringType, + Required: true, + Validators: []validator.List{ + listvalidator.SizeBetween(1, 5), + listvalidator.ValueStringsAre( + stringvalidator.LengthBetween(0, 32), + stringvalidator.RegexMatches(regexache.MustCompile(`[-0-9a-zA-Z]+`), "must be a valid security group ID"), + ), + }, + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + }, + names.AttrSubnets: schema.ListAttribute{ + ElementType: types.StringType, + Required: true, + Validators: []validator.List{ + listvalidator.SizeBetween(1, 16), + listvalidator.ValueStringsAre( + stringvalidator.LengthBetween(0, 32), + stringvalidator.RegexMatches(regexache.MustCompile(`[-0-9a-zA-Z]+`), "must be a valid subnet ID"), + ), + }, + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + }, + }, + }, + } +} + +func (r *resourceTrainingJob) Create(ctx context.Context, req resource.CreateRequest, resp *resource.CreateResponse) { + conn := r.Meta().SageMakerClient(ctx) + + var plan resourceTrainingJobModel + smerr.AddEnrich(ctx, &resp.Diagnostics, req.Plan.Get(ctx, &plan)) + if resp.Diagnostics.HasError() { + return + } + + var input sagemaker.CreateTrainingJobInput + smerr.AddEnrich(ctx, &resp.Diagnostics, flex.Expand(ctx, plan, &input)) + if resp.Diagnostics.HasError() { + return + } + + input.Tags = getTagsIn(ctx) + + out, err := tfresource.RetryWhen(ctx, propagationTimeout, func(ctx context.Context) (*sagemaker.CreateTrainingJobOutput, error) { + return conn.CreateTrainingJob(ctx, &input) + }, func(err error) (bool, error) { + if errMessageContainsAny(err, ErrCodeValidationException, []string{ + "Could not assume role", + "Unauthorized to List objects under S3 URL", + "Access denied to OutputDataConfig S3 bucket", + "no identity-based policy allows the s3:ListBucket action", + "Access denied to hub content", + "Access denied for repository", + }) { + return true, err + } + return false, err + }) + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, plan.TrainingJobName.String()) + return + } + + if out == nil || out.TrainingJobArn == nil { + smerr.AddError(ctx, &resp.Diagnostics, errors.New("empty output"), smerr.ID, plan.TrainingJobName.String()) + return + } + + createTimeout := r.CreateTimeout(ctx, plan.Timeouts) + waitOut, err := waitTrainingJobCreated(ctx, conn, plan.TrainingJobName.ValueString(), createTimeout) + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, plan.TrainingJobName.String()) + return + } + + r.flatten(ctx, waitOut, &plan, &resp.Diagnostics) + if resp.Diagnostics.HasError() { + return + } + + smerr.AddEnrich(ctx, &resp.Diagnostics, resp.State.Set(ctx, plan)) +} + +func (r *resourceTrainingJob) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) { + conn := r.Meta().SageMakerClient(ctx) + + var state resourceTrainingJobModel + smerr.AddEnrich(ctx, &resp.Diagnostics, req.State.Get(ctx, &state)) + if resp.Diagnostics.HasError() { + return + } + + out, err := findTrainingJobByName(ctx, conn, state.TrainingJobName.ValueString()) + + if retry.NotFound(err) { + resp.Diagnostics.Append(fwdiag.NewResourceNotFoundWarningDiagnostic(err)) + resp.State.RemoveResource(ctx) + return + } + + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, state.TrainingJobName.ValueString()) + return + } + + r.flatten(ctx, out, &state, &resp.Diagnostics) + if resp.Diagnostics.HasError() { + return + } + + smerr.AddEnrich(ctx, &resp.Diagnostics, resp.State.Set(ctx, &state)) +} + +func (r *resourceTrainingJob) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) { + var plan, state resourceTrainingJobModel + smerr.AddEnrich(ctx, &resp.Diagnostics, req.Plan.Get(ctx, &plan)) + if resp.Diagnostics.HasError() { + return + } + smerr.AddEnrich(ctx, &resp.Diagnostics, req.State.Get(ctx, &state)) + if resp.Diagnostics.HasError() { + return + } + + conn := r.Meta().SageMakerClient(ctx) + + diff, d := flex.Diff(ctx, plan, state) + smerr.AddEnrich(ctx, &resp.Diagnostics, d, smerr.ID, plan.TrainingJobName) + if resp.Diagnostics.HasError() { + return + } + + if diff.HasChanges() { + var input sagemaker.UpdateTrainingJobInput + smerr.AddEnrich(ctx, &resp.Diagnostics, flex.Expand(ctx, plan, &input), smerr.ID, plan.TrainingJobName) + if resp.Diagnostics.HasError() { + return + } + + _, err := conn.UpdateTrainingJob(ctx, &input) + + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, plan.TrainingJobName) + return + } + } + + smerr.AddEnrich(ctx, &resp.Diagnostics, resp.State.Set(ctx, &plan)) +} + +func (r *resourceTrainingJob) Delete(ctx context.Context, req resource.DeleteRequest, resp *resource.DeleteResponse) { + conn := r.Meta().SageMakerClient(ctx) + + var state resourceTrainingJobModel + smerr.AddEnrich(ctx, &resp.Diagnostics, req.State.Get(ctx, &state)) + if resp.Diagnostics.HasError() { + return + } + + job, err := findTrainingJobByName(ctx, conn, state.TrainingJobName.ValueString()) + if err != nil { + if !retry.NotFound(err) { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, state.TrainingJobName.ValueString()) + } + return + } + + if job.TrainingJobStatus == awstypes.TrainingJobStatusInProgress { + stopInput := &sagemaker.StopTrainingJobInput{ + TrainingJobName: state.TrainingJobName.ValueStringPointer(), + } + _, err := conn.StopTrainingJob(ctx, stopInput) + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, state.TrainingJobName.ValueString()) + return + } + + stopTimeout := r.DeleteTimeout(ctx, state.Timeouts) + _, err = waitTrainingJobStopped(ctx, conn, state.TrainingJobName.ValueString(), stopTimeout) + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, state.TrainingJobName.ValueString()) + return + } + } + + input := sagemaker.DeleteTrainingJobInput{ + TrainingJobName: state.TrainingJobName.ValueStringPointer(), + } + + _, err = conn.DeleteTrainingJob(ctx, &input) + + if err != nil { + if errs.Contains(err, "Requested resource not found") { + return + } + + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, state.TrainingJobName.ValueString()) + return + } + + deleteTimeout := r.DeleteTimeout(ctx, state.Timeouts) + _, err = waitTrainingJobDeleted(ctx, conn, state.TrainingJobName.ValueString(), deleteTimeout) + if err != nil { + smerr.AddError(ctx, &resp.Diagnostics, err, smerr.ID, state.TrainingJobName.ValueString()) + return + } + + if !state.VPCConfig.IsNull() && !state.VPCConfig.IsUnknown() { + vpcConfigs, diags := state.VPCConfig.ToSlice(ctx) + resp.Diagnostics.Append(diags...) + if !resp.Diagnostics.HasError() && len(vpcConfigs) > 0 { + var securityGroupIDs []string + resp.Diagnostics.Append(vpcConfigs[0].SecurityGroupIDs.ElementsAs(ctx, &securityGroupIDs, false)...) + + var subnetIDs []string + resp.Diagnostics.Append(vpcConfigs[0].Subnets.ElementsAs(ctx, &subnetIDs, false)...) + + if !resp.Diagnostics.HasError() && len(securityGroupIDs) > 0 && len(subnetIDs) > 0 { + if err := deleteTrainingJobVPCENIs(ctx, r.Meta().EC2Client(ctx), securityGroupIDs, subnetIDs, deleteTimeout); err != nil { + resp.Diagnostics.AddWarning( + "Error cleaning up VPC ENIs", + fmt.Sprintf("SageMaker training job %s was deleted, but there was an error cleaning up VPC network interfaces: %s", state.TrainingJobName.ValueString(), err), + ) + } + } + } + } + + if !state.ModelPackageConfig.IsNull() && !state.ModelPackageConfig.IsUnknown() { + mpConfigs, diags := state.ModelPackageConfig.ToSlice(ctx) + resp.Diagnostics.Append(diags...) + if !resp.Diagnostics.HasError() && len(mpConfigs) > 0 { + groupARN := mpConfigs[0].ModelPackageGroupARN.ValueString() + if groupARN != "" { + if err := deleteModelPackages(ctx, conn, groupARN); err != nil { + resp.Diagnostics.AddWarning( + "Error cleaning up Model Packages", + fmt.Sprintf("SageMaker training job %s was deleted, but there was an error cleaning up model packages in group %s: %s", state.TrainingJobName.ValueString(), groupARN, err), + ) + } + } + } + } +} + +func deleteTrainingJobVPCENIs(ctx context.Context, ec2Conn *ec2.Client, securityGroupIDs, subnetIDs []string, timeout time.Duration) error { + networkInterfaces, err := tfec2.FindNetworkInterfaces(ctx, ec2Conn, &ec2.DescribeNetworkInterfacesInput{ + Filters: []ec2types.Filter{ + tfec2.NewFilter("group-id", securityGroupIDs), + tfec2.NewFilter("subnet-id", subnetIDs), + }, + }) + if err != nil { + return fmt.Errorf("finding ENIs: %w", err) + } + + for _, ni := range networkInterfaces { + networkInterfaceID := aws.ToString(ni.NetworkInterfaceId) + + if ni.Attachment != nil { + if err := tfec2.DetachNetworkInterface(ctx, ec2Conn, networkInterfaceID, aws.ToString(ni.Attachment.AttachmentId), timeout); err != nil { + return fmt.Errorf("detaching ENI (%s): %w", networkInterfaceID, err) + } + } + + if err := tfec2.DeleteNetworkInterface(ctx, ec2Conn, networkInterfaceID); err != nil { + return fmt.Errorf("deleting ENI (%s): %w", networkInterfaceID, err) + } + } + + return nil +} + +func deleteModelPackages(ctx context.Context, conn *sagemaker.Client, groupNameOrARN string) error { + pages := sagemaker.NewListModelPackagesPaginator(conn, &sagemaker.ListModelPackagesInput{ + ModelPackageGroupName: aws.String(groupNameOrARN), + }) + for pages.HasMorePages() { + page, err := pages.NextPage(ctx) + if err != nil { + return fmt.Errorf("listing SageMaker AI Model Packages for group (%s): %w", groupNameOrARN, err) + } + for _, mp := range page.ModelPackageSummaryList { + if _, err := conn.DeleteModelPackage(ctx, &sagemaker.DeleteModelPackageInput{ + ModelPackageName: mp.ModelPackageArn, + }); err != nil { + if !errs.Contains(err, "does not exist") { + return fmt.Errorf("deleting SageMaker AI Model Package (%s): %w", aws.ToString(mp.ModelPackageArn), err) + } + } + } + } + return nil +} + +// SageMaker injects metric definitions for some built-in algorithms and supported +// prebuilt images. This fixes unexpected new value errors during apply +func normalizeAlgoSpecMetricDefinitions( + ctx context.Context, + saved fwtypes.ListNestedObjectValueOf[trainingJobAlgorithmSpecificationModel], + target *fwtypes.ListNestedObjectValueOf[trainingJobAlgorithmSpecificationModel], + diags *diag.Diagnostics, +) { + if saved.IsUnknown() { + return + } + + flatSpecs, d := target.ToSlice(ctx) + diags.Append(d...) + if diags.HasError() || len(flatSpecs) == 0 { + return + } + + if saved.IsNull() || len(saved.Elements()) == 0 { + flatSpecs[0].MetricDefinitions = fwtypes.NewListNestedObjectValueOfNull[trainingJobMetricDefinitionModel](ctx) + } else { + savedSpecs, d := saved.ToSlice(ctx) + diags.Append(d...) + if diags.HasError() || len(savedSpecs) == 0 { + return + } + flatSpecs[0].MetricDefinitions = savedSpecs[0].MetricDefinitions + } + + *target = fwtypes.NewListNestedObjectValueOfSliceMust(ctx, flatSpecs) +} + +// AWS injects a default stopping_condition for serverless jobs when the user omitted it. +// Only suppress that value for serverless jobs so import can still retain explicit +// stopping_condition values for non-serverless jobs. +func normalizeStoppingCondition( + ctx context.Context, + saved fwtypes.ListNestedObjectValueOf[trainingJobStoppingConditionModel], + serverlessJobConfig fwtypes.ListNestedObjectValueOf[trainingJobServerlessJobConfigModel], + target *fwtypes.ListNestedObjectValueOf[trainingJobStoppingConditionModel], +) { + if saved.IsUnknown() { + return + } + + if (saved.IsNull() || len(saved.Elements()) == 0) && !serverlessJobConfig.IsNull() && len(serverlessJobConfig.Elements()) > 0 { + *target = fwtypes.NewListNestedObjectValueOfNull[trainingJobStoppingConditionModel](ctx) + } +} + +func (r *resourceTrainingJob) flatten( + ctx context.Context, + out *sagemaker.DescribeTrainingJobOutput, + target *resourceTrainingJobModel, + diags *diag.Diagnostics, +) { + savedAlgoSpec := target.AlgorithmSpecification + savedStoppingCondition := target.StoppingCondition + + diags.Append(flex.Flatten(ctx, out, target)...) + if diags.HasError() { + return + } + + normalizeAlgoSpecMetricDefinitions(ctx, savedAlgoSpec, &target.AlgorithmSpecification, diags) + if diags.HasError() { + return + } + + normalizeStoppingCondition(ctx, savedStoppingCondition, target.ServerlessJobConfig, &target.StoppingCondition) +} + +// SageMaker always selects the latest version of the provided model irrespective of user config +func serverlessJobConfigEqualityFunc( + ctx context.Context, + oldValue, newValue fwtypes.NestedCollectionValue[trainingJobServerlessJobConfigModel], +) (bool, diag.Diagnostics) { + var diags diag.Diagnostics + + oldConfig, d := oldValue.ToPtr(ctx) + diags.Append(d...) + if diags.HasError() { + return false, diags + } + + newConfig, d := newValue.ToPtr(ctx) + diags.Append(d...) + if diags.HasError() { + return false, diags + } + + if oldConfig == nil || newConfig == nil { + return oldConfig == nil && newConfig == nil, diags + } + + if !oldConfig.AcceptEULA.Equal(newConfig.AcceptEULA) || + !oldConfig.CustomizationTechnique.Equal(newConfig.CustomizationTechnique) || + !oldConfig.EvaluationType.Equal(newConfig.EvaluationType) || + !oldConfig.EvaluatorARN.Equal(newConfig.EvaluatorARN) || + !oldConfig.JobType.Equal(newConfig.JobType) || + !oldConfig.Peft.Equal(newConfig.Peft) { + return false, diags + } + + return serverlessBaseModelARNsEqual(oldConfig.BaseModelARN, newConfig.BaseModelARN), diags +} + +func errMessageContainsAny(err error, code string, messages []string) bool { + for _, message := range messages { + if tfawserr.ErrMessageContains(err, code, message) { + return true + } + } + return false +} + +func serverlessBaseModelARNsEqual(oldValue, newValue types.String) bool { + if oldValue.IsNull() || oldValue.IsUnknown() || newValue.IsNull() || newValue.IsUnknown() { + return oldValue.Equal(newValue) + } + + return normalizeServerlessBaseModelARN(oldValue.ValueString()) == normalizeServerlessBaseModelARN(newValue.ValueString()) +} + +func normalizeServerlessBaseModelARN(v string) string { + return serverlessBaseModelARNVersionRegex.ReplaceAllString(v, "") +} + +func statusTrainingJob(conn *sagemaker.Client, id string) retry.StateRefreshFunc { + return func(ctx context.Context) (any, string, error) { + out, err := findTrainingJobByName(ctx, conn, id) + if retry.NotFound(err) { + return nil, "", nil + } + + if err != nil { + return nil, "", smarterr.NewError(err) + } + + return out, string(out.TrainingJobStatus), nil + } +} + +func findTrainingJobByName(ctx context.Context, conn *sagemaker.Client, id string) (*sagemaker.DescribeTrainingJobOutput, error) { + input := sagemaker.DescribeTrainingJobInput{ + TrainingJobName: aws.String(id), + } + + out, err := conn.DescribeTrainingJob(ctx, &input) + if err != nil { + if errs.Contains(err, "Requested resource not found") { + return nil, smarterr.NewError(&retry.NotFoundError{ + LastError: err, + }) + } + + return nil, smarterr.NewError(err) + } + + if out == nil || out.TrainingJobArn == nil { + return nil, smarterr.NewError(tfresource.NewEmptyResultError()) + } + + return out, nil +} + +type resourceTrainingJobModel struct { + framework.WithRegionModel + AlgorithmSpecification fwtypes.ListNestedObjectValueOf[trainingJobAlgorithmSpecificationModel] `tfsdk:"algorithm_specification"` + TrainingJobARN types.String `tfsdk:"arn"` + CheckpointConfig fwtypes.ListNestedObjectValueOf[trainingJobCheckpointConfigModel] `tfsdk:"checkpoint_config"` + DebugHookConfig fwtypes.ListNestedObjectValueOf[trainingJobDebugHookConfigModel] `tfsdk:"debug_hook_config"` + DebugRuleConfigurations fwtypes.ListNestedObjectValueOf[trainingJobDebugRuleConfigurationModel] `tfsdk:"debug_rule_configurations"` + EnableInterContainerTrafficEncryption types.Bool `tfsdk:"enable_inter_container_traffic_encryption"` + EnableManagedSpotTraining types.Bool `tfsdk:"enable_managed_spot_training"` + EnableNetworkIsolation types.Bool `tfsdk:"enable_network_isolation"` + Environment fwtypes.MapOfString `tfsdk:"environment"` + ExperimentConfig fwtypes.ListNestedObjectValueOf[trainingJobExperimentConfigModel] `tfsdk:"experiment_config"` + HyperParameters fwtypes.MapOfString `tfsdk:"hyper_parameters"` + InfraCheckConfig fwtypes.ListNestedObjectValueOf[trainingJobInfraCheckConfigModel] `tfsdk:"infra_check_config"` + InputDataConfig fwtypes.ListNestedObjectValueOf[trainingJobInputDataConfigModel] `tfsdk:"input_data_config"` + MlflowConfig fwtypes.ListNestedObjectValueOf[trainingJobMlflowConfigModel] `tfsdk:"mlflow_config"` + ModelPackageConfig fwtypes.ListNestedObjectValueOf[trainingJobModelPackageConfigModel] `tfsdk:"model_package_config"` + OutputDataConfig fwtypes.ListNestedObjectValueOf[trainingJobOutputDataConfigModel] `tfsdk:"output_data_config"` + ProfilerConfig fwtypes.ListNestedObjectValueOf[trainingJobProfilerConfigModel] `tfsdk:"profiler_config"` + ProfilerRuleConfigurations fwtypes.ListNestedObjectValueOf[trainingJobProfilerRuleConfigModel] `tfsdk:"profiler_rule_configurations"` + RemoteDebugConfig fwtypes.ListNestedObjectValueOf[trainingJobRemoteDebugConfigModel] `tfsdk:"remote_debug_config"` + ResourceConfig fwtypes.ListNestedObjectValueOf[trainingJobResourceConfigModel] `tfsdk:"resource_config"` + RetryStrategy fwtypes.ListNestedObjectValueOf[trainingJobRetryStrategyModel] `tfsdk:"retry_strategy"` + RoleARN fwtypes.ARN `tfsdk:"role_arn"` + ServerlessJobConfig fwtypes.ListNestedObjectValueOf[trainingJobServerlessJobConfigModel] `tfsdk:"serverless_job_config"` + SessionChainingConfig fwtypes.ListNestedObjectValueOf[trainingJobSessionChainingConfigModel] `tfsdk:"session_chaining_config"` + StoppingCondition fwtypes.ListNestedObjectValueOf[trainingJobStoppingConditionModel] `tfsdk:"stopping_condition" autoflex:",omitempty"` + Tags tftags.Map `tfsdk:"tags"` + TagsAll tftags.Map `tfsdk:"tags_all"` + TensorBoardOutputConfig fwtypes.ListNestedObjectValueOf[trainingJobTensorBoardOutputConfigModel] `tfsdk:"tensor_board_output_config"` + Timeouts timeouts.Value `tfsdk:"timeouts"` + TrainingJobName types.String `tfsdk:"training_job_name"` + VPCConfig fwtypes.ListNestedObjectValueOf[trainingJobVPCConfigModel] `tfsdk:"vpc_config"` +} + +type trainingJobAlgorithmSpecificationModel struct { + AlgorithmName types.String `tfsdk:"algorithm_name"` + ContainerArguments fwtypes.ListOfString `tfsdk:"container_arguments"` + ContainerEntrypoint fwtypes.ListOfString `tfsdk:"container_entrypoint"` + EnableSageMakerMetricsTimeSeries types.Bool `tfsdk:"enable_sagemaker_metrics_time_series"` + MetricDefinitions fwtypes.ListNestedObjectValueOf[trainingJobMetricDefinitionModel] `tfsdk:"metric_definitions"` + TrainingImage types.String `tfsdk:"training_image"` + TrainingImageConfig fwtypes.ListNestedObjectValueOf[trainingJobTrainingImageConfigModel] `tfsdk:"training_image_config"` + TrainingInputMode types.String `tfsdk:"training_input_mode"` +} + +type trainingJobMetricDefinitionModel struct { + Name types.String `tfsdk:"name"` + Regex types.String `tfsdk:"regex"` +} + +type trainingJobTrainingImageConfigModel struct { + TrainingRepositoryAccessMode fwtypes.StringEnum[awstypes.TrainingRepositoryAccessMode] `tfsdk:"training_repository_access_mode"` + TrainingRepositoryAuthConfig fwtypes.ListNestedObjectValueOf[trainingJobTrainingRepositoryAuthConfigModel] `tfsdk:"training_repository_auth_config"` +} + +type trainingJobTrainingRepositoryAuthConfigModel struct { + TrainingRepositoryCredentialsProviderARN fwtypes.ARN `tfsdk:"training_repository_credentials_provider_arn"` +} + +type trainingJobInputDataConfigModel struct { + ChannelName types.String `tfsdk:"channel_name"` + CompressionType fwtypes.StringEnum[awstypes.CompressionType] `tfsdk:"compression_type"` + ContentType types.String `tfsdk:"content_type"` + DataSource fwtypes.ListNestedObjectValueOf[trainingJobDataSourceModel] `tfsdk:"data_source"` + InputMode fwtypes.StringEnum[awstypes.TrainingInputMode] `tfsdk:"input_mode"` + RecordWrapperType fwtypes.StringEnum[awstypes.RecordWrapper] `tfsdk:"record_wrapper_type"` + ShuffleConfig fwtypes.ListNestedObjectValueOf[trainingJobShuffleConfigModel] `tfsdk:"shuffle_config"` +} + +type trainingJobDataSourceModel struct { + FileSystemDataSource fwtypes.ListNestedObjectValueOf[trainingJobFileSystemDataSourceModel] `tfsdk:"file_system_data_source"` + S3DataSource fwtypes.ListNestedObjectValueOf[trainingJobS3DataSourceModel] `tfsdk:"s3_data_source"` +} + +type trainingJobFileSystemDataSourceModel struct { + DirectoryPath types.String `tfsdk:"directory_path"` + FileSystemAccessMode fwtypes.StringEnum[awstypes.FileSystemAccessMode] `tfsdk:"file_system_access_mode"` + FileSystemID types.String `tfsdk:"file_system_id"` + FileSystemType fwtypes.StringEnum[awstypes.FileSystemType] `tfsdk:"file_system_type"` +} + +type trainingJobS3DataSourceModel struct { + AttributeNames fwtypes.ListOfString `tfsdk:"attribute_names"` + HubAccessConfig fwtypes.ListNestedObjectValueOf[trainingJobHubAccessConfigModel] `tfsdk:"hub_access_config"` + InstanceGroupNames fwtypes.ListOfString `tfsdk:"instance_group_names"` + ModelAccessConfig fwtypes.ListNestedObjectValueOf[trainingJobModelAccessConfigModel] `tfsdk:"model_access_config"` + S3DataDistributionType fwtypes.StringEnum[awstypes.S3DataDistribution] `tfsdk:"s3_data_distribution_type"` + S3DataType fwtypes.StringEnum[awstypes.S3DataType] `tfsdk:"s3_data_type"` + S3URI types.String `tfsdk:"s3_uri"` +} + +type trainingJobHubAccessConfigModel struct { + HubContentARN types.String `tfsdk:"hub_content_arn"` +} + +type trainingJobModelAccessConfigModel struct { + AcceptEULA types.Bool `tfsdk:"accept_eula"` +} + +type trainingJobShuffleConfigModel struct { + Seed types.Int64 `tfsdk:"seed"` +} + +type trainingJobOutputDataConfigModel struct { + CompressionType fwtypes.StringEnum[awstypes.OutputCompressionType] `tfsdk:"compression_type"` + KMSKeyID types.String `tfsdk:"kms_key_id" autoflex:",omitempty"` + S3OutputPath types.String `tfsdk:"s3_output_path"` +} + +type trainingJobResourceConfigModel struct { + InstanceCount types.Int64 `tfsdk:"instance_count"` + InstanceGroups fwtypes.ListNestedObjectValueOf[trainingJobInstanceGroupModel] `tfsdk:"instance_groups"` + InstancePlacementConfig fwtypes.ListNestedObjectValueOf[trainingJobInstancePlacementConfigModel] `tfsdk:"instance_placement_config"` + InstanceType fwtypes.StringEnum[awstypes.TrainingInstanceType] `tfsdk:"instance_type"` + KeepAlivePeriodInSeconds types.Int64 `tfsdk:"keep_alive_period_in_seconds"` + TrainingPlanARN types.String `tfsdk:"training_plan_arn"` + VolumeKMSKeyID types.String `tfsdk:"volume_kms_key_id"` + VolumeSizeInGB types.Int64 `tfsdk:"volume_size_in_gb"` +} + +type trainingJobInstanceGroupModel struct { + InstanceCount types.Int64 `tfsdk:"instance_count"` + InstanceGroupName types.String `tfsdk:"instance_group_name"` + InstanceType fwtypes.StringEnum[awstypes.TrainingInstanceType] `tfsdk:"instance_type"` +} + +type trainingJobInstancePlacementConfigModel struct { + EnableMultipleJobs types.Bool `tfsdk:"enable_multiple_jobs"` + PlacementSpecifications fwtypes.ListNestedObjectValueOf[trainingJobPlacementSpecificationModel] `tfsdk:"placement_specifications"` +} + +type trainingJobPlacementSpecificationModel struct { + InstanceCount types.Int64 `tfsdk:"instance_count"` + UltraServerID types.String `tfsdk:"ultra_server_id"` +} + +type trainingJobStoppingConditionModel struct { + MaxPendingTimeInSeconds types.Int64 `tfsdk:"max_pending_time_in_seconds"` + MaxRuntimeInSeconds types.Int64 `tfsdk:"max_runtime_in_seconds"` + MaxWaitTimeInSeconds types.Int64 `tfsdk:"max_wait_time_in_seconds"` +} + +type trainingJobVPCConfigModel struct { + SecurityGroupIDs fwtypes.ListOfString `tfsdk:"security_group_ids"` + Subnets fwtypes.ListOfString `tfsdk:"subnets"` +} + +type trainingJobCheckpointConfigModel struct { + LocalPath types.String `tfsdk:"local_path"` + S3URI types.String `tfsdk:"s3_uri"` +} + +type trainingJobDebugHookConfigModel struct { + CollectionConfigurations fwtypes.ListNestedObjectValueOf[trainingJobCollectionConfigurationModel] `tfsdk:"collection_configurations"` + HookParameters fwtypes.MapOfString `tfsdk:"hook_parameters"` + LocalPath types.String `tfsdk:"local_path"` + S3OutputPath types.String `tfsdk:"s3_output_path"` +} + +type trainingJobCollectionConfigurationModel struct { + CollectionName types.String `tfsdk:"collection_name"` + CollectionParameters fwtypes.MapOfString `tfsdk:"collection_parameters"` +} + +type trainingJobDebugRuleConfigurationModel struct { + InstanceType fwtypes.StringEnum[awstypes.ProcessingInstanceType] `tfsdk:"instance_type"` + LocalPath types.String `tfsdk:"local_path"` + RuleConfigurationName types.String `tfsdk:"rule_configuration_name"` + RuleEvaluatorImage types.String `tfsdk:"rule_evaluator_image"` + RuleParameters fwtypes.MapOfString `tfsdk:"rule_parameters"` + S3OutputPath types.String `tfsdk:"s3_output_path"` + VolumeSizeInGB types.Int64 `tfsdk:"volume_size_in_gb"` +} + +type trainingJobExperimentConfigModel struct { + ExperimentName types.String `tfsdk:"experiment_name"` + RunName types.String `tfsdk:"run_name"` + TrialComponentDisplayName types.String `tfsdk:"trial_component_display_name"` + TrialName types.String `tfsdk:"trial_name"` +} + +type trainingJobInfraCheckConfigModel struct { + EnableInfraCheck types.Bool `tfsdk:"enable_infra_check"` +} + +type trainingJobMlflowConfigModel struct { + MlflowExperimentName types.String `tfsdk:"mlflow_experiment_name"` + MlflowResourceARN fwtypes.ARN `tfsdk:"mlflow_resource_arn"` + MlflowRunName types.String `tfsdk:"mlflow_run_name"` +} + +type trainingJobModelPackageConfigModel struct { + ModelPackageGroupARN fwtypes.ARN `tfsdk:"model_package_group_arn"` + SourceModelPackageARN fwtypes.ARN `tfsdk:"source_model_package_arn"` +} + +type trainingJobProfilerConfigModel struct { + DisableProfiler types.Bool `tfsdk:"disable_profiler"` + ProfilingIntervalInMilliseconds types.Int64 `tfsdk:"profiling_interval_in_milliseconds"` + ProfilingParameters fwtypes.MapOfString `tfsdk:"profiling_parameters"` + S3OutputPath types.String `tfsdk:"s3_output_path"` +} + +type trainingJobProfilerRuleConfigModel struct { + InstanceType fwtypes.StringEnum[awstypes.ProcessingInstanceType] `tfsdk:"instance_type"` + LocalPath types.String `tfsdk:"local_path"` + RuleConfigurationName types.String `tfsdk:"rule_configuration_name"` + RuleEvaluatorImage types.String `tfsdk:"rule_evaluator_image"` + RuleParameters fwtypes.MapOfString `tfsdk:"rule_parameters"` + S3OutputPath types.String `tfsdk:"s3_output_path"` + VolumeSizeInGB types.Int64 `tfsdk:"volume_size_in_gb"` +} + +type trainingJobRemoteDebugConfigModel struct { + EnableRemoteDebug types.Bool `tfsdk:"enable_remote_debug"` +} + +type trainingJobRetryStrategyModel struct { + MaximumRetryAttempts types.Int64 `tfsdk:"maximum_retry_attempts"` +} + +type trainingJobServerlessJobConfigModel struct { + AcceptEULA types.Bool `tfsdk:"accept_eula"` + BaseModelARN types.String `tfsdk:"base_model_arn"` + CustomizationTechnique fwtypes.StringEnum[awstypes.CustomizationTechnique] `tfsdk:"customization_technique"` + EvaluationType fwtypes.StringEnum[awstypes.EvaluationType] `tfsdk:"evaluation_type"` + EvaluatorARN types.String `tfsdk:"evaluator_arn"` + JobType fwtypes.StringEnum[awstypes.ServerlessJobType] `tfsdk:"job_type"` + Peft fwtypes.StringEnum[awstypes.Peft] `tfsdk:"peft"` +} + +type trainingJobSessionChainingConfigModel struct { + EnableSessionTagChaining types.Bool `tfsdk:"enable_session_tag_chaining"` +} + +type trainingJobTensorBoardOutputConfigModel struct { + LocalPath types.String `tfsdk:"local_path"` + S3OutputPath types.String `tfsdk:"s3_output_path"` +} diff --git a/internal/service/sagemaker/training_job_identity_gen_test.go b/internal/service/sagemaker/training_job_identity_gen_test.go new file mode 100644 index 000000000000..958edfc1be65 --- /dev/null +++ b/internal/service/sagemaker/training_job_identity_gen_test.go @@ -0,0 +1,198 @@ +// Copyright IBM Corp. 2014, 2026 +// SPDX-License-Identifier: MPL-2.0 + +// Code generated by internal/generate/identitytests/main.go; DO NOT EDIT. + +package sagemaker_test + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/service/sagemaker" + "github.com/hashicorp/terraform-plugin-testing/config" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/knownvalue" + "github.com/hashicorp/terraform-plugin-testing/plancheck" + "github.com/hashicorp/terraform-plugin-testing/statecheck" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" + "github.com/hashicorp/terraform-plugin-testing/tfversion" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + tfknownvalue "github.com/hashicorp/terraform-provider-aws/internal/acctest/knownvalue" + "github.com/hashicorp/terraform-provider-aws/names" +) + +func TestAccSageMakerTrainingJob_Identity_basic(t *testing.T) { + ctx := acctest.Context(t) + + var v sagemaker.DescribeTrainingJobOutput + resourceName := "aws_sagemaker_training_job.test" + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + + acctest.ParallelTest(ctx, t, resource.TestCase{ + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.SkipBelow(tfversion.Version1_12_0), + }, + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + Steps: []resource.TestStep{ + // Step 1: Setup + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/basic/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + }, + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &v), + ), + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.Region())), + statecheck.ExpectIdentity(resourceName, map[string]knownvalue.Check{ + names.AttrAccountID: tfknownvalue.AccountID(), + names.AttrRegion: knownvalue.StringExact(acctest.Region()), + "training_job_name": knownvalue.NotNull(), + }), + statecheck.ExpectIdentityValueMatchesState(resourceName, tfjsonpath.New("training_job_name")), + }, + }, + + // Step 2: Import command + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/basic/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + }, + ImportStateKind: resource.ImportCommandWithID, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + + // Step 3: Import block with Import ID + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/basic/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + }, + ResourceName: resourceName, + ImportState: true, + ImportStateKind: resource.ImportBlockWithID, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportPlanChecks: resource.ImportPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New("training_job_name"), knownvalue.NotNull()), + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.Region())), + }, + }, + }, + + // Step 4: Import block with Resource Identity + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/basic/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + }, + ResourceName: resourceName, + ImportState: true, + ImportStateKind: resource.ImportBlockWithResourceIdentity, + ImportPlanChecks: resource.ImportPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New("training_job_name"), knownvalue.NotNull()), + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.Region())), + }, + }, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_Identity_regionOverride(t *testing.T) { + ctx := acctest.Context(t) + + resourceName := "aws_sagemaker_training_job.test" + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + + acctest.ParallelTest(ctx, t, resource.TestCase{ + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.SkipBelow(tfversion.Version1_12_0), + }, + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + CheckDestroy: acctest.CheckDestroyNoop, + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + Steps: []resource.TestStep{ + // Step 1: Setup + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/region_override/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "region": config.StringVariable(acctest.AlternateRegion()), + }, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.AlternateRegion())), + statecheck.ExpectIdentity(resourceName, map[string]knownvalue.Check{ + names.AttrAccountID: tfknownvalue.AccountID(), + names.AttrRegion: knownvalue.StringExact(acctest.AlternateRegion()), + "training_job_name": knownvalue.NotNull(), + }), + statecheck.ExpectIdentityValueMatchesState(resourceName, tfjsonpath.New("training_job_name")), + }, + }, + + // Step 2: Import command + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/region_override/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "region": config.StringVariable(acctest.AlternateRegion()), + }, + ImportStateKind: resource.ImportCommandWithID, + ImportStateIdFunc: acctest.CrossRegionAttrImportStateIdFunc(resourceName, "training_job_name"), + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + + // Step 3: Import block with Import ID + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/region_override/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "region": config.StringVariable(acctest.AlternateRegion()), + }, + ResourceName: resourceName, + ImportState: true, + ImportStateKind: resource.ImportBlockWithID, + ImportStateIdFunc: acctest.CrossRegionAttrImportStateIdFunc(resourceName, "training_job_name"), + ImportPlanChecks: resource.ImportPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New("training_job_name"), knownvalue.NotNull()), + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.AlternateRegion())), + }, + }, + }, + + // Step 4: Import block with Resource Identity + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/region_override/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "region": config.StringVariable(acctest.AlternateRegion()), + }, + ResourceName: resourceName, + ImportState: true, + ImportStateKind: resource.ImportBlockWithResourceIdentity, + ImportPlanChecks: resource.ImportPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New("training_job_name"), knownvalue.NotNull()), + plancheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.AlternateRegion())), + }, + }, + }, + }, + }) +} diff --git a/internal/service/sagemaker/training_job_list.go b/internal/service/sagemaker/training_job_list.go new file mode 100644 index 000000000000..b81bf5581ca1 --- /dev/null +++ b/internal/service/sagemaker/training_job_list.go @@ -0,0 +1,183 @@ +// Copyright IBM Corp. 2014, 2026 +// SPDX-License-Identifier: MPL-2.0 + +package sagemaker + +import ( + "context" + "fmt" + "iter" + "reflect" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sagemaker" + awstypes "github.com/aws/aws-sdk-go-v2/service/sagemaker/types" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/list" + "github.com/hashicorp/terraform-provider-aws/internal/errs/fwdiag" + "github.com/hashicorp/terraform-provider-aws/internal/framework" + fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" + fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" +) + +// Function annotations are used for list resource registration to the Provider. DO NOT EDIT. +// @FrameworkListResource("aws_sagemaker_training_job") +func newTrainingJobResourceAsListResource() list.ListResourceWithConfigure { + return &trainingJobListResource{} +} + +var _ list.ListResource = &trainingJobListResource{} + +type trainingJobListResource struct { + resourceTrainingJob + framework.WithList +} + +func (l *trainingJobListResource) List(ctx context.Context, request list.ListRequest, stream *list.ListResultsStream) { + conn := l.Meta().SageMakerClient(ctx) + + var query listTrainingJobModel + if request.Config.Raw.IsKnown() && !request.Config.Raw.IsNull() { + if diags := request.Config.Get(ctx, &query); diags.HasError() { + stream.Results = list.ListResultsStreamDiagnostics(diags) + return + } + } + + stream.Results = func(yield func(list.ListResult) bool) { + var input sagemaker.ListTrainingJobsInput + + for item, err := range listTrainingJobs(ctx, conn, &input) { + if err != nil { + result := fwdiag.NewListResultErrorDiagnostic(err) + yield(result) + return + } + + trainingJobName := aws.ToString(item.TrainingJobName) + + result := request.NewListResult(ctx) + + var data resourceTrainingJobModel + data.TrainingJobName = fwflex.StringValueToFramework(ctx, trainingJobName) + data.TrainingJobName = fwflex.StringValueToFramework(ctx, trainingJobName) + + l.SetResult(ctx, l.Meta(), request.IncludeResource, &data, &result, func() { + if request.IncludeResource { + trainingJob, err := findTrainingJobByName(ctx, conn, trainingJobName) + if err != nil { + result.Diagnostics.Append(diag.NewErrorDiagnostic("Reading SageMaker Training Job", err.Error())) + return + } + + result.Diagnostics.Append(fwflex.Flatten(ctx, trainingJob, &data)...) + if result.Diagnostics.HasError() { + return + } + } + + result.Diagnostics.Append(setZeroAttrValuesToNull(ctx, &data)...) + if result.Diagnostics.HasError() { + return + } + + result.DisplayName = trainingJobName + }) + + if result.Diagnostics.HasError() { + yield(result) + return + } + + if !yield(result) { + return + } + } + } +} + +type listTrainingJobModel struct { + framework.WithRegionModel +} + +func listTrainingJobs(ctx context.Context, conn *sagemaker.Client, input *sagemaker.ListTrainingJobsInput) iter.Seq2[awstypes.TrainingJobSummary, error] { + return func(yield func(awstypes.TrainingJobSummary, error) bool) { + pages := sagemaker.NewListTrainingJobsPaginator(conn, input) + for pages.HasMorePages() { + page, err := pages.NextPage(ctx) + if err != nil { + yield(awstypes.TrainingJobSummary{}, fmt.Errorf("listing SageMaker Training Job resources: %w", err)) + return + } + + for _, item := range page.TrainingJobSummaries { + if !yield(item, nil) { + return + } + } + } + } +} + +func setZeroAttrValuesToNull(ctx context.Context, target any) diag.Diagnostics { + var diags diag.Diagnostics + + value := reflect.ValueOf(target) + if !value.IsValid() || value.Kind() != reflect.Ptr || value.IsNil() { + return diags + } + + walkStructSetZeroAttrNull(ctx, value.Elem(), &diags) + + return diags +} + +func walkStructSetZeroAttrNull(ctx context.Context, value reflect.Value, diags *diag.Diagnostics) { + if diags.HasError() || !value.IsValid() || value.Kind() != reflect.Struct { + return + } + + for index := 0; index < value.NumField(); index++ { + field := value.Field(index) + if !field.CanSet() { + continue + } + + if field.Kind() != reflect.Struct { + continue + } + + if attrValue, ok := field.Interface().(attr.Value); ok { + if field.IsZero() { + nullValue, err := fwtypes.NullValueOf(ctx, attrValue) + if err != nil { + diags.AddError("Normalizing List Result", err.Error()) + return + } + + if nullValue == nil { + continue + } + + nullValueReflect := reflect.ValueOf(nullValue) + switch { + case nullValueReflect.Type().AssignableTo(field.Type()): + field.Set(nullValueReflect) + case nullValueReflect.Type().ConvertibleTo(field.Type()): + field.Set(nullValueReflect.Convert(field.Type())) + default: + diags.AddError("Normalizing List Result", fmt.Sprintf("cannot assign null value of type %T to field type %s", nullValue, field.Type())) + return + } + } + + continue + } + + walkStructSetZeroAttrNull(ctx, field, diags) + if diags.HasError() { + return + } + } +} diff --git a/internal/service/sagemaker/training_job_list_test.go b/internal/service/sagemaker/training_job_list_test.go new file mode 100644 index 000000000000..c261f41c0ac8 --- /dev/null +++ b/internal/service/sagemaker/training_job_list_test.go @@ -0,0 +1,208 @@ +// Copyright IBM Corp. 2014, 2026 +// SPDX-License-Identifier: MPL-2.0 + +package sagemaker_test + +import ( + "testing" + + "github.com/hashicorp/terraform-plugin-testing/config" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/knownvalue" + "github.com/hashicorp/terraform-plugin-testing/querycheck" + "github.com/hashicorp/terraform-plugin-testing/statecheck" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" + "github.com/hashicorp/terraform-plugin-testing/tfversion" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + tfknownvalue "github.com/hashicorp/terraform-provider-aws/internal/acctest/knownvalue" + tfquerycheck "github.com/hashicorp/terraform-provider-aws/internal/acctest/querycheck" + tfqueryfilter "github.com/hashicorp/terraform-provider-aws/internal/acctest/queryfilter" + tfstatecheck "github.com/hashicorp/terraform-provider-aws/internal/acctest/statecheck" + "github.com/hashicorp/terraform-provider-aws/names" +) + +func TestAccSageMakerTrainingJob_List_basic(t *testing.T) { + ctx := acctest.Context(t) + + resourceName1 := "aws_sagemaker_training_job.test[0]" + resourceName2 := "aws_sagemaker_training_job.test[1]" + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + + identity1 := tfstatecheck.Identity() + identity2 := tfstatecheck.Identity() + + acctest.ParallelTest(ctx, t, resource.TestCase{ + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.SkipBelow(tfversion.Version1_14_0), + }, + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + Steps: []resource.TestStep{ + // Step 1: Setup + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/list_basic/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "resource_count": config.IntegerVariable(2), + }, + ConfigStateChecks: []statecheck.StateCheck{ + identity1.GetIdentity(resourceName1), + statecheck.ExpectKnownValue(resourceName1, tfjsonpath.New(names.AttrARN), tfknownvalue.RegionalARNExact("sagemaker", "training-job/"+rName+"-0")), + + identity2.GetIdentity(resourceName2), + statecheck.ExpectKnownValue(resourceName2, tfjsonpath.New(names.AttrARN), tfknownvalue.RegionalARNExact("sagemaker", "training-job/"+rName+"-1")), + }, + }, + + // Step 2: Query + { + Query: true, + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/list_basic/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "resource_count": config.IntegerVariable(2), + }, + QueryResultChecks: []querycheck.QueryResultCheck{ + tfquerycheck.ExpectIdentityFunc("aws_sagemaker_training_job.test", identity1.Checks()), + querycheck.ExpectResourceDisplayName("aws_sagemaker_training_job.test", tfqueryfilter.ByResourceIdentityFunc(identity1.Checks()), knownvalue.StringExact(rName+"-0")), + tfquerycheck.ExpectNoResourceObject("aws_sagemaker_training_job.test", tfqueryfilter.ByResourceIdentityFunc(identity1.Checks())), + + tfquerycheck.ExpectIdentityFunc("aws_sagemaker_training_job.test", identity2.Checks()), + querycheck.ExpectResourceDisplayName("aws_sagemaker_training_job.test", tfqueryfilter.ByResourceIdentityFunc(identity2.Checks()), knownvalue.StringExact(rName+"-1")), + tfquerycheck.ExpectNoResourceObject("aws_sagemaker_training_job.test", tfqueryfilter.ByResourceIdentityFunc(identity2.Checks())), + }, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_List_includeResource(t *testing.T) { + ctx := acctest.Context(t) + + resourceName1 := "aws_sagemaker_training_job.test[0]" + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + + identity1 := tfstatecheck.Identity() + + acctest.ParallelTest(ctx, t, resource.TestCase{ + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.SkipBelow(tfversion.Version1_14_0), + }, + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + Steps: []resource.TestStep{ + // Step 1: Setup + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/list_include_resource/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "resource_count": config.IntegerVariable(1), + acctest.CtResourceTags: config.MapVariable(map[string]config.Variable{ + acctest.CtKey1: config.StringVariable(acctest.CtValue1), + }), + }, + ConfigStateChecks: []statecheck.StateCheck{ + identity1.GetIdentity(resourceName1), + statecheck.ExpectKnownValue(resourceName1, tfjsonpath.New(names.AttrARN), tfknownvalue.RegionalARNExact("sagemaker", "training-job/"+rName+"-0")), + }, + }, + + // Step 2: Query + { + Query: true, + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/list_include_resource/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "resource_count": config.IntegerVariable(1), + acctest.CtResourceTags: config.MapVariable(map[string]config.Variable{ + acctest.CtKey1: config.StringVariable(acctest.CtValue1), + }), + }, + QueryResultChecks: []querycheck.QueryResultCheck{ + tfquerycheck.ExpectIdentityFunc("aws_sagemaker_training_job.test", identity1.Checks()), + querycheck.ExpectResourceDisplayName("aws_sagemaker_training_job.test", tfqueryfilter.ByResourceIdentityFunc(identity1.Checks()), knownvalue.StringExact(rName+"-0")), + querycheck.ExpectResourceKnownValues("aws_sagemaker_training_job.test", tfqueryfilter.ByResourceIdentityFunc(identity1.Checks()), []querycheck.KnownValueCheck{ + tfquerycheck.KnownValueCheck(tfjsonpath.New(names.AttrARN), tfknownvalue.RegionalARNExact("sagemaker", "training-job/"+rName+"-0")), + tfquerycheck.KnownValueCheck(tfjsonpath.New(names.AttrRegion), knownvalue.StringExact(acctest.Region())), + tfquerycheck.KnownValueCheck(tfjsonpath.New("training_job_name"), knownvalue.StringExact(rName+"-0")), + tfquerycheck.KnownValueCheck(tfjsonpath.New(names.AttrTags), knownvalue.MapExact(map[string]knownvalue.Check{ + acctest.CtKey1: knownvalue.StringExact(acctest.CtValue1), + })), + tfquerycheck.KnownValueCheck(tfjsonpath.New(names.AttrTagsAll), knownvalue.MapExact(map[string]knownvalue.Check{ + acctest.CtKey1: knownvalue.StringExact(acctest.CtValue1), + })), + }), + }, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_List_regionOverride(t *testing.T) { + ctx := acctest.Context(t) + + resourceName1 := "aws_sagemaker_training_job.test[0]" + resourceName2 := "aws_sagemaker_training_job.test[1]" + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + + identity1 := tfstatecheck.Identity() + identity2 := tfstatecheck.Identity() + + acctest.ParallelTest(ctx, t, resource.TestCase{ + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.SkipBelow(tfversion.Version1_14_0), + }, + PreCheck: func() { + acctest.PreCheck(ctx, t) + acctest.PreCheckMultipleRegion(t, 2) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + Steps: []resource.TestStep{ + // Step 1: Setup + { + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/list_region_override/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "resource_count": config.IntegerVariable(2), + "region": config.StringVariable(acctest.AlternateRegion()), + }, + ConfigStateChecks: []statecheck.StateCheck{ + identity1.GetIdentity(resourceName1), + statecheck.ExpectKnownValue(resourceName1, tfjsonpath.New(names.AttrARN), tfknownvalue.RegionalARNAlternateRegionExact("sagemaker", "training-job/"+rName+"-0")), + + identity2.GetIdentity(resourceName2), + statecheck.ExpectKnownValue(resourceName2, tfjsonpath.New(names.AttrARN), tfknownvalue.RegionalARNAlternateRegionExact("sagemaker", "training-job/"+rName+"-1")), + }, + }, + + // Step 2: Query + { + Query: true, + ConfigDirectory: config.StaticDirectory("testdata/TrainingJob/list_region_override/"), + ConfigVariables: config.Variables{ + acctest.CtRName: config.StringVariable(rName), + "resource_count": config.IntegerVariable(2), + "region": config.StringVariable(acctest.AlternateRegion()), + }, + QueryResultChecks: []querycheck.QueryResultCheck{ + tfquerycheck.ExpectIdentityFunc("aws_sagemaker_training_job.test", identity1.Checks()), + + tfquerycheck.ExpectIdentityFunc("aws_sagemaker_training_job.test", identity2.Checks()), + }, + }, + }, + }) +} diff --git a/internal/service/sagemaker/training_job_test.go b/internal/service/sagemaker/training_job_test.go new file mode 100644 index 000000000000..f27aabfea3f8 --- /dev/null +++ b/internal/service/sagemaker/training_job_test.go @@ -0,0 +1,2983 @@ +// Copyright IBM Corp. 2014, 2026 +// SPDX-License-Identifier: MPL-2.0 + +package sagemaker_test + +import ( + "context" + "fmt" + "testing" + + "github.com/YakDriver/regexache" + "github.com/aws/aws-sdk-go-v2/service/sagemaker" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/knownvalue" + "github.com/hashicorp/terraform-plugin-testing/plancheck" + "github.com/hashicorp/terraform-plugin-testing/statecheck" + "github.com/hashicorp/terraform-plugin-testing/terraform" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + "github.com/hashicorp/terraform-provider-aws/internal/retry" + tfsagemaker "github.com/hashicorp/terraform-provider-aws/internal/service/sagemaker" + "github.com/hashicorp/terraform-provider-aws/names" +) + +const ( + trainingJobNovaModelARNEnvVar = "SAGEMAKER_TRAINING_JOB_NOVA_MODEL_ARN" + trainingJobCustomImageEnvVar = "SAGEMAKER_TRAINING_JOB_CUSTOM_IMAGE" +) + +func TestAccSageMakerTrainingJob_basic(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + acctest.MatchResourceAttrRegionalARN(ctx, resourceName, names.AttrARN, "sagemaker", regexache.MustCompile(`training-job/.+`)), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.#", "1"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.training_input_mode", "File"), + resource.TestCheckResourceAttrPair(resourceName, "algorithm_specification.0.training_image", "data.aws_sagemaker_prebuilt_ecr_image.test", "registry_path"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.enable_sagemaker_metrics_time_series", acctest.CtTrue), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.0.s3_output_path", fmt.Sprintf("s3://%s/output/", rName)), + resource.TestCheckResourceAttr(resourceName, "resource_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "resource_config.0.instance_type", "ml.m5.large"), + resource.TestCheckResourceAttr(resourceName, "resource_config.0.instance_count", "1"), + resource.TestCheckResourceAttr(resourceName, "resource_config.0.volume_size_in_gb", "30"), + resource.TestCheckResourceAttr(resourceName, "stopping_condition.#", "1"), + resource.TestCheckResourceAttr(resourceName, "stopping_condition.0.max_runtime_in_seconds", "3600"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_disappears(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + acctest.CheckFrameworkResourceDisappears(ctx, t, tfsagemaker.ResourceTrainingJob, resourceName), + ), + ExpectNonEmptyPlan: true, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_vpc(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_vpc(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "1"), + resource.TestCheckResourceAttr(resourceName, "stopping_condition.0.max_runtime_in_seconds", "3600"), + ), + }, + { + Config: testAccTrainingJobConfig_vpcUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckResourceAttr(resourceName, "stopping_condition.0.max_runtime_in_seconds", "7200"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_debugConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_debug(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "debug_hook_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "debug_hook_config.0.s3_output_path", fmt.Sprintf("s3://%s/debug/", rName)), + resource.TestCheckResourceAttr(resourceName, "debug_rule_configurations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "debug_rule_configurations.0.s3_output_path", fmt.Sprintf("s3://%s/debug-rules/", rName)), + ), + }, + { + Config: testAccTrainingJobConfig_debugUpdate(rNameUpdated), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "debug_hook_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "debug_hook_config.0.s3_output_path", fmt.Sprintf("s3://%s/debug-updated/", rNameUpdated)), + resource.TestCheckResourceAttr(resourceName, "debug_rule_configurations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "debug_rule_configurations.0.s3_output_path", fmt.Sprintf("s3://%s/debug-rules-updated/", rNameUpdated)), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_profilerConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_profiler(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "profiler_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "profiler_config.0.disable_profiler", acctest.CtFalse), + resource.TestCheckResourceAttr(resourceName, "profiler_config.0.profiling_interval_in_milliseconds", "500"), + resource.TestCheckResourceAttr(resourceName, "profiler_rule_configurations.#", "1"), + ), + }, + { + Config: testAccTrainingJobConfig_profilerUpdated(rNameUpdated), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "profiler_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "profiler_config.0.disable_profiler", acctest.CtFalse), + resource.TestCheckResourceAttr(resourceName, "profiler_config.0.profiling_interval_in_milliseconds", "1000"), + resource.TestCheckResourceAttr(resourceName, "profiler_rule_configurations.#", "1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_environmentAndHyperParameters(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_environmentAndHyperParameters(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "environment.%", "2"), + resource.TestCheckResourceAttr(resourceName, "environment.TEST_ENV", "test_value"), + resource.TestCheckResourceAttr(resourceName, "hyper_parameters.%", "2"), + resource.TestCheckResourceAttr(resourceName, "hyper_parameters.epochs", "10"), + resource.TestCheckResourceAttr(resourceName, "enable_inter_container_traffic_encryption", acctest.CtTrue), + resource.TestCheckResourceAttr(resourceName, "stopping_condition.0.max_runtime_in_seconds", "3600"), + ), + }, + { + Config: testAccTrainingJobConfig_environmentAndHyperParametersUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "environment.%", "2"), + resource.TestCheckResourceAttr(resourceName, "environment.TEST_ENV", "updated_value"), + resource.TestCheckResourceAttr(resourceName, "hyper_parameters.%", "2"), + resource.TestCheckResourceAttr(resourceName, "hyper_parameters.epochs", "20"), + resource.TestCheckResourceAttr(resourceName, "enable_inter_container_traffic_encryption", acctest.CtFalse), + resource.TestCheckResourceAttr(resourceName, "stopping_condition.0.max_runtime_in_seconds", "7200"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_checkpointConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_checkpoint(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "checkpoint_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "checkpoint_config.0.local_path", "/opt/ml/checkpoints"), + resource.TestCheckResourceAttr(resourceName, "checkpoint_config.0.s3_uri", fmt.Sprintf("s3://%s/checkpoints/", rName)), + ), + }, + { + Config: testAccTrainingJobConfig_checkpointUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "checkpoint_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "checkpoint_config.0.local_path", "/opt/ml/checkpoints"), + resource.TestCheckResourceAttr(resourceName, "checkpoint_config.0.s3_uri", fmt.Sprintf("s3://%s/checkpoints-v2/", rName)), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_tensorBoardOutputConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_tensorBoard(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "tensor_board_output_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "tensor_board_output_config.0.local_path", "/opt/ml/output/tensorboard"), + resource.TestCheckResourceAttr(resourceName, "tensor_board_output_config.0.s3_output_path", fmt.Sprintf("s3://%s/tensorboard/", rName)), + ), + }, + { + Config: testAccTrainingJobConfig_tensorBoardUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "tensor_board_output_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "tensor_board_output_config.0.local_path", "/opt/ml/output/tensorboard"), + resource.TestCheckResourceAttr(resourceName, "tensor_board_output_config.0.s3_output_path", fmt.Sprintf("s3://%s/tensorboard-v2/", rName)), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_inputDataConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_inputData(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.channel_name", "training"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.input_mode", "File"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_source.0.s3_data_source.0.s3_uri", fmt.Sprintf("s3://%s/input/", rName)), + ), + }, + { + Config: testAccTrainingJobConfig_inputDataUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.channel_name", "training"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.input_mode", "File"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_source.0.s3_data_source.0.s3_uri", fmt.Sprintf("s3://%s/input-v2/", rName)), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_outputDataConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_outputData(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.0.compression_type", "GZIP"), + resource.TestCheckResourceAttrSet(resourceName, "output_data_config.0.kms_key_id"), + ), + }, + { + Config: testAccTrainingJobConfig_outputDataUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.0.compression_type", "NONE"), + resource.TestCheckResourceAttrSet(resourceName, "output_data_config.0.kms_key_id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_algorithmSpecificationMetrics(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + customImage := acctest.SkipIfEnvVarNotSet(t, trainingJobCustomImageEnvVar) + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_algorithmMetrics(rName, customImage), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.0.name", "train:loss"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.0.regex", "loss: ([0-9\\.]+)"), + ), + }, + { + Config: testAccTrainingJobConfig_algorithmMetricsUpdate(rNameUpdated, customImage), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.0.name", "train:loss"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.0.regex", "loss: ([0-9\\.]+)"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.1.name", "validation:accuracy"), + resource.TestCheckResourceAttr(resourceName, "algorithm_specification.0.metric_definitions.1.regex", "accuracy: ([0-9\\.]+)"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + ImportStateVerifyIgnore: []string{"algorithm_specification.0.metric_definitions"}, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_retryStrategy(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_retryStrategy(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "retry_strategy.#", "1"), + resource.TestCheckResourceAttr(resourceName, "retry_strategy.0.maximum_retry_attempts", "3"), + ), + }, + { + Config: testAccTrainingJobConfig_retryStrategyUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "retry_strategy.#", "1"), + resource.TestCheckResourceAttr(resourceName, "retry_strategy.0.maximum_retry_attempts", "5"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_serverless(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_serverless(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.0.job_type", "FineTuning"), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.0.accept_eula", acctest.CtTrue), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.0.customization_technique", "SFT"), + resource.TestCheckResourceAttrSet(resourceName, "serverless_job_config.0.base_model_arn"), + resource.TestCheckResourceAttr(resourceName, "model_package_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "model_package_config.0.model_package_group_arn"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.channel_name", "train"), + ), + }, + { + Config: testAccTrainingJobConfig_serverlessUpdate(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.0.job_type", "FineTuning"), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.0.accept_eula", acctest.CtTrue), + resource.TestCheckResourceAttr(resourceName, "serverless_job_config.0.customization_technique", "DPO"), + resource.TestCheckResourceAttrSet(resourceName, "serverless_job_config.0.base_model_arn"), + resource.TestCheckResourceAttr(resourceName, "model_package_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "model_package_config.0.model_package_group_arn"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.channel_name", "train"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + ImportStateVerifyIgnore: []string{"serverless_job_config.0.base_model_arn"}, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_tags(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_tags1(rName, acctest.CtKey1, acctest.CtValue1), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + ), + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrTags), knownvalue.MapExact(map[string]knownvalue.Check{ + acctest.CtKey1: knownvalue.StringExact(acctest.CtValue1), + })), + statecheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrTagsAll), knownvalue.MapExact(map[string]knownvalue.Check{ + acctest.CtKey1: knownvalue.StringExact(acctest.CtValue1), + })), + }, + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + ImportStateVerifyIgnore: []string{"algorithm_specification.0.metric_definitions"}, + }, + { + Config: testAccTrainingJobConfig_tags2(rName, acctest.CtKey1, acctest.CtValue1Updated, acctest.CtKey2, acctest.CtValue2), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + ), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionUpdate), + }, + }, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrTags), knownvalue.MapExact(map[string]knownvalue.Check{ + acctest.CtKey1: knownvalue.StringExact(acctest.CtValue1Updated), + acctest.CtKey2: knownvalue.StringExact(acctest.CtValue2), + })), + }, + }, + { + Config: testAccTrainingJobConfig_tags1(rName, acctest.CtKey2, acctest.CtValue2), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + ), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionUpdate), + }, + }, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExpectKnownValue(resourceName, tfjsonpath.New(names.AttrTags), knownvalue.MapExact(map[string]knownvalue.Check{ + acctest.CtKey2: knownvalue.StringExact(acctest.CtValue2), + })), + }, + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_infraCheckConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_infraCheck(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "infra_check_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "infra_check_config.0.enable_infra_check", acctest.CtTrue), + ), + }, + { + Config: testAccTrainingJobConfig_infraCheckUpdate(rNameUpdated), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "infra_check_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "infra_check_config.0.enable_infra_check", acctest.CtFalse), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_mlflowConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + novaModelARN := acctest.SkipIfEnvVarNotSet(t, trainingJobNovaModelARNEnvVar) + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_mlflow(rName, novaModelARN), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "mlflow_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "mlflow_config.0.mlflow_experiment_name", rName), + resource.TestCheckResourceAttrSet(resourceName, "mlflow_config.0.mlflow_resource_arn"), + resource.TestCheckResourceAttr(resourceName, "mlflow_config.0.mlflow_run_name", rName), + ), + }, + { + Config: testAccTrainingJobConfig_mlflowUpdate(rNameUpdated, novaModelARN), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "mlflow_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "mlflow_config.0.mlflow_experiment_name", rNameUpdated), + resource.TestCheckResourceAttrSet(resourceName, "mlflow_config.0.mlflow_resource_arn"), + resource.TestCheckResourceAttr(resourceName, "mlflow_config.0.mlflow_run_name", rNameUpdated), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_remoteDebugConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + customImage := acctest.SkipIfEnvVarNotSet(t, trainingJobCustomImageEnvVar) + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_remoteDebug(rName, rName, customImage), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "remote_debug_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "remote_debug_config.0.enable_remote_debug", acctest.CtFalse), + ), + }, + { + Config: testAccTrainingJobConfig_remoteDebugUpdate(rName, rNameUpdated, customImage), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "remote_debug_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "remote_debug_config.0.enable_remote_debug", acctest.CtTrue), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + }, + }, + }) +} + +func TestAccSageMakerTrainingJob_sessionChainingConfig(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var trainingjob sagemaker.DescribeTrainingJobOutput + rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + rNameUpdated := acctest.RandomWithPrefix(t, acctest.ResourcePrefix) + resourceName := "aws_sagemaker_training_job.test" + + acctest.Test(ctx, t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(ctx, t) + testAccPreCheckTrainingJobs(ctx, t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckTrainingJobDestroy(ctx, t), + Steps: []resource.TestStep{ + { + Config: testAccTrainingJobConfig_sessionChaining(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rName), + resource.TestCheckResourceAttr(resourceName, "session_chaining_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "session_chaining_config.0.enable_session_tag_chaining", acctest.CtTrue), + ), + }, + { + Config: testAccTrainingJobConfig_sessionChainingUpdate(rNameUpdated), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckTrainingJobExists(ctx, t, resourceName, &trainingjob), + resource.TestCheckResourceAttr(resourceName, "training_job_name", rNameUpdated), + resource.TestCheckResourceAttr(resourceName, "session_chaining_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "session_chaining_config.0.enable_session_tag_chaining", acctest.CtFalse), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateIdFunc: acctest.AttrImportStateIdFunc(resourceName, "training_job_name"), + ImportStateVerifyIdentifierAttribute: "training_job_name", + ImportStateVerifyIgnore: []string{"session_chaining_config"}, + }, + }, + }) +} + +func testAccCheckTrainingJobDestroy(ctx context.Context, t *testing.T) resource.TestCheckFunc { + return func(s *terraform.State) error { + conn := acctest.ProviderMeta(ctx, t).SageMakerClient(ctx) + + for _, rs := range s.RootModule().Resources { + if rs.Type != "aws_sagemaker_training_job" { + continue + } + + trainingJobName := rs.Primary.Attributes["training_job_name"] + if trainingJobName == "" { + return fmt.Errorf("No SageMaker Training Job name is set") + } + + _, err := tfsagemaker.FindTrainingJobByName(ctx, conn, trainingJobName) + if retry.NotFound(err) { + continue + } + if err != nil { + return err + } + + return fmt.Errorf("SageMaker Training Job %s still exists", trainingJobName) + } + + return nil + } +} + +func testAccCheckTrainingJobExists(ctx context.Context, t *testing.T, name string, trainingjob *sagemaker.DescribeTrainingJobOutput) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[name] + if !ok { + return fmt.Errorf("Not found: %s", name) + } + + if rs.Primary.Attributes["training_job_name"] == "" { + return fmt.Errorf("No SageMaker Training Job name is set") + } + + conn := acctest.ProviderMeta(ctx, t).SageMakerClient(ctx) + + trainingJobName := rs.Primary.Attributes["training_job_name"] + if trainingJobName == "" { + return fmt.Errorf("No SageMaker Training Job name is set") + } + + output, err := tfsagemaker.FindTrainingJobByName(ctx, conn, trainingJobName) + if err != nil { + return err + } + + *trainingjob = *output + + return nil + } +} + +func testAccPreCheckTrainingJobs(ctx context.Context, t *testing.T) { + conn := acctest.ProviderMeta(ctx, t).SageMakerClient(ctx) + + input := &sagemaker.ListTrainingJobsInput{} + + _, err := conn.ListTrainingJobs(ctx, input) + + if acctest.PreCheckSkipError(err) { + t.Skipf("skipping acceptance testing: %s", err) + } + if err != nil { + t.Fatalf("unexpected PreCheck error: %s", err) + } +} + +func testAccTrainingJobConfig_base(rName string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity", "sts:TagSession"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q + force_destroy = true +} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + repository_name = "linear-learner" + image_tag = "1" +} +`, rName) +} + +func testAccTrainingJobConfig_basic(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + enable_sagemaker_metrics_time_series = true + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_vpc(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_vpc" "test" { + cidr_block = "10.0.0.0/16" +} + +resource "aws_subnet" "test" { + vpc_id = aws_vpc.test.id + cidr_block = "10.0.1.0/24" +} + +resource "aws_security_group" "test" { + vpc_id = aws_vpc.test.id + name = %[1]q +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + vpc_config { + security_group_ids = [aws_security_group.test.id] + subnets = [aws_subnet.test.id] + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_vpcUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_vpc" "test" { + cidr_block = "10.0.0.0/16" +} + +resource "aws_subnet" "test" { + vpc_id = aws_vpc.test.id + cidr_block = "10.0.1.0/24" +} + +resource "aws_subnet" "test2" { + vpc_id = aws_vpc.test.id + cidr_block = "10.0.2.0/24" +} + +resource "aws_security_group" "test" { + vpc_id = aws_vpc.test.id + name = %[1]q +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 7200 + } + + vpc_config { + security_group_ids = [aws_security_group.test.id] + subnets = [aws_subnet.test.id, aws_subnet.test2.id] + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_debug(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_iam_policy_document" "s3" { + statement { + actions = [ + "s3:GetObject", + "s3:PutObject" + ] + resources = [ + "${aws_s3_bucket.test.arn}/*" + ] + } + statement { + actions = [ + "s3:ListBucket" + ] + resources = [ + aws_s3_bucket.test.arn + ] + } +} + +resource "aws_iam_role_policy" "test" { + role = aws_iam_role.test.name + policy = data.aws_iam_policy_document.s3.json +} + +data "aws_sagemaker_prebuilt_ecr_image" "debugger" { + repository_name = "sagemaker-debugger-rules" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + debug_hook_config { + local_path = "/opt/ml/output/tensors" + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/debug/" + } + + debug_rule_configurations { + local_path = "/opt/ml/processing/test1" + rule_configuration_name = "LossNotDecreasing" + rule_evaluator_image = data.aws_sagemaker_prebuilt_ecr_image.debugger.registry_path + rule_parameters = { + "rule_to_invoke" = "LossNotDecreasing" + } + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/debug-rules/" + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_debugUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_iam_policy_document" "s3" { + statement { + actions = [ + "s3:GetObject", + "s3:PutObject" + ] + resources = [ + "${aws_s3_bucket.test.arn}/*" + ] + } + statement { + actions = [ + "s3:ListBucket" + ] + resources = [ + aws_s3_bucket.test.arn + ] + } +} + +resource "aws_iam_role_policy" "test" { + role = aws_iam_role.test.name + policy = data.aws_iam_policy_document.s3.json +} + +data "aws_sagemaker_prebuilt_ecr_image" "debugger" { + repository_name = "sagemaker-debugger-rules" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + debug_hook_config { + local_path = "/opt/ml/output/tensors" + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/debug-updated/" + } + + debug_rule_configurations { + local_path = "/opt/ml/processing/test1" + rule_configuration_name = "LossNotDecreasing" + rule_evaluator_image = data.aws_sagemaker_prebuilt_ecr_image.debugger.registry_path + rule_parameters = { + "rule_to_invoke" = "LossNotDecreasing" + } + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/debug-rules-updated/" + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_profiler(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_iam_policy_document" "s3" { + statement { + actions = [ + "s3:GetObject", + "s3:PutObject" + ] + resources = [ + "${aws_s3_bucket.test.arn}/*" + ] + } + statement { + actions = [ + "s3:ListBucket" + ] + resources = [ + aws_s3_bucket.test.arn + ] + } +} + +resource "aws_iam_role_policy" "test" { + role = aws_iam_role.test.name + policy = data.aws_iam_policy_document.s3.json +} + +data "aws_sagemaker_prebuilt_ecr_image" "debugger" { + repository_name = "sagemaker-debugger-rules" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + profiler_config { + disable_profiler = false + profiling_interval_in_milliseconds = 500 + profiling_parameters = { + "profile_cpu" = "true" + } + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/profiler/" + } + + profiler_rule_configurations { + local_path = "/opt/ml/processing/test" + rule_configuration_name = "ProfilerReport" + rule_evaluator_image = data.aws_sagemaker_prebuilt_ecr_image.debugger.registry_path + rule_parameters = { + "rule_to_invoke" = "ProfilerReport" + } + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/profiler-rules/" + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_profilerUpdated(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_iam_policy_document" "s3" { + statement { + actions = [ + "s3:GetObject", + "s3:PutObject" + ] + resources = [ + "${aws_s3_bucket.test.arn}/*" + ] + } + statement { + actions = [ + "s3:ListBucket" + ] + resources = [ + aws_s3_bucket.test.arn + ] + } +} + +resource "aws_iam_role_policy" "test" { + role = aws_iam_role.test.name + policy = data.aws_iam_policy_document.s3.json +} + +data "aws_sagemaker_prebuilt_ecr_image" "debugger" { + repository_name = "sagemaker-debugger-rules" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + profiler_config { + disable_profiler = false + profiling_interval_in_milliseconds = 1000 + profiling_parameters = { + "profile_cpu" = "false" + } + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/profiler/" + } + + profiler_rule_configurations { + local_path = "/opt/ml/processing/test" + rule_configuration_name = "ProfilerReport" + rule_evaluator_image = data.aws_sagemaker_prebuilt_ecr_image.debugger.registry_path + rule_parameters = { + "rule_to_invoke" = "ProfilerReport" + } + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/profiler-rules/" + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_environmentAndHyperParameters(rName string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q +} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + repository_name = "pytorch-training" + image_tag = "2.0.0-cpu-py310-ubuntu20.04-sagemaker" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + enable_inter_container_traffic_encryption = true + enable_managed_spot_training = true + enable_network_isolation = false + + environment = { + "TEST_ENV" = "test_value" + "ANOTHER_ENV" = "another_value" + } + + hyper_parameters = { + "epochs" = "10" + "batch_size" = "32" + } + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + max_wait_time_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName) +} + +func testAccTrainingJobConfig_environmentAndHyperParametersUpdate(rName string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q +} + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + repository_name = "pytorch-training" + image_tag = "2.0.0-cpu-py310-ubuntu20.04-sagemaker" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + enable_inter_container_traffic_encryption = false + enable_managed_spot_training = true + enable_network_isolation = false + + environment = { + "TEST_ENV" = "updated_value" + "ANOTHER_ENV" = "another_value" + } + + hyper_parameters = { + "epochs" = "20" + "batch_size" = "32" + } + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 7200 + max_wait_time_in_seconds = 8000 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName) +} + +func testAccTrainingJobConfig_checkpoint(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + checkpoint_config { + local_path = "/opt/ml/checkpoints" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/checkpoints/" + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_checkpointUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + checkpoint_config { + local_path = "/opt/ml/checkpoints" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/checkpoints-v2/" + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_tensorBoard(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + tensor_board_output_config { + local_path = "/opt/ml/output/tensorboard" + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/tensorboard/" + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_tensorBoardUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + tensor_board_output_config { + local_path = "/opt/ml/output/tensorboard" + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/tensorboard-v2/" + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_inputData(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_iam_policy_document" "s3" { + statement { + actions = [ + "s3:GetObject", + "s3:PutObject" + ] + resources = [ + "${aws_s3_bucket.test.arn}/*" + ] + } + statement { + actions = [ + "s3:ListBucket" + ] + resources = [ + aws_s3_bucket.test.arn + ] + } +} + +resource "aws_iam_role_policy" "test" { + role = aws_iam_role.test.name + policy = data.aws_iam_policy_document.s3.json +} + +resource "aws_s3_object" "input" { + bucket = aws_s3_bucket.test.id + key = "input/placeholder.csv" + content = "feature1,label\n1.0,0\n" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + input_data_config { + channel_name = "training" + compression_type = "None" + content_type = "text/csv" + input_mode = "File" + record_wrapper_type = "None" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/input/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.test, aws_s3_object.input] +} +`, rName)) +} + +func testAccTrainingJobConfig_inputDataUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_iam_policy_document" "s3" { + statement { + actions = [ + "s3:GetObject", + "s3:PutObject" + ] + resources = [ + "${aws_s3_bucket.test.arn}/*" + ] + } + statement { + actions = [ + "s3:ListBucket" + ] + resources = [ + aws_s3_bucket.test.arn + ] + } +} + +resource "aws_iam_role_policy" "test" { + role = aws_iam_role.test.name + policy = data.aws_iam_policy_document.s3.json +} + +resource "aws_s3_object" "input_v2" { + bucket = aws_s3_bucket.test.id + key = "input-v2/placeholder.csv" + content = "feature1,label\n1.0,0\n" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + input_data_config { + channel_name = "training" + compression_type = "None" + content_type = "text/csv" + input_mode = "File" + record_wrapper_type = "None" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/input-v2/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.test, aws_s3_object.input_v2] +} +`, rName)) +} + +func testAccTrainingJobConfig_outputData(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_kms_key" "test" { + description = "KMS key for SageMaker training job" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + compression_type = "GZIP" + kms_key_id = aws_kms_key.test.arn + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_outputDataUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_kms_key" "test" { + description = "KMS key for SageMaker training job" +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + compression_type = "NONE" + kms_key_id = aws_kms_key.test.arn + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_algorithmMetrics(rName, customImage string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_caller_identity" "current" {} +data "aws_region" "current" {} + +resource "aws_iam_role_policy" "ecr" { + name = "%[1]s-ecr" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Effect = "Allow" + Action = [ + "ecr:BatchCheckLayerAvailability", + "ecr:BatchGetImage", + "ecr:GetDownloadUrlForLayer", + "ecr:GetAuthorizationToken", + ] + Resource = "*" + }, + ] + }) +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = %[2]q + + metric_definitions { + name = "train:loss" + regex = "loss: ([0-9\\.]+)" + } + + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.ecr] +} +`, rName, customImage)) +} + +func testAccTrainingJobConfig_algorithmMetricsUpdate(rName, customImage string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +data "aws_caller_identity" "current" {} +data "aws_region" "current" {} + +resource "aws_iam_role_policy" "ecr" { + name = "%[1]s-ecr" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Effect = "Allow" + Action = [ + "ecr:BatchCheckLayerAvailability", + "ecr:BatchGetImage", + "ecr:GetDownloadUrlForLayer", + "ecr:GetAuthorizationToken", + ] + Resource = "*" + }, + ] + }) +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = %[2]q + + metric_definitions { + name = "train:loss" + regex = "loss: ([0-9\\.]+)" + } + + metric_definitions { + name = "validation:accuracy" + regex = "accuracy: ([0-9\\.]+)" + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.ecr] +} +`, rName, customImage)) +} + +func testAccTrainingJobConfig_retryStrategy(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + retry_strategy { + maximum_retry_attempts = 3 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_retryStrategyUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + retry_strategy { + maximum_retry_attempts = 5 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_serverless(rName string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} +data "aws_region" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_iam_role_policy" "hub_access" { + name = "%[1]s-hub" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Action = ["sagemaker:DescribeHubContent"] + Resource = ["*"] + }] + }) +} + +resource "aws_iam_role_policy" "s3" { + name = "%[1]s-s3" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Action = ["s3:GetObject", "s3:PutObject", "s3:ListBucket", "s3:DeleteObject"] + Resource = [ + "arn:${data.aws_partition.current.partition}:s3:::%[1]s", + "arn:${data.aws_partition.current.partition}:s3:::%[1]s/*" + ] + }] + }) +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q + force_destroy = true +} + +resource "aws_s3_object" "training" { + bucket = aws_s3_bucket.test.id + key = "train/placeholder.jsonl" + content = "{\"prompt\": \"hello\", \"completion\": \"world\"}\n" +} + +resource "aws_sagemaker_model_package_group" "test" { + model_package_group_name = %[1]q + + depends_on = [aws_iam_role_policy_attachment.test] +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + input_data_config { + channel_name = "train" + content_type = "application/jsonlines" + input_mode = "File" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/train/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + model_package_config { + model_package_group_arn = aws_sagemaker_model_package_group.test.arn + } + + serverless_job_config { + accept_eula = true + base_model_arn = "arn:${data.aws_partition.current.partition}:sagemaker:${data.aws_region.current.name}:aws:hub-content/SageMakerPublicHub/Model/meta-textgeneration-llama-3-1-8b-instruct/2.40.0" + job_type = "FineTuning" + customization_technique = "SFT" + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.hub_access, aws_iam_role_policy.s3, aws_s3_object.training] +} +`, rName) +} + +func testAccTrainingJobConfig_serverlessUpdate(rName string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} +data "aws_region" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_iam_role_policy" "hub_access" { + name = "%[1]s-hub" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Action = ["sagemaker:DescribeHubContent"] + Resource = ["*"] + }] + }) +} + +resource "aws_iam_role_policy" "s3" { + name = "%[1]s-s3" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Action = ["s3:GetObject", "s3:PutObject", "s3:ListBucket", "s3:DeleteObject"] + Resource = [ + "arn:${data.aws_partition.current.partition}:s3:::%[1]s", + "arn:${data.aws_partition.current.partition}:s3:::%[1]s/*" + ] + }] + }) +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q + force_destroy = true +} + +resource "aws_s3_object" "training" { + bucket = aws_s3_bucket.test.id + key = "train/placeholder.jsonl" + content = "{\"prompt\": \"hello\", \"completion\": \"world\"}\n" +} + +resource "aws_sagemaker_model_package_group" "test" { + model_package_group_name = %[1]q + + depends_on = [aws_iam_role_policy_attachment.test] +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + input_data_config { + channel_name = "train" + content_type = "application/jsonlines" + input_mode = "File" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/train/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + model_package_config { + model_package_group_arn = aws_sagemaker_model_package_group.test.arn + } + + serverless_job_config { + accept_eula = true + base_model_arn = "arn:${data.aws_partition.current.partition}:sagemaker:${data.aws_region.current.name}:aws:hub-content/SageMakerPublicHub/Model/meta-textgeneration-llama-3-1-8b-instruct/2.40.0" + job_type = "FineTuning" + customization_technique = "DPO" + peft = "LORA" + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.hub_access, aws_iam_role_policy.s3, aws_s3_object.training] +} +`, rName) +} + +func testAccTrainingJobConfig_tags1(rName, tagKey1, tagValue1 string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + tags = { + %[2]q = %[3]q + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName, tagKey1, tagValue1)) +} + +func testAccTrainingJobConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + tags = { + %[2]q = %[3]q + %[4]q = %[5]q + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName, tagKey1, tagValue1, tagKey2, tagValue2)) +} + +func testAccTrainingJobConfig_infraCheck(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + infra_check_config { + enable_infra_check = true + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_infraCheckUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + infra_check_config { + enable_infra_check = false + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_mlflow(rName, novaModelARN string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_iam_role_policy" "s3" { + name = "%[1]s-s3" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Action = ["s3:GetObject", "s3:PutObject", "s3:ListBucket", "s3:DeleteObject"] + Resource = [ + "arn:${data.aws_partition.current.partition}:s3:::%[1]s", + "arn:${data.aws_partition.current.partition}:s3:::%[1]s/*" + ] + }] + }) +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q + force_destroy = true +} + +resource "aws_s3_object" "training" { + bucket = aws_s3_bucket.test.id + key = "train/placeholder.jsonl" + content = "{\"prompt\": \"hello\", \"completion\": \"world\"}\n" +} + +resource "aws_sagemaker_model_package_group" "test" { + model_package_group_name = %[1]q + + depends_on = [aws_iam_role_policy_attachment.test] +} + +resource "aws_sagemaker_mlflow_tracking_server" "test" { + tracking_server_name = %[1]q + artifact_store_uri = "s3://${aws_s3_bucket.test.bucket}/mlflow/" + role_arn = aws_iam_role.test.arn +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + input_data_config { + channel_name = "train" + content_type = "application/jsonlines" + input_mode = "File" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/train/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + serverless_job_config { + accept_eula = true + base_model_arn = %[2]q + job_type = "FineTuning" + customization_technique = "SFT" + } + + model_package_config { + model_package_group_arn = aws_sagemaker_model_package_group.test.arn + } + + mlflow_config { + mlflow_experiment_name = %[1]q + mlflow_resource_arn = aws_sagemaker_mlflow_tracking_server.test.arn + mlflow_run_name = %[1]q + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.s3, aws_s3_object.training] +} +`, rName, novaModelARN) +} + +func testAccTrainingJobConfig_mlflowUpdate(rName, novaModelARN string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole", "sts:SetSourceIdentity"] + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +resource "aws_iam_role_policy_attachment" "test" { + role = aws_iam_role.test.name + policy_arn = "arn:${data.aws_partition.current.partition}:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_iam_role_policy" "s3" { + name = "%[1]s-s3" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Action = ["s3:GetObject", "s3:PutObject", "s3:ListBucket", "s3:DeleteObject"] + Resource = [ + "arn:${data.aws_partition.current.partition}:s3:::%[1]s", + "arn:${data.aws_partition.current.partition}:s3:::%[1]s/*" + ] + }] + }) +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q + force_destroy = true +} + +resource "aws_s3_object" "training" { + bucket = aws_s3_bucket.test.id + key = "train/placeholder.jsonl" + content = "{\"prompt\": \"hello\", \"completion\": \"world\"}\n" +} + +resource "aws_sagemaker_model_package_group" "test" { + model_package_group_name = %[1]q + + depends_on = [aws_iam_role_policy_attachment.test] +} + +resource "aws_sagemaker_mlflow_tracking_server" "test" { + tracking_server_name = %[1]q + artifact_store_uri = "s3://${aws_s3_bucket.test.bucket}/mlflow/" + role_arn = aws_iam_role.test.arn +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + input_data_config { + channel_name = "train" + content_type = "application/jsonlines" + input_mode = "File" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/train/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + serverless_job_config { + accept_eula = true + base_model_arn = %[2]q + job_type = "FineTuning" + customization_technique = "SFT" + } + + model_package_config { + model_package_group_arn = aws_sagemaker_model_package_group.test.arn + } + + mlflow_config { + mlflow_experiment_name = %[1]q + mlflow_resource_arn = aws_sagemaker_mlflow_tracking_server.test.arn + mlflow_run_name = %[1]q + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.s3, aws_s3_object.training] +} +`, rName, novaModelARN) +} + +func testAccTrainingJobConfig_remoteDebug(rName, jobName, customImage string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_iam_role_policy" "ecr" { + name = "%[1]s-ecr" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Effect = "Allow" + Action = [ + "ecr:BatchCheckLayerAvailability", + "ecr:BatchGetImage", + "ecr:GetDownloadUrlForLayer", + "ecr:GetAuthorizationToken", + ] + Resource = "*" + }, + ] + }) +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[2]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = %[3]q + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + remote_debug_config { + enable_remote_debug = false + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.ecr] +} +`, rName, jobName, customImage)) +} + +func testAccTrainingJobConfig_remoteDebugUpdate(rName, jobName, customImage string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_iam_role_policy" "ecr" { + name = "%[1]s-ecr" + role = aws_iam_role.test.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Effect = "Allow" + Action = [ + "ecr:BatchCheckLayerAvailability", + "ecr:BatchGetImage", + "ecr:GetDownloadUrlForLayer", + "ecr:GetAuthorizationToken", + ] + Resource = "*" + }, + ] + }) +} + +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[2]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = %[3]q + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + remote_debug_config { + enable_remote_debug = true + } + + depends_on = [aws_iam_role_policy_attachment.test, aws_iam_role_policy.ecr] +} +`, rName, jobName, customImage)) +} + +func testAccTrainingJobConfig_sessionChaining(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + session_chaining_config { + enable_session_tag_chaining = true + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} + +func testAccTrainingJobConfig_sessionChainingUpdate(rName string) string { + return acctest.ConfigCompose(testAccTrainingJobConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_training_job" "test" { + training_job_name = %[1]q + role_arn = aws_iam_role.test.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + session_chaining_config { + enable_session_tag_chaining = false + } + + depends_on = [aws_iam_role_policy_attachment.test] +} +`, rName)) +} diff --git a/internal/service/sagemaker/wait.go b/internal/service/sagemaker/wait.go index e5eb612c1e3a..89158dfd9054 100644 --- a/internal/service/sagemaker/wait.go +++ b/internal/service/sagemaker/wait.go @@ -751,3 +751,53 @@ func waitMlflowAppDeleted(ctx context.Context, conn *sagemaker.Client, arn strin return err } + +func waitTrainingJobCreated(ctx context.Context, conn *sagemaker.Client, id string, timeout time.Duration) (*sagemaker.DescribeTrainingJobOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: []string{}, + Target: enum.Slice(awstypes.TrainingJobStatusInProgress), + Refresh: statusTrainingJob(conn, id), + Timeout: timeout, + ContinuousTargetOccurence: 2, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + if out, ok := outputRaw.(*sagemaker.DescribeTrainingJobOutput); ok { + return out, err + } + + return nil, err +} + +func waitTrainingJobDeleted(ctx context.Context, conn *sagemaker.Client, id string, timeout time.Duration) (*sagemaker.DescribeTrainingJobOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.TrainingJobStatusDeleting, awstypes.TrainingJobStatusInProgress, awstypes.TrainingJobStatusStopping), + Target: []string{}, + Refresh: statusTrainingJob(conn, id), + Timeout: timeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + if out, ok := outputRaw.(*sagemaker.DescribeTrainingJobOutput); ok { + return out, err + } + + return nil, err +} + +func waitTrainingJobStopped(ctx context.Context, conn *sagemaker.Client, id string, timeout time.Duration) (*sagemaker.DescribeTrainingJobOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.TrainingJobStatusInProgress, awstypes.TrainingJobStatusStopping), + Target: enum.Slice(awstypes.TrainingJobStatusCompleted, awstypes.TrainingJobStatusFailed, awstypes.TrainingJobStatusStopped), + Refresh: statusTrainingJob(conn, id), + Timeout: timeout, + ContinuousTargetOccurence: 2, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + if out, ok := outputRaw.(*sagemaker.DescribeTrainingJobOutput); ok { + return out, err + } + + return nil, err +} diff --git a/website/docs/list-resources/sagemaker_training_job.html.markdown b/website/docs/list-resources/sagemaker_training_job.html.markdown new file mode 100644 index 000000000000..b584684b83da --- /dev/null +++ b/website/docs/list-resources/sagemaker_training_job.html.markdown @@ -0,0 +1,25 @@ +--- +subcategory: "SageMaker AI" +layout: "aws" +page_title: "AWS: aws_sagemaker_training_job" +description: |- + Lists SageMaker AI Training Job resources. +--- + +# List Resource: aws_sagemaker_training_job + +Lists SageMaker AI Training Job resources. + +## Example Usage + +```terraform +list "aws_sagemaker_training_job" "example" { + provider = aws +} +``` + +## Argument Reference + +This list resource supports the following arguments: + +* `region` - (Optional) Region to query. Defaults to provider region. diff --git a/website/docs/r/sagemaker_training_job.html.markdown b/website/docs/r/sagemaker_training_job.html.markdown new file mode 100644 index 000000000000..36c863fb7fa0 --- /dev/null +++ b/website/docs/r/sagemaker_training_job.html.markdown @@ -0,0 +1,590 @@ +--- +subcategory: "SageMaker AI" +layout: "aws" +page_title: "AWS: aws_sagemaker_training_job" +description: |- + Manages an AWS SageMaker AI Training Job. +--- + +# Resource: aws_sagemaker_training_job + +Manages an AWS SageMaker AI Training Job. + +## Example Usage + +### Basic Usage + +```terraform +resource "aws_sagemaker_training_job" "example" { + training_job_name = "example" + role_arn = aws_iam_role.example.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.example.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } +} +``` + +### With VPC Configuration + +```terraform +resource "aws_sagemaker_training_job" "example" { + training_job_name = "example" + role_arn = aws_iam_role.example.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.example.registry_path + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + vpc_config { + security_group_ids = [aws_security_group.example.id] + subnets = [aws_subnet.example.id] + } +} +``` + +### With Input Data and Hyperparameters + +```terraform +resource "aws_sagemaker_training_job" "example" { + training_job_name = "example" + role_arn = aws_iam_role.example.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.example.registry_path + enable_sagemaker_metrics_time_series = true + } + + hyper_parameters = { + "mini_batch_size" = "200" + "epochs" = "10" + } + + input_data_config { + channel_name = "train" + + data_source { + s3_data_source { + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.example.bucket}/train/" + } + } + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } +} +``` + +### With Encrypted Output, Checkpoints, and TensorBoard + +```terraform +resource "aws_sagemaker_training_job" "example" { + training_job_name = "example" + role_arn = aws_iam_role.example.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.example.registry_path + } + + checkpoint_config { + local_path = "/opt/ml/checkpoints" + s3_uri = "s3://${aws_s3_bucket.example.bucket}/checkpoints/" + } + + output_data_config { + compression_type = "GZIP" + kms_key_id = aws_kms_key.example.arn + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + volume_kms_key_id = aws_kms_key.example.arn + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } + + tensor_board_output_config { + local_path = "/opt/ml/output/tensorboard" + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/tensorboard/" + } +} +``` + +### With Managed Spot Training and Custom Metrics + +```terraform +resource "aws_sagemaker_training_job" "example" { + training_job_name = "example" + role_arn = aws_iam_role.example.arn + enable_managed_spot_training = true + enable_network_isolation = true + enable_inter_container_traffic_encryption = true + + algorithm_specification { + training_input_mode = "File" + training_image = var.training_image + container_entrypoint = ["python", "/opt/ml/code/train.py"] + container_arguments = ["--epochs", "10", "--batch-size", "128"] + + metric_definitions { + name = "train:loss" + regex = "loss: ([0-9\\.]+)" + } + + metric_definitions { + name = "validation:accuracy" + regex = "accuracy: ([0-9\\.]+)" + } + } + + environment = { + MODEL_DIR = "/opt/ml/model" + SM_LOG_LEVEL = "20" + } + + hyper_parameters = { + epochs = "10" + batch_size = "128" + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.xlarge" + instance_count = 1 + volume_size_in_gb = 50 + keep_alive_period_in_seconds = 600 + } + + retry_strategy { + maximum_retry_attempts = 3 + } + + stopping_condition { + max_runtime_in_seconds = 3600 + max_wait_time_in_seconds = 7200 + } + + tags = { + Environment = "test" + Workload = "training" + } +} +``` + +### With Multiple Input Channels, Infrastructure Checks, and Session Tag Chaining + +```terraform +resource "aws_sagemaker_training_job" "example" { + training_job_name = "example" + role_arn = aws_iam_role.example.arn + + algorithm_specification { + training_input_mode = "File" + training_image = data.aws_sagemaker_prebuilt_ecr_image.example.registry_path + } + + input_data_config { + channel_name = "train" + content_type = "text/csv" + input_mode = "File" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.example.bucket}/train/" + } + } + } + + input_data_config { + channel_name = "validation" + content_type = "text/csv" + input_mode = "File" + + data_source { + s3_data_source { + s3_data_distribution_type = "FullyReplicated" + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_bucket.example.bucket}/validation/" + } + } + } + + infra_check_config { + enable_infra_check = true + } + + output_data_config { + s3_output_path = "s3://${aws_s3_bucket.example.bucket}/output/" + } + + resource_config { + instance_type = "ml.m5.large" + instance_count = 1 + volume_size_in_gb = 30 + } + + session_chaining_config { + enable_session_tag_chaining = true + } + + stopping_condition { + max_runtime_in_seconds = 3600 + } +} +``` + +## Argument Reference + +The following arguments are required: + +* `role_arn` - (Required) ARN of the IAM role that SageMaker AI assumes to perform tasks on your behalf during training. +* `training_job_name` - (Required) Name of the training job. Must be between 1 and 63 characters, start with a letter or number, and contain only letters, numbers, and hyphens. +* `output_data_config` - (Required) Location of the output data from the training job. See [`output_data_config`](#output_data_config) below. + +The following arguments are optional: + +* `algorithm_specification` - (Optional) Algorithm-related parameters of the training job. See [`algorithm_specification`](#algorithm_specification) below. Conflicts with `serverless_job_config`. +* `checkpoint_config` - (Optional) Location of checkpoints during training. See [`checkpoint_config`](#checkpoint_config) below. Conflicts with `serverless_job_config`. +* `debug_hook_config` - (Optional) Configuration for debugging rules. See [`debug_hook_config`](#debug_hook_config) below. Conflicts with `serverless_job_config`. +* `debug_rule_configurations` - (Optional) List of debug rule configurations. Maximum of 20. See [`debug_rule_configurations`](#debug_rule_configurations) below. +* `enable_inter_container_traffic_encryption` - (Optional) Whether to encrypt inter-container traffic. When enabled, communications between containers are encrypted. +* `enable_managed_spot_training` - (Optional) Whether to use managed spot training. Optimizes the cost of training by using Amazon EC2 Spot Instances. Conflicts with `serverless_job_config`. +* `enable_network_isolation` - (Optional) Whether to isolate the training container from the network. No inbound or outbound network calls can be made. +* `environment` - (Optional) Map of environment variables to set in the training container. Maximum of 100 entries. Conflicts with `serverless_job_config`. +* `experiment_config` - (Optional) Associates a SageMaker AI Experiment or Trial to the training job. See [`experiment_config`](#experiment_config) below. Conflicts with `serverless_job_config`. +* `hyper_parameters` - (Optional) Map of hyperparameters for the training algorithm. Maximum of 100 entries. +* `infra_check_config` - (Optional) Infrastructure health check configuration. See [`infra_check_config`](#infra_check_config) below. +* `input_data_config` - (Optional) List of input data channel configurations for the training job. Maximum of 20. See [`input_data_config`](#input_data_config) below. +* `mlflow_config` - (Optional) MLflow integration configuration. See [`mlflow_config`](#mlflow_config) below. +* `model_package_config` - (Optional) Model package configuration. Requires `serverless_job_config`. See [`model_package_config`](#model_package_config) below. +* `profiler_config` - (Optional) Configuration for the profiler. See [`profiler_config`](#profiler_config) below. Conflicts with `serverless_job_config`. +* `profiler_rule_configurations` - (Optional) List of profiler rule configurations. Maximum of 20. See [`profiler_rule_configurations`](#profiler_rule_configurations) below. Conflicts with `serverless_job_config`. +* `region` - (Optional) Region where this resource will be [managed](https://docs.aws.amazon.com/general/latest/gr/rande.html#regional-endpoints). Defaults to the Region set in the [provider configuration](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#aws-configuration-reference). +* `remote_debug_config` - (Optional) Configuration for remote debugging. See [`remote_debug_config`](#remote_debug_config) below. +* `resource_config` - (Optional) Resources for the training job, including compute instances and storage volumes. See [`resource_config`](#resource_config) below. +* `retry_strategy` - (Optional) Number of times to retry the job if it fails. See [`retry_strategy`](#retry_strategy) below. Conflicts with `serverless_job_config`. +* `serverless_job_config` - (Optional) Configuration for serverless training jobs using foundation models. Conflicts with `algorithm_specification`, `enable_managed_spot_training`, `environment`, `retry_strategy`, `checkpoint_config`, `debug_hook_config`, `experiment_config`, `profiler_config`, `profiler_rule_configurations`, and `tensor_board_output_config`. See [`serverless_job_config`](#serverless_job_config) below. +* `session_chaining_config` - (Optional) Configuration for session tag chaining. See [`session_chaining_config`](#session_chaining_config) below. +* `tags` - (Optional) Map of tags to assign to the resource. If configured with a provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. +* `tensor_board_output_config` - (Optional) Configuration for TensorBoard output. See [`tensor_board_output_config`](#tensor_board_output_config) below. Conflicts with `serverless_job_config`. +* `vpc_config` - (Optional) VPC configuration for the training job. See [`vpc_config`](#vpc_config) below. + +### `algorithm_specification` + +* `algorithm_name` - (Optional) Name or ARN of the algorithm resource to use for the training job. +* `container_arguments` - (Optional) List of arguments for the container entrypoint. Maximum of 100 entries. +* `container_entrypoint` - (Optional) List of entrypoint commands for the container. Maximum of 100 entries. +* `enable_sagemaker_metrics_time_series` - (Optional) Whether to enable SageMaker AI metrics time series collection. +* `metric_definitions` - (Optional) List of metric definitions for the training job. Maximum of 40. Use this to extract custom metrics from your own training container logs. SageMaker can still publish built-in metrics for built-in algorithms and supported prebuilt images when this block is omitted. See [`metric_definitions`](#metric_definitions) below. +* `training_image` - (Optional) Registry path of the Docker image that contains the training algorithm. +* `training_image_config` - (Optional) Training image configuration. See [`training_image_config`](#training_image_config) below. +* `training_input_mode` - (Optional) Input mode for the training data. Valid values: `File`, `Pipe`, `FastFile`. + +### `metric_definitions` + +Use `metric_definitions` when you need to parse custom metrics from your training container logs. For SageMaker built-in algorithms and supported prebuilt images, SageMaker can populate default metrics without this block. + +* `name` - (Required) Name of the metric. +* `regex` - (Required) Regular expression that searches the output of the training job and captures the value of the metric. + +### `training_image_config` + +* `training_repository_access_mode` - (Optional) Access mode for the training image repository. +* `training_repository_auth_config` - (Optional) Authentication configuration for the training image repository. See [`training_repository_auth_config`](#training_repository_auth_config) below. + +### `training_repository_auth_config` + +* `training_repository_credentials_provider_arn` - (Optional) ARN of the Lambda function that provides credentials to authenticate to the private Docker registry. + +### `checkpoint_config` + +* `local_path` - (Optional) Local path where checkpoints are written. +* `s3_uri` - (Required) S3 URI where checkpoints are stored. + +### `debug_hook_config` + +* `collection_configurations` - (Optional) List of tensor collections to configure for the debug hook. Maximum of 20. See [`collection_configurations`](#collection_configurations) below. +* `hook_parameters` - (Optional) Map of parameters for the debug hook. Maximum of 20 entries. +* `local_path` - (Optional) Local path where debug output is written. +* `s3_output_path` - (Required) S3 URI where debug output is stored. + +### `collection_configurations` + +* `collection_name` - (Optional) Name of the tensor collection. +* `collection_parameters` - (Optional) Map of parameters for the tensor collection. + +### `debug_rule_configurations` + +* `instance_type` - (Optional) Instance type to deploy for the debug rule evaluation. Valid values are SageMaker AI processing instance types. +* `local_path` - (Optional) Local path where debug rule output is written. +* `rule_configuration_name` - (Required) Name of the rule configuration. Must be between 1 and 256 characters. +* `rule_evaluator_image` - (Required) Docker image URI for the rule evaluator. +* `rule_parameters` - (Optional) Map of parameters for the rule configuration. Maximum of 100 entries. +* `s3_output_path` - (Optional) S3 URI where rule output is stored. +* `volume_size_in_gb` - (Optional) Size of the storage volume for the rule evaluator, in GB. + +### `experiment_config` + +* `experiment_name` - (Optional) Name of the SageMaker AI Experiment to associate with. +* `run_name` - (Optional) Name of the Experiment Run to associate with. +* `trial_component_display_name` - (Optional) Display name for the trial component. +* `trial_name` - (Optional) Name of the SageMaker AI Trial to associate with. + +### `infra_check_config` + +* `enable_infra_check` - (Optional) Whether to enable infrastructure health checks before training. + +### `input_data_config` + +* `channel_name` - (Required) Name of the channel. Must be between 1 and 64 characters. +* `compression_type` - (Optional) Compression type for the input data. Valid values: `None`, `Gzip`. +* `content_type` - (Optional) MIME type of the input data. +* `data_source` - (Required) Location of the channel data. See [`data_source`](#data_source) below. +* `input_mode` - (Optional) Input mode for the channel data. Valid values: `File`, `Pipe`, `FastFile`. +* `record_wrapper_type` - (Optional) Record wrapper type. Valid values: `None`, `RecordIO`. +* `shuffle_config` - (Optional) Configuration for shuffling data in the channel. See [`shuffle_config`](#shuffle_config) below. + +### `data_source` + +* `file_system_data_source` - (Optional) File system data source. See [`file_system_data_source`](#file_system_data_source) below. +* `s3_data_source` - (Optional) S3 data source. See [`s3_data_source`](#s3_data_source) below. + +### `file_system_data_source` + +* `directory_path` - (Required) Full path to the directory on the file system. +* `file_system_access_mode` - (Required) Access mode for the file system. Valid values: `ro`, `rw`. +* `file_system_id` - (Required) File system ID. +* `file_system_type` - (Required) File system type. Valid values: `EFS`, `FSxLustre`. + +### `s3_data_source` + +* `attribute_names` - (Optional) List of attribute names to include in the training dataset. Maximum of 16. +* `hub_access_config` - (Optional) SageMaker AI Hub access configuration. See [`hub_access_config`](#hub_access_config) below. +* `instance_group_names` - (Optional) List of instance group names for the training data distribution. Maximum of 5. +* `model_access_config` - (Optional) Model access configuration. See [`model_access_config`](#model_access_config) below. +* `s3_data_distribution_type` - (Optional) Distribution type for S3 data. Valid values: `FullyReplicated`, `ShardedByS3Key`. +* `s3_data_type` - (Required) S3 data type. Valid values: `ManifestFile`, `S3Prefix`, `AugmentedManifestFile`. +* `s3_uri` - (Required) S3 URI of the data. + +### `hub_access_config` + +* `hub_content_arn` - (Required) ARN of the hub content. + +### `model_access_config` + +* `accept_eula` - (Required) Whether to accept the model EULA. + +### `shuffle_config` + +* `seed` - (Optional) Seed value used to shuffle the training data. + +### `mlflow_config` + +* `mlflow_experiment_name` - (Optional) Name of the MLflow experiment. +* `mlflow_resource_arn` - (Required) ARN of the MLflow tracking server. +* `mlflow_run_name` - (Optional) Name of the MLflow run. + +### `model_package_config` + +* `model_package_group_arn` - (Required) ARN of the model package group. +* `source_model_package_arn` - (Optional) ARN of the source model package. + +### `output_data_config` + +* `compression_type` - (Optional) Output compression type. Valid values: `GZIP`, `NONE`. +* `kms_key_id` - (Optional) KMS key ID used to encrypt the output data. +* `s3_output_path` - (Required) S3 URI where output data is stored. + +### `profiler_config` + +* `disable_profiler` - (Optional) Whether to disable the profiler. +* `profiling_interval_in_milliseconds` - (Optional) Time interval in milliseconds for capturing system metrics. Valid values: `100`, `200`, `500`, `1000`, `5000`, `60000`. +* `profiling_parameters` - (Optional) Map of profiling parameters. Maximum of 20 entries. +* `s3_output_path` - (Optional) S3 URI where profiler output is stored. + +### `profiler_rule_configurations` + +* `instance_type` - (Optional) Instance type to deploy for the profiler rule evaluation. Valid values are SageMaker AI processing instance types. +* `local_path` - (Optional) Local path where profiler rule output is written. +* `rule_configuration_name` - (Required) Name of the profiler rule configuration. Must be between 1 and 256 characters. +* `rule_evaluator_image` - (Required) Docker image URI for the profiler rule evaluator. +* `rule_parameters` - (Optional) Map of parameters for the profiler rule. Maximum of 100 entries. +* `s3_output_path` - (Optional) S3 URI where profiler rule output is stored. +* `volume_size_in_gb` - (Optional) Size of the storage volume for the profiler rule evaluator, in GB. + +### `remote_debug_config` + +* `enable_remote_debug` - (Optional) Whether to enable remote debugging for the training job. + +### `resource_config` + +* `instance_count` - (Optional) Number of ML compute instances to use. Conflicts with `instance_groups`. +* `instance_groups` - (Optional) List of instance groups for heterogeneous cluster training. Maximum of 5. Conflicts with `instance_count`, `instance_type`, and `keep_alive_period_in_seconds`. See [`instance_groups`](#instance_groups) below. +* `instance_placement_config` - (Optional) Instance placement configuration. See [`instance_placement_config`](#instance_placement_config) below. +* `instance_type` - (Optional) ML compute instance type. Conflicts with `instance_groups`. +* `keep_alive_period_in_seconds` - (Optional) Time in seconds to keep instances alive after training completes, for warm pool reuse. Valid values: 0–3600. Conflicts with `instance_groups`. +* `training_plan_arn` - (Optional) ARN of the training plan to use. +* `volume_kms_key_id` - (Optional) KMS key ID used to encrypt data on the storage volume. +* `volume_size_in_gb` - (Optional) Size of the storage volume attached to each instance, in GB. + +### `instance_groups` + +* `instance_count` - (Optional) Number of instances in the group. +* `instance_group_name` - (Optional) Name of the instance group. +* `instance_type` - (Optional) ML compute instance type for the group. + +### `instance_placement_config` + +* `enable_multiple_jobs` - (Optional) Whether to enable multiple jobs on the same instance. +* `placement_specifications` - (Optional) Placement specifications for instance placement. See [`placement_specifications`](#placement_specifications) below. + +### `placement_specifications` + +* `instance_count` - (Optional) Number of instances in the placement. +* `ultra_server_id` - (Optional) Ultra server ID for the placement. + +### `retry_strategy` + +* `maximum_retry_attempts` - (Required) Maximum number of retry attempts. Valid values: 1–30. + +### `serverless_job_config` + +* `accept_eula` - (Optional) Whether to accept the model EULA. +* `base_model_arn` - (Required) ARN of the base foundation model from the SageMaker AI Public Hub. +* `customization_technique` - (Optional) Customization technique to apply. Valid values: `FINE_TUNING`, `DOMAIN_ADAPTION`. +* `evaluation_type` - (Optional) Evaluation type. Valid values: `AUTOMATIC`, `HUMAN`, `NONE`. +* `evaluator_arn` - (Optional) ARN of the evaluator. +* `job_type` - (Required) Serverless job type. Valid values: `FINE_TUNING`, `EVALUATION`, `DISTILLATION`. +* `peft` - (Optional) Parameter-Efficient Fine-Tuning (PEFT) method. Valid values: `LORA`. + +### `session_chaining_config` + +* `enable_session_tag_chaining` - (Optional) Whether to enable session tag chaining for the training job. + +### `stopping_condition` + +* `max_pending_time_in_seconds` - (Optional) Maximum time in seconds a training job can be pending before it is stopped. Valid values: 7200–2419200. +* `max_runtime_in_seconds` - (Optional) Maximum time in seconds the training job can run before it is stopped. +* `max_wait_time_in_seconds` - (Optional) Maximum time in seconds to wait for a managed spot training job to complete. + +### `tensor_board_output_config` + +* `local_path` - (Optional) Local path where TensorBoard output is written. +* `s3_output_path` - (Required) S3 URI where TensorBoard output is stored. + +### `vpc_config` + +* `security_group_ids` - (Required) List of VPC security group IDs. Maximum of 5. +* `subnets` - (Required) List of subnet IDs. Maximum of 16. + +## Attribute Reference + +This resource exports the following attributes in addition to the arguments above: + +* `arn` - ARN of the Training Job. +* `tags_all` - Map of tags assigned to the resource, including those inherited from the provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block). + +## Timeouts + +[Configuration options](https://developer.hashicorp.com/terraform/language/resources/syntax#operation-timeouts): + +* `create` - (Default `45m`) +* `update` - (Default `45m`) +* `delete` - (Default `45m`) + +## Import + +In Terraform v1.12.0 and later, the [`import` block](https://developer.hashicorp.com/terraform/language/import) can be used with the `identity` attribute. For example: + +```terraform +import { + to = aws_sagemaker_training_job.example + identity = { + training_job_name = "my-training-job" + } +} + +resource "aws_sagemaker_training_job" "example" { + ### Configuration omitted for brevity ### +} +``` + +### Identity Schema + +#### Required + +* `training_job_name` - (String) Name of the Training Job. + +#### Optional + +* `account_id` (String) AWS Account where this resource is managed. +* `region` (String) Region where this resource is managed. + +In Terraform v1.5.0 and later, use an [`import` block](https://developer.hashicorp.com/terraform/language/import) to import SageMaker AI Training Job using the `training_job_name`. For example: + +```terraform +import { + to = aws_sagemaker_training_job.example + id = "my-training-job" +} +``` + +Using `terraform import`, import SageMaker AI Training Job using the `training_job_name`. For example: + +```console +% terraform import aws_sagemaker_training_job.example my-training-job +```