Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2170: Adding validation webhook for v2 trainjob #2307

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ const (
// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TorchEnvNamePrefix is the env name prefix for the distributed envs for torchrun.
TorchEnvNamePrefix = "PET_"

Comment on lines +75 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this constant to the L60, and re-use it for other torchrun envs:

// TorchEnvNumNodes is the env name for the number of training nodes.

// JobLauncher is the Job name for the launcher.
JobLauncher string = "launcher"

Expand Down Expand Up @@ -111,6 +114,8 @@ const (
// Distributed envs for mpirun.
// Values for OpenMPI implementation.
OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile"

UnsupportedRuntimeErrMsg string = "the specified runtime is not supported"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tenzen-y @akshaychitneni Do we want to keep this error in the trainjob_controller.go or constants ?

)

var (
Expand Down
13 changes: 2 additions & 11 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
jobruntimes "github.com/kubeflow/trainer/pkg/runtime"
)

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type objsOpState int

const (
Expand Down Expand Up @@ -83,10 +81,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, nil
}

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtimeRefGK := jobruntimes.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this function as:

Suggested change
runtimeRefGK := jobruntimes.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef)
runtimeRefGK := jobruntimes.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef)

runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return ctrl.Result{}, fmt.Errorf("%s: %s", constants.UnsupportedRuntimeErrMsg, runtimeRefGK)
}
opState, err := r.reconcileObjects(ctx, runtime, &trainJob)

Expand Down Expand Up @@ -214,13 +212,6 @@ func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
}

func runtimeRefToGroupKind(runtimeRef trainer.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
b := ctrl.NewControllerManagedBy(mgr).
WithOptions(options).
Expand Down
13 changes: 9 additions & 4 deletions pkg/runtime/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
Expand Down Expand Up @@ -69,14 +70,18 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
clusterTrainingRuntime := &trainer.ClusterTrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
Name: new.Spec.RuntimeRef.Name,
}, &trainer.ClusterTrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.runtimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, clusterTrainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: clusterTrainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
1 change: 0 additions & 1 deletion pkg/runtime/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package core

import (
"context"

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeflow/trainer/pkg/runtime"
Expand Down
49 changes: 31 additions & 18 deletions pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai
func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
) ([]client.Object, error) {

info := r.runtimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) runtimeInfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back to my point here: #2307 (comment)
Do we really need to generate Info object when we perform validation ?
The validation of TrainingRuntime executes before TrainJob is created, so we don't really need to construct Info object from TrainJob + TrainingRuntime for the TrainJob validation.

@akshaychitneni @tenzen-y Am I missing something ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that we use it here, since we need to fetch data from the TrainJob and ClusterTrainingRuntime to define what type of TrainJob validation we need to perform

if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil {
numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode
if numProcPerNode.Type != intstr.Int {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}

@tenzen-y Is that something that you had in mind when you designed the Runtime Framework ?

ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) *runtime.Info {

propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
Expand Down Expand Up @@ -113,19 +133,7 @@ func (r *TrainingRuntime) buildObjects(

info := runtime.NewInfo(opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
return info
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
Expand All @@ -141,14 +149,19 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
trainingRuntime := &trainer.TrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &trainer.TrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, trainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.runtimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: trainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
4 changes: 2 additions & 2 deletions pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(info *runtime.Info, trainJob
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
func (f *Framework) RunCustomValidationPlugins(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
warnings, errs := plugin.Validate(oldObj, newObj)
warnings, errs := plugin.Validate(runtimeJobTemplate, info, oldObj, newObj)
if len(warnings) != 0 {
aggregatedWarnings = append(aggregatedWarnings, warnings...)
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func TestNew(t *testing.T) {
customValidationPlugins: []framework.CustomValidationPlugin{
&mpi.MPI{},
&torch.Torch{},
&jobset.JobSet{},
},
watchExtensionPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
Expand Down Expand Up @@ -371,7 +372,9 @@ func TestRunCustomValidationPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
runtimeInfo := runtime.NewInfo()
jobSetTemplate := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test")
warnings, errs := fwk.RunCustomValidationPlugins(jobSetTemplate, runtimeInfo, tc.oldObj, tc.newObj)
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Plugin interface {

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
Validate(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
}

type WatchExtensionPlugin interface {
Expand Down
54 changes: 54 additions & 0 deletions pkg/runtime/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ import (
"context"
"fmt"
"maps"
"slices"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/equality"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand All @@ -52,6 +56,7 @@ type JobSet struct {
var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
var _ framework.TerminalConditionPlugin = (*JobSet)(nil)
var _ framework.CustomValidationPlugin = (*JobSet)(nil)

const Name = constants.JobSetKind

Expand Down Expand Up @@ -159,3 +164,52 @@ func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJ
}
return nil, nil
}

func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question, can we move it after Name() API.


var allErrs field.ErrorList
specPath := field.NewPath("spec")
runtimeRefPath := specPath.Child("runtimeRef")

jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet)
if !ok {
return nil, nil
}

if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil {
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool {
return x.Name == constants.JobInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input modelConfig", constants.JobInitializer)))
} else {
Comment on lines +179 to +184
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we can simplify this since if user sets DatasetConfig or ModelConfig we need to check that ReplicatedJob contains Initializer Job.

for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool {
return x.Name == constants.ContainerModelInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerModelInitializer, constants.JobInitializer)))
}
}
}
}
}

if newObj.Spec.DatasetConfig != nil {
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool {
return x.Name == constants.JobInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input datasetConfig", constants.JobInitializer)))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool {
return x.Name == constants.ContainerDatasetInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerDatasetInitializer, constants.JobInitializer)))
}
}
}
}
}
return nil, allErrs
}
17 changes: 14 additions & 3 deletions pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"k8s.io/apimachinery/pkg/util/intstr"
"maps"
"strconv"

