-
Notifications
You must be signed in to change notification settings - Fork 733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KEP-2170: Adding validation webhook for v2 trainjob #2307
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,9 @@ const ( | |
// TorchEnvMasterPort is the env name for the master node port. | ||
TorchEnvMasterPort string = "PET_MASTER_PORT" | ||
|
||
// TorchEnvNamePrefix is the env name prefix for the distributed envs for torchrun. | ||
TorchEnvNamePrefix = "PET_" | ||
|
||
// JobLauncher is the Job name for the launcher. | ||
JobLauncher string = "launcher" | ||
|
||
|
@@ -111,6 +114,8 @@ const ( | |
// Distributed envs for mpirun. | ||
// Values for OpenMPI implementation. | ||
OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile" | ||
|
||
UnsupportedRuntimeErrMsg string = "the specified runtime is not supported" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tenzen-y @akshaychitneni Do we want to keep this error in the |
||
) | ||
|
||
var ( | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -39,8 +39,6 @@ import ( | |||||
jobruntimes "github.com/kubeflow/trainer/pkg/runtime" | ||||||
) | ||||||
|
||||||
var errorUnsupportedRuntime = errors.New("the specified runtime is not supported") | ||||||
|
||||||
type objsOpState int | ||||||
|
||||||
const ( | ||||||
|
@@ -83,10 +81,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c | |||||
return ctrl.Result{}, nil | ||||||
} | ||||||
|
||||||
runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String() | ||||||
runtimeRefGK := jobruntimes.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we call this function as:
Suggested change
|
||||||
runtime, ok := r.runtimes[runtimeRefGK] | ||||||
if !ok { | ||||||
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK) | ||||||
return ctrl.Result{}, fmt.Errorf("%s: %s", constants.UnsupportedRuntimeErrMsg, runtimeRefGK) | ||||||
} | ||||||
opState, err := r.reconcileObjects(ctx, runtime, &trainJob) | ||||||
|
||||||
|
@@ -214,13 +212,6 @@ func isTrainJobFinished(trainJob *trainer.TrainJob) bool { | |||||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed) | ||||||
} | ||||||
|
||||||
func runtimeRefToGroupKind(runtimeRef trainer.RuntimeRef) schema.GroupKind { | ||||||
return schema.GroupKind{ | ||||||
Group: ptr.Deref(runtimeRef.APIGroup, ""), | ||||||
Kind: ptr.Deref(runtimeRef.Kind, ""), | ||||||
} | ||||||
} | ||||||
|
||||||
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error { | ||||||
b := ctrl.NewControllerManagedBy(mgr). | ||||||
WithOptions(options). | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -83,6 +83,26 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai | |||||||||||
func (r *TrainingRuntime) buildObjects( | ||||||||||||
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy, | ||||||||||||
) ([]client.Object, error) { | ||||||||||||
|
||||||||||||
info := r.runtimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy) | ||||||||||||
if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil { | ||||||||||||
return nil, err | ||||||||||||
} | ||||||||||||
|
||||||||||||
if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil { | ||||||||||||
return nil, err | ||||||||||||
} | ||||||||||||
|
||||||||||||
jobSetTemplate := jobsetv1alpha2.JobSet{ | ||||||||||||
Spec: jobSetTemplateSpec.Spec, | ||||||||||||
} | ||||||||||||
|
||||||||||||
return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob) | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *TrainingRuntime) runtimeInfo( | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Back to my point here: #2307 (comment) @akshaychitneni @tenzen-y Am I missing something ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can see that we use it here, since we need to fetch data from the TrainJob and ClusterTrainingRuntime to define what type of TrainJob validation we need to perform trainer/pkg/runtime/framework/plugins/mpi/mpi.go Lines 84 to 88 in 32f04e3
@tenzen-y Is that something that you had in mind when you designed the Runtime Framework ? |
||||||||||||
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) *runtime.Info { | ||||||||||||
|
||||||||||||
propagationLabels := jobSetTemplateSpec.Labels | ||||||||||||
if propagationLabels == nil && trainJob.Spec.Labels != nil { | ||||||||||||
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) | ||||||||||||
|
@@ -113,19 +133,7 @@ func (r *TrainingRuntime) buildObjects( | |||||||||||
|
||||||||||||
info := runtime.NewInfo(opts...) | ||||||||||||
|
||||||||||||
if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil { | ||||||||||||
return nil, err | ||||||||||||
} | ||||||||||||
|
||||||||||||
if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil { | ||||||||||||
return nil, err | ||||||||||||
} | ||||||||||||
|
||||||||||||
jobSetTemplate := jobsetv1alpha2.JobSet{ | ||||||||||||
Spec: jobSetTemplateSpec.Spec, | ||||||||||||
} | ||||||||||||
|
||||||||||||
return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob) | ||||||||||||
return info | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { | ||||||||||||
|
@@ -141,14 +149,19 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { | |||||||||||
} | ||||||||||||
|
||||||||||||
func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) { | ||||||||||||
trainingRuntime := &trainer.TrainingRuntime{} | ||||||||||||
if err := r.client.Get(ctx, client.ObjectKey{ | ||||||||||||
Namespace: old.Namespace, | ||||||||||||
Name: old.Spec.RuntimeRef.Name, | ||||||||||||
}, &trainer.TrainingRuntime{}); err != nil { | ||||||||||||
Namespace: new.Namespace, | ||||||||||||
Name: new.Spec.RuntimeRef.Name, | ||||||||||||
}, trainingRuntime); err != nil { | ||||||||||||
return nil, field.ErrorList{ | ||||||||||||
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef, | ||||||||||||
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef, | ||||||||||||
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)), | ||||||||||||
} | ||||||||||||
} | ||||||||||||
return r.framework.RunCustomValidationPlugins(old, new) | ||||||||||||
info := r.runtimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) | ||||||||||||
jobSetTemplate := jobsetv1alpha2.JobSet{ | ||||||||||||
Spec: trainingRuntime.Spec.Template.Spec, | ||||||||||||
} | ||||||||||||
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new) | ||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,20 +20,24 @@ import ( | |
"context" | ||
"fmt" | ||
"maps" | ||
"slices" | ||
|
||
"github.com/go-logr/logr" | ||
corev1 "k8s.io/api/core/v1" | ||
"k8s.io/apimachinery/pkg/api/equality" | ||
apierrors "k8s.io/apimachinery/pkg/api/errors" | ||
"k8s.io/apimachinery/pkg/api/meta" | ||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" | ||
apiruntime "k8s.io/apimachinery/pkg/runtime" | ||
"k8s.io/apimachinery/pkg/runtime/schema" | ||
"k8s.io/apimachinery/pkg/util/validation/field" | ||
"k8s.io/utils/ptr" | ||
ctrl "sigs.k8s.io/controller-runtime" | ||
"sigs.k8s.io/controller-runtime/pkg/builder" | ||
"sigs.k8s.io/controller-runtime/pkg/cache" | ||
"sigs.k8s.io/controller-runtime/pkg/client" | ||
ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" | ||
"sigs.k8s.io/controller-runtime/pkg/webhook/admission" | ||
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" | ||
|
||
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" | ||
|
@@ -52,6 +56,7 @@ type JobSet struct { | |
var _ framework.WatchExtensionPlugin = (*JobSet)(nil) | ||
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) | ||
var _ framework.TerminalConditionPlugin = (*JobSet)(nil) | ||
var _ framework.CustomValidationPlugin = (*JobSet)(nil) | ||
|
||
const Name = constants.JobSetKind | ||
|
||
|
@@ -159,3 +164,52 @@ func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJ | |
} | ||
return nil, nil | ||
} | ||
|
||
func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question, can we move it after |
||
|
||
var allErrs field.ErrorList | ||
specPath := field.NewPath("spec") | ||
runtimeRefPath := specPath.Child("runtimeRef") | ||
|
||
jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet) | ||
if !ok { | ||
return nil, nil | ||
} | ||
|
||
if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil { | ||
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool { | ||
return x.Name == constants.JobInitializer | ||
}) { | ||
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input modelConfig", constants.JobInitializer))) | ||
} else { | ||
Comment on lines
+179
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, we can simplify this since if user sets DatasetConfig or ModelConfig we need to check that ReplicatedJob contains Initializer Job. |
||
for _, job := range jobSet.Spec.ReplicatedJobs { | ||
if job.Name == constants.JobInitializer { | ||
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool { | ||
return x.Name == constants.ContainerModelInitializer | ||
}) { | ||
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerModelInitializer, constants.JobInitializer))) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
if newObj.Spec.DatasetConfig != nil { | ||
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool { | ||
return x.Name == constants.JobInitializer | ||
}) { | ||
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input datasetConfig", constants.JobInitializer))) | ||
} else { | ||
for _, job := range jobSet.Spec.ReplicatedJobs { | ||
if job.Name == constants.JobInitializer { | ||
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool { | ||
return x.Name == constants.ContainerDatasetInitializer | ||
}) { | ||
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerDatasetInitializer, constants.JobInitializer))) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return nil, allErrs | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -25,6 +25,7 @@ import ( | |||||||||||||||
"crypto/x509" | ||||||||||||||||
"encoding/pem" | ||||||||||||||||
"fmt" | ||||||||||||||||
"k8s.io/apimachinery/pkg/util/intstr" | ||||||||||||||||
"maps" | ||||||||||||||||
"strconv" | ||||||||||||||||
|
||||||||||||||||
|
@@ -75,9 +76,19 @@ func (m *MPI) Name() string { | |||||||||||||||
return Name | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// TODO: Need to implement validations for MPI Policy. | ||||||||||||||||
func (m *MPI) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { | ||||||||||||||||
return nil, nil | ||||||||||||||||
func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { | ||||||||||||||||
var allErrs field.ErrorList | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to be consistent with Enforce API, and just exit this validation if validation is not required:
Suggested change
|
||||||||||||||||
specPath := field.NewPath("spec") | ||||||||||||||||
if newJobObj.Spec.Trainer != nil { | ||||||||||||||||
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode") | ||||||||||||||||
if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil { | ||||||||||||||||
numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode | ||||||||||||||||
if numProcPerNode.Type != intstr.Int { | ||||||||||||||||
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value")) | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
return nil, allErrs | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,6 +19,8 @@ package torch | |||||||||||||||||||||||||||||
import ( | ||||||||||||||||||||||||||||||
"context" | ||||||||||||||||||||||||||||||
"fmt" | ||||||||||||||||||||||||||||||
"slices" | ||||||||||||||||||||||||||||||
"strings" | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
corev1 "k8s.io/api/core/v1" | ||||||||||||||||||||||||||||||
"k8s.io/apimachinery/pkg/util/intstr" | ||||||||||||||||||||||||||||||
|
@@ -49,11 +51,6 @@ func (t *Torch) Name() string { | |||||||||||||||||||||||||||||
return Name | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// TODO: Need to implement validations for Torch policy. | ||||||||||||||||||||||||||||||
func (t *Torch) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { | ||||||||||||||||||||||||||||||
return nil, nil | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// TODO (andreyvelich): Add support for PyTorch elastic when JobSet supports Elastic Jobs. | ||||||||||||||||||||||||||||||
func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { | ||||||||||||||||||||||||||||||
if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.Torch == nil { | ||||||||||||||||||||||||||||||
|
@@ -140,3 +137,31 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return nil | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
func (t *Torch) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @akshaychitneni Please can you keep this trainer/pkg/runtime/framework/plugins/mpi/mpi.go Lines 79 to 92 in 32f04e3
|
||||||||||||||||||||||||||||||
var allErrs field.ErrorList | ||||||||||||||||||||||||||||||
specPath := field.NewPath("spec") | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if newObj.Spec.Trainer != nil { | ||||||||||||||||||||||||||||||
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode") | ||||||||||||||||||||||||||||||
if runtimeInfo.RuntimePolicy.MLPolicy != nil && | ||||||||||||||||||||||||||||||
runtimeInfo.RuntimePolicy.MLPolicy.Torch != nil && newObj.Spec.Trainer.NumProcPerNode != nil { | ||||||||||||||||||||||||||||||
numProcPerNode := *newObj.Spec.Trainer.NumProcPerNode | ||||||||||||||||||||||||||||||
if numProcPerNode.Type == intstr.String { | ||||||||||||||||||||||||||||||
allowedStringValList := []string{"auto", "cpu", "gpu"} | ||||||||||||||||||||||||||||||
if !slices.Contains(allowedStringValList, numProcPerNode.StrVal) { | ||||||||||||||||||||||||||||||
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu")) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if slices.ContainsFunc(newObj.Spec.Trainer.Env, func(x corev1.EnvVar) bool { | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this TODO:
|
||||||||||||||||||||||||||||||
return strings.HasPrefix(x.Name, constants.TorchEnvNamePrefix) | ||||||||||||||||||||||||||||||
}) { | ||||||||||||||||||||||||||||||
trainerEnvsPath := specPath.Child("trainer").Child("env") | ||||||||||||||||||||||||||||||
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("should not have envs with name having prefix %s", constants.TorchEnvNamePrefix))) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return nil, allErrs | ||||||||||||||||||||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this constant to the L60, and re-use it for other torchrun envs:
trainer/pkg/constants/constants.go
Line 60 in 32f04e3