Skip to content

Commit

Permalink
Adding v2 trainjob validation webhook
Browse files Browse the repository at this point in the history
fixing runtime

Signed-off-by: Akshay Chitneni <[email protected]>
  • Loading branch information
Akshay Chitneni committed Feb 24, 2025
1 parent a55f03a commit 32f04e3
Show file tree
Hide file tree
Showing 20 changed files with 376 additions and 75 deletions.
5 changes: 5 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ const (
// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TorchEnvNamePrefix is the env name prefix for the distributed envs for torchrun.
TorchEnvNamePrefix = "PET_"

// JobLauncher is the Job name for the launcher.
JobLauncher string = "launcher"

Expand Down Expand Up @@ -111,6 +114,8 @@ const (
// Distributed envs for mpirun.
// Values for OpenMPI implementation.
OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile"

UnsupportedRuntimeErrMsg string = "the specified runtime is not supported"
)

var (
Expand Down
13 changes: 2 additions & 11 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
jobruntimes "github.com/kubeflow/trainer/pkg/runtime"
)

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type objsOpState int

const (
Expand Down Expand Up @@ -83,10 +81,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, nil
}

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtimeRefGK := jobruntimes.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef)
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return ctrl.Result{}, fmt.Errorf("%s: %s", constants.UnsupportedRuntimeErrMsg, runtimeRefGK)
}
opState, err := r.reconcileObjects(ctx, runtime, &trainJob)

Expand Down Expand Up @@ -214,13 +212,6 @@ func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
}

func runtimeRefToGroupKind(runtimeRef trainer.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
b := ctrl.NewControllerManagedBy(mgr).
WithOptions(options).
Expand Down
13 changes: 9 additions & 4 deletions pkg/runtime/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
Expand Down Expand Up @@ -69,14 +70,18 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
clusterTrainingRuntime := &trainer.ClusterTrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
Name: new.Spec.RuntimeRef.Name,
}, &trainer.ClusterTrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.runtimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, clusterTrainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: clusterTrainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
1 change: 0 additions & 1 deletion pkg/runtime/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package core

import (
"context"

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeflow/trainer/pkg/runtime"
Expand Down
49 changes: 31 additions & 18 deletions pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai
func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
) ([]client.Object, error) {

info := r.runtimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) runtimeInfo(
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) *runtime.Info {

propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
Expand Down Expand Up @@ -113,19 +133,7 @@ func (r *TrainingRuntime) buildObjects(

info := runtime.NewInfo(opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
return info
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
Expand All @@ -141,14 +149,19 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
trainingRuntime := &trainer.TrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &trainer.TrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, trainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.runtimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: trainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
4 changes: 2 additions & 2 deletions pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(info *runtime.Info, trainJob
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
func (f *Framework) RunCustomValidationPlugins(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
warnings, errs := plugin.Validate(oldObj, newObj)
warnings, errs := plugin.Validate(runtimeJobTemplate, info, oldObj, newObj)
if len(warnings) != 0 {
aggregatedWarnings = append(aggregatedWarnings, warnings...)
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func TestNew(t *testing.T) {
customValidationPlugins: []framework.CustomValidationPlugin{
&mpi.MPI{},
&torch.Torch{},
&jobset.JobSet{},
},
watchExtensionPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
Expand Down Expand Up @@ -371,7 +372,9 @@ func TestRunCustomValidationPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
runtimeInfo := runtime.NewInfo()
jobSetTemplate := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test")
warnings, errs := fwk.RunCustomValidationPlugins(jobSetTemplate, runtimeInfo, tc.oldObj, tc.newObj)
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Plugin interface {

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
Validate(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
}

type WatchExtensionPlugin interface {
Expand Down
54 changes: 54 additions & 0 deletions pkg/runtime/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ import (
"context"
"fmt"
"maps"
"slices"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/equality"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand All @@ -52,6 +56,7 @@ type JobSet struct {
var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
var _ framework.TerminalConditionPlugin = (*JobSet)(nil)
var _ framework.CustomValidationPlugin = (*JobSet)(nil)

const Name = constants.JobSetKind

Expand Down Expand Up @@ -159,3 +164,52 @@ func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJ
}
return nil, nil
}

func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {

var allErrs field.ErrorList
specPath := field.NewPath("spec")
runtimeRefPath := specPath.Child("runtimeRef")

jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet)
if !ok {
return nil, nil
}

if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil {
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool {
return x.Name == constants.JobInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input modelConfig", constants.JobInitializer)))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool {
return x.Name == constants.ContainerModelInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerModelInitializer, constants.JobInitializer)))
}
}
}
}
}

if newObj.Spec.DatasetConfig != nil {
if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool {
return x.Name == constants.JobInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input datasetConfig", constants.JobInitializer)))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool {
return x.Name == constants.ContainerDatasetInitializer
}) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerDatasetInitializer, constants.JobInitializer)))
}
}
}
}
}
return nil, allErrs
}
17 changes: 14 additions & 3 deletions pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"k8s.io/apimachinery/pkg/util/intstr"
"maps"
"strconv"

Expand Down Expand Up @@ -75,9 +76,19 @@ func (m *MPI) Name() string {
return Name
}

// TODO: Need to implement validations for MPI Policy.
func (m *MPI) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")
if newJobObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil {
numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode
if numProcPerNode.Type != intstr.Int {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}
}
}
return nil, allErrs
}

func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {
Expand Down
35 changes: 30 additions & 5 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package torch
import (
"context"
"fmt"
"slices"
"strings"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
Expand Down Expand Up @@ -49,11 +51,6 @@ func (t *Torch) Name() string {
return Name
}

// TODO: Need to implement validations for Torch policy.
func (t *Torch) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
}

// TODO (andreyvelich): Add support for PyTorch elastic when JobSet supports Elastic Jobs.
func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {
if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.Torch == nil {
Expand Down Expand Up @@ -140,3 +137,31 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)

return nil
}

func (t *Torch) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")

if newObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy != nil &&
runtimeInfo.RuntimePolicy.MLPolicy.Torch != nil && newObj.Spec.Trainer.NumProcPerNode != nil {
numProcPerNode := *newObj.Spec.Trainer.NumProcPerNode
if numProcPerNode.Type == intstr.String {
allowedStringValList := []string{"auto", "cpu", "gpu"}
if !slices.Contains(allowedStringValList, numProcPerNode.StrVal) {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu"))
}
}
}

if slices.ContainsFunc(newObj.Spec.Trainer.Env, func(x corev1.EnvVar) bool {
return strings.HasPrefix(x.Name, constants.TorchEnvNamePrefix)
}) {
trainerEnvsPath := specPath.Child("trainer").Child("env")
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("should not have envs with name having prefix %s", constants.TorchEnvNamePrefix)))
}
}

return nil, allErrs
}
9 changes: 9 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"maps"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/ptr"
kueuelr "sigs.k8s.io/kueue/pkg/util/limitrange"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand Down Expand Up @@ -146,3 +148,10 @@ func NewInfo(opts ...InfoOption) *Info {

return info
}

func RuntimeRefToRuntimeRegistryKey(runtimeRef trainer.RuntimeRef) string {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}.String()
}
Loading

0 comments on commit 32f04e3

Please sign in to comment.