Skip to content

Commit

Permalink
Fix trainer double init call (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mwyatt authored Feb 13, 2025
1 parent 5ef7cff commit ab7bb18
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 11 deletions.
4 changes: 3 additions & 1 deletion arctic_training/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def run_script():
import deepspeed.comm as dist

from arctic_training.config.trainer import get_config
from arctic_training.registry.trainer import get_registered_trainer

parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -69,7 +70,8 @@ def run_script():
raise FileNotFoundError(f"Config file {args.config} not found.")

config = get_config(args.config)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)
trainer.train()
if dist.is_initialized():
dist.barrier()
Expand Down
4 changes: 0 additions & 4 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,6 @@ def copy_lr(self) -> Self:
self.optimizer.learning_rate = self.scheduler.learning_rate
return self

@property
def trainer(self):
return get_registered_trainer(self.type)(config=self)

@property
def checkpoint_engines(self) -> List[partial["CheckpointEngine"]]:
checkpoint_engines = []
Expand Down
7 changes: 5 additions & 2 deletions tests/checkpoint/test_ds_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from utils import models_are_equal

from arctic_training.config.trainer import get_config
from arctic_training.registry.trainer import get_registered_trainer


@pytest.mark.cpu
Expand Down Expand Up @@ -49,7 +50,8 @@ def test_ds_engine(tmp_path):
}

config = get_config(config_dict)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)

# Force checkpoint to be saved despite no training happening
trainer.training_finished = True
Expand All @@ -60,7 +62,8 @@ def test_ds_engine(tmp_path):

config_dict["seed"] = 0 # Make sure newly initialized model is different
config = get_config(config_dict)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)

loaded_model = trainer.model
assert models_are_equal(original_model, loaded_model), "Models are not equal"
Expand Down
7 changes: 5 additions & 2 deletions tests/checkpoint/test_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from utils import models_are_equal

from arctic_training.config.trainer import get_config
from arctic_training.registry.trainer import get_registered_trainer


@pytest.mark.cpu
Expand Down Expand Up @@ -47,7 +48,8 @@ def test_hf_engine(tmp_path):
}

config = get_config(config_dict)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)

# Force checkpoint to be saved despite no training happening
trainer.training_finished = True
Expand All @@ -60,7 +62,8 @@ def test_hf_engine(tmp_path):
trainer.checkpoint_engines[0].checkpoint_dir
)
config = get_config(config_dict)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)

loaded_model = trainer.model
assert models_are_equal(original_model, loaded_model), "Models are not equal"
7 changes: 5 additions & 2 deletions tests/trainer/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import yaml

from arctic_training.config.trainer import get_config
from arctic_training.registry.trainer import get_registered_trainer


@pytest.mark.gpu
Expand All @@ -40,7 +41,8 @@ def test_sft_trainer(tmp_path):
f.write(yaml.dump(config_dict))

config = get_config(config_path)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)
trainer.train()
assert trainer.global_step > 0, "Training did not run"

Expand Down Expand Up @@ -75,6 +77,7 @@ def test_sft_trainer_cpu(tmp_path):
f.write(yaml.dump(config_dict))

config = get_config(config_path)
trainer = config.trainer
trainer_cls = get_registered_trainer(config.type)
trainer = trainer_cls(config)
trainer.train()
assert trainer.global_step > 0, "Training did not run"

0 comments on commit ab7bb18

Please sign in to comment.