Expand Down Expand Up @@ -75,9 +76,19 @@ func (m *MPI) Name() string {
return Name
}

// TODO: Need to implement validations for MPI Policy.
func (m *MPI) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to be consistent with Enforce API, and just exit this validation if validation is not required:

Suggested change
var allErrs field.ErrorList
var allErrs field.ErrorList
if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.MPI == nil {
return nil, allErrs
}

specPath := field.NewPath("spec")
if newJobObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil {
numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode
if numProcPerNode.Type != intstr.Int {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}
}
}
return nil, allErrs
}

func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {
Expand Down
35 changes: 30 additions & 5 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package torch
import (
"context"
"fmt"
"slices"
"strings"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
Expand Down Expand Up @@ -49,11 +51,6 @@ func (t *Torch) Name() string {
return Name
}

// TODO: Need to implement validations for Torch policy.
func (t *Torch) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
}

// TODO (andreyvelich): Add support for PyTorch elastic when JobSet supports Elastic Jobs.
func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {
if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.Torch == nil {
Expand Down Expand Up @@ -140,3 +137,31 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)

return nil
}

func (t *Torch) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akshaychitneni Please can you keep this Validate() function at the top of torch.go file for consistency with other plugins (e.g. MPI:

func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")
if newJobObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil {
numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode
if numProcPerNode.Type != intstr.Int {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}
}
}
return nil, allErrs
}
)?

var allErrs field.ErrorList
specPath := field.NewPath("spec")

if newObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy != nil &&
runtimeInfo.RuntimePolicy.MLPolicy.Torch != nil && newObj.Spec.Trainer.NumProcPerNode != nil {
numProcPerNode := *newObj.Spec.Trainer.NumProcPerNode
if numProcPerNode.Type == intstr.String {
allowedStringValList := []string{"auto", "cpu", "gpu"}
if !slices.Contains(allowedStringValList, numProcPerNode.StrVal) {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu"))
}
}
}

if slices.ContainsFunc(newObj.Spec.Trainer.Env, func(x corev1.EnvVar) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this TODO:

// TODO (andreyvelich): Add validation to check that TrainJob doesn't have "PET_" envs.

return strings.HasPrefix(x.Name, constants.TorchEnvNamePrefix)
}) {
trainerEnvsPath := specPath.Child("trainer").Child("env")
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("should not have envs with name having prefix %s", constants.TorchEnvNamePrefix)))
}
}

return nil, allErrs
}
9 changes: 9 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"maps"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/ptr"
kueuelr "sigs.k8s.io/kueue/pkg/util/limitrange"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand Down Expand Up @@ -146,3 +148,10 @@ func NewInfo(opts ...InfoOption) *Info {

return info
}

func RuntimeRefToRuntimeRegistryKey(runtimeRef trainer.RuntimeRef) string {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}.String()
}
Loading
Loading