diff --git a/arctic_training/__init__.py b/arctic_training/__init__.py index f6c1fd4..3ed67e4 100644 --- a/arctic_training/__init__.py +++ b/arctic_training/__init__.py @@ -30,6 +30,7 @@ from .config.scheduler import SchedulerConfig from .config.tokenizer import TokenizerConfig from .config.trainer import TrainerConfig +from .config.trainer import get_config from .data.factory import DataFactory from .data.sft_factory import SFTDataFactory from .data.source import DataSource diff --git a/arctic_training/checkpoint/engine.py b/arctic_training/checkpoint/engine.py index 7ec4447..9583f6b 100644 --- a/arctic_training/checkpoint/engine.py +++ b/arctic_training/checkpoint/engine.py @@ -18,7 +18,6 @@ from pathlib import Path from typing import TYPE_CHECKING from typing import Any -from typing import Type import torch @@ -39,13 +38,13 @@ class CheckpointEngine(ABC, CallbackMixin): engine in the registry. """ - config_type: Type[CheckpointConfig] = CheckpointConfig + config: CheckpointConfig """ The configuration class for the checkpoint engine. This is used to validate the configuration passed to the engine. """ - def __init__(self, trainer: "Trainer", config: "CheckpointConfig") -> None: + def __init__(self, trainer: "Trainer", config: CheckpointConfig) -> None: self._trainer = trainer self.config = config diff --git a/arctic_training/config/trainer.py b/arctic_training/config/trainer.py index 519f455..b04255f 100644 --- a/arctic_training/config/trainer.py +++ b/arctic_training/config/trainer.py @@ -19,25 +19,31 @@ from pathlib import Path from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Union +from typing import cast +import deepspeed import yaml +from deepspeed.accelerator import get_accelerator from pydantic import Field +from pydantic import ValidationInfo from pydantic import field_validator from pydantic import model_validator from typing_extensions import Self -if TYPE_CHECKING: - from arctic_training.checkpoint.engine import CheckpointEngine - -import deepspeed -from deepspeed.accelerator import get_accelerator -from pydantic import ValidationInfo - from arctic_training.config import BaseConfig +from arctic_training.config.checkpoint import CheckpointConfig +from arctic_training.config.data import DataConfig from arctic_training.config.enums import DType +from arctic_training.config.logger import LoggerConfig +from arctic_training.config.model import ModelConfig +from arctic_training.config.optimizer import OptimizerConfig +from arctic_training.config.scheduler import SchedulerConfig +from arctic_training.config.tokenizer import TokenizerConfig +from arctic_training.config.wandb import WandBConfig from arctic_training.registry.checkpoint import get_registered_checkpoint_engine from arctic_training.registry.data import get_registered_data_factory from arctic_training.registry.model import get_registered_model_factory @@ -45,17 +51,13 @@ from arctic_training.registry.scheduler import get_registered_scheduler_factory from arctic_training.registry.tokenizer import get_registered_tokenizer_factory from arctic_training.registry.trainer import get_registered_trainer +from arctic_training.registry.utils import _get_class_attr_type_hints from arctic_training.utils import get_local_rank from arctic_training.utils import get_world_size -from .checkpoint import CheckpointConfig -from .data import DataConfig -from .logger import LoggerConfig -from .model import ModelConfig -from .optimizer import OptimizerConfig -from .scheduler import SchedulerConfig -from .tokenizer import TokenizerConfig -from .wandb import WandBConfig +if TYPE_CHECKING: + from arctic_training.checkpoint.engine import CheckpointEngine + TRAINER_DEFAULT = "sft" CUSTOM_CODE_DEFAULT = Path("train.py") @@ -70,6 +72,9 @@ class TrainerConfig(BaseConfig): code: Path = CUSTOM_CODE_DEFAULT """ Path to the python script containing custom trainer implementation. """ + skip_validation: bool = False + """ Skips validation of types for subconfigs and registered classes. """ + model: ModelConfig """ Model configuration. """ @@ -100,7 +105,7 @@ class TrainerConfig(BaseConfig): loss_log_interval: int = Field(default=1, ge=0) """ Number of steps between logging loss. """ - gradient_accumulation_steps: int = 1 + gradient_accumulation_steps: int = Field(default=1, ge=1) """ Number of gradient accumulation steps. """ micro_batch_size: int = Field(default=1, ge=1) @@ -147,82 +152,161 @@ def checkpoint_engines(self) -> List[partial["CheckpointEngine"]]: def zero_3_enabled(self) -> bool: return self.deepspeed.get("zero_optimization", {}).get("stage", 0) == 3 - @field_validator("eval_frequency", mode="after") - def validate_eval_frequency(cls, v: int, info: ValidationInfo) -> int: + @staticmethod + def _get_subconfig_object( + v: Union[Dict, BaseConfig], + info: ValidationInfo, + get_class_fn: Callable, + attr_name: str, + ) -> BaseConfig: + # Get the trainer class as it will tell us which types of factory + # classes (and thus configs) are default/compatible + trainer_type = info.data["type"] + trainer_cls = get_registered_trainer(trainer_type) + + # Get type hints for this factory class. This is a list of compatible + # classes for the given attribute field. + attribute_type_hints = _get_class_attr_type_hints(trainer_cls, attr_name) + + # Convert to a dictionary as default values are the base config classes + # and we likely need to use a different class based on the trainer type + # or user requested `type` field value. + if isinstance(v, dict): + config_dict = v + else: + config_dict = v.model_dump() + + # Determine which attribute class to use (e.g., for `model`: + # HFModelFactory, LigerModelFactory, etc.) + if config_dict.get("type", ""): + # User explicitly specified the type + attr_cls = get_class_fn(config_dict["type"]) + else: + # User did not specify the type, use the first (maybe only) hint as default type + attr_cls = attribute_type_hints[0] + + # Check that the requested/resolved type is compatible with the trainer if ( - info.data["data"].eval_sources - or info.data["data"].train_eval_split[1] > 0.0 + not info.data.get("skip_validation") + and attr_cls not in attribute_type_hints ): - assert v > 0, "eval_frequency must be set if eval dataset is provided." - return v + raise ValueError( + f"{attr_cls.__name__} is not supported for {attr_name} in" + f" {trainer_cls.__name__}. Supported types are" + f" {[cls.__name__ for cls in attribute_type_hints]}." + ) - @field_validator("tokenizer", mode="after") - def set_tokenizer(cls, v: TokenizerConfig, info: ValidationInfo) -> TokenizerConfig: - if not v.name_or_path and "model" in info.data: - v.name_or_path = info.data["model"].name_or_path + # Make sure the `type` field is set in the config dict + config_dict["type"] = attr_cls.name + + # Get the config class for the factory class and creat the config + config_cls = _get_class_attr_type_hints(attr_cls, "config")[0] + return config_cls(**config_dict) + + @staticmethod + def _to_list(v: Union[Any, List[Any]]) -> List[Any]: + if not isinstance(v, list): + return [v] return v - @field_validator( - "checkpoint", - "data", - "model", - "optimizer", - "scheduler", - "tokenizer", - mode="before", - ) + @field_validator("checkpoint", mode="before") @classmethod - def parse_sub_config( + def init_checkpoint_configs( cls, - v: Any, + v: Union[Union[Dict, CheckpointConfig], List[Union[Dict, CheckpointConfig]]], info: ValidationInfo, - ) -> Union[BaseConfig, List[BaseConfig]]: - trainer_attr_map = { - "checkpoint": "checkpoint_engine_type", - "data": "data_factory_type", - "model": "model_factory_type", - "optimizer": "optimizer_factory_type", - "scheduler": "scheduler_factory_type", - "tokenizer": "tokenizer_factory_type", - } - field_name: str = info.field_name # type: ignore - trainer_type: str = info.data["type"] - trainer_cls = get_registered_trainer(trainer_type) - trainer_field_default = getattr(trainer_cls, trainer_attr_map[field_name])[0] - - if isinstance(v, tuple) or isinstance(v, list): - return_list = [] - for sub_v in v: - if isinstance(sub_v, BaseConfig): - sub_v = sub_v.model_dump() - field_cls = cls._get_config_cls( - sub_v, field_name, trainer_field_default + ) -> List[CheckpointConfig]: + v = cls._to_list(v) + return_list = [] + for sub_v in v: + return_list.append( + cls._get_subconfig_object( + v=sub_v, + info=info, + get_class_fn=get_registered_checkpoint_engine, + attr_name="checkpoint_engine", ) - sub_v["type"] = field_cls.name - return_list.append(field_cls.config_type(**sub_v)) - return return_list + ) + return [cast(CheckpointConfig, subconfig) for subconfig in return_list] - if isinstance(v, BaseConfig): - v = v.model_dump() - field_cls = cls._get_config_cls(v, field_name, trainer_field_default) - v["type"] = field_cls.name - return field_cls.config_type(**v) + @field_validator("data", mode="before") + @classmethod + def init_data_config( + cls, v: Union[Dict, DataConfig], info: ValidationInfo + ) -> DataConfig: + subconfig = cls._get_subconfig_object( + v=v, + info=info, + get_class_fn=get_registered_data_factory, + attr_name="data_factory", + ) + return cast(DataConfig, subconfig) + + @field_validator("model", mode="before") + @classmethod + def init_model_config( + cls, v: Union[Dict, ModelConfig], info: ValidationInfo + ) -> ModelConfig: + subconfig = cls._get_subconfig_object( + v=v, + info=info, + get_class_fn=get_registered_model_factory, + attr_name="model_factory", + ) + return cast(ModelConfig, subconfig) + + @field_validator("optimizer", mode="before") + @classmethod + def init_optimizer_config( + cls, v: Union[Dict, OptimizerConfig], info: ValidationInfo + ) -> OptimizerConfig: + subconfig = cls._get_subconfig_object( + v=v, + info=info, + get_class_fn=get_registered_optimizer_factory, + attr_name="optimizer_factory", + ) + return cast(OptimizerConfig, subconfig) + + @field_validator("scheduler", mode="before") + @classmethod + def init_scheduler_config( + cls, v: Union[Dict, SchedulerConfig], info: ValidationInfo + ) -> SchedulerConfig: + subconfig = cls._get_subconfig_object( + v=v, + info=info, + get_class_fn=get_registered_scheduler_factory, + attr_name="scheduler_factory", + ) + return cast(SchedulerConfig, subconfig) + @field_validator("tokenizer", mode="before") @classmethod - def _get_config_cls(cls, config_dict, field_name, default_cls): - get_class_fn_map = { - "checkpoint": get_registered_checkpoint_engine, - "data": get_registered_data_factory, - "model": get_registered_model_factory, - "optimizer": get_registered_optimizer_factory, - "scheduler": get_registered_scheduler_factory, - "tokenizer": get_registered_tokenizer_factory, - } - field_type = config_dict.get("type", "") - if field_type == "": - field_type = default_cls - field_cls = get_class_fn_map[field_name](field_type) - return field_cls + def init_tokenizer_config( + cls, v: Union[Dict, TokenizerConfig], info: ValidationInfo + ) -> TokenizerConfig: + subconfig = cls._get_subconfig_object( + v=v, + info=info, + get_class_fn=get_registered_tokenizer_factory, + attr_name="tokenizer_factory", + ) + return cast(TokenizerConfig, subconfig) + + @model_validator(mode="after") + def validate_eval_frequency(self) -> Self: + if self.data.eval_sources or self.data.train_eval_split[1] > 0.0: + assert ( + self.eval_frequency > 0 + ), "eval_frequency must be set if eval dataset is provided." + return self + + @model_validator(mode="after") + def set_tokenizer(self) -> Self: + if not self.tokenizer.name_or_path: + self.tokenizer.name_or_path = self.model.name_or_path + return self @field_validator("logger", mode="after") @classmethod @@ -232,12 +316,6 @@ def initialize_logger(cls, v: LoggerConfig) -> LoggerConfig: setup_logger(v) return v - @field_validator("checkpoint", mode="before") - def checkpoint_to_list(cls, v: Union[Dict, List[Dict]]) -> List[Dict]: - if not isinstance(v, list): - return [v] - return v - @model_validator(mode="after") def build_deepspeed_config(self) -> Self: ds_config = self.deepspeed @@ -312,7 +390,7 @@ def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig: sys.path = original_sys_path trainer_cls = get_registered_trainer(trainer_type) - config_cls = trainer_cls.config_type + config_cls = _get_class_attr_type_hints(trainer_cls, "config")[0] config = config_cls(**config_dict) diff --git a/arctic_training/data/factory.py b/arctic_training/data/factory.py index 3662dee..3b9293e 100644 --- a/arctic_training/data/factory.py +++ b/arctic_training/data/factory.py @@ -23,7 +23,6 @@ from typing import List from typing import Optional from typing import Tuple -from typing import Type import torch from datasets import Dataset @@ -52,20 +51,18 @@ class DataFactory(ABC, CallbackMixin): specify the DataFactory to use. """ - config_type: Type[DataConfig] = DataConfig + config: DataConfig """ The type of the DataConfig object that this DataFactory uses. Any DataFactory-specific options should be specified in this class. """ - def __init__( - self, trainer: "Trainer", data_config: Optional["DataConfig"] = None - ) -> None: - if data_config is None: - data_config = trainer.config.data + def __init__(self, trainer: "Trainer", config: Optional[DataConfig] = None) -> None: + if config is None: + config = trainer.config.data self._trainer = trainer - self.config = data_config + self.config = config def __call__(self) -> Tuple[DataLoader, Optional[DataLoader]]: train_dataset = self._load_data( diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index bca2f86..801ca20 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -26,7 +26,6 @@ from transformers import BatchEncoding from transformers import PreTrainedTokenizerBase -from arctic_training.config.data import DataConfig from arctic_training.data.factory import DataFactory from arctic_training.registry import register @@ -223,7 +222,6 @@ def packing_sft_dataset( @register class SFTDataFactory(DataFactory): name = "sft" - config_type = DataConfig def tokenize_fn( self, diff --git a/arctic_training/model/factory.py b/arctic_training/model/factory.py index 96213ac..dbfc5fc 100644 --- a/arctic_training/model/factory.py +++ b/arctic_training/model/factory.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Optional -from typing import Type from transformers import PreTrainedModel @@ -40,14 +39,14 @@ class ModelFactory(ABC, CallbackMixin): model factory to be used. """ - config_type: Type[ModelConfig] = ModelConfig + config: ModelConfig """ The type of config class that the model factory uses. This should contain all model-specific parameters. """ def __init__( - self, trainer: "Trainer", model_config: Optional["ModelConfig"] = None + self, trainer: "Trainer", model_config: Optional[ModelConfig] = None ) -> None: if model_config is None: model_config = trainer.config.model diff --git a/arctic_training/optimizer/factory.py b/arctic_training/optimizer/factory.py index 2fa29d4..acb2a71 100644 --- a/arctic_training/optimizer/factory.py +++ b/arctic_training/optimizer/factory.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Optional -from typing import Type from arctic_training.callback.mixin import CallbackMixin from arctic_training.callback.mixin import callback_wrapper @@ -38,7 +37,7 @@ class OptimizerFactory(ABC, CallbackMixin): to identify which optimizer factory to be used. """ - config_type: Type[OptimizerConfig] = OptimizerConfig + config: OptimizerConfig """ The type of config class that the optimizer factory uses. This should contain all optimizer-specific parameters. @@ -47,7 +46,7 @@ class OptimizerFactory(ABC, CallbackMixin): def __init__( self, trainer: "Trainer", - optimizer_config: Optional["OptimizerConfig"] = None, + optimizer_config: Optional[OptimizerConfig] = None, ) -> None: if optimizer_config is None: optimizer_config = trainer.config.optimizer diff --git a/arctic_training/registry/checkpoint.py b/arctic_training/registry/checkpoint.py index f3f070c..fb6da23 100644 --- a/arctic_training/registry/checkpoint.py +++ b/arctic_training/registry/checkpoint.py @@ -24,6 +24,7 @@ from arctic_training.logging import logger from arctic_training.registry.utils import AlreadyRegisteredError from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type _supported_checkpoint_registry: Dict[str, Type["CheckpointEngine"]] = {} @@ -33,6 +34,7 @@ def register_checkpoint_engine( ) -> Type["CheckpointEngine"]: global _supported_checkpoint_registry from arctic_training.checkpoint.engine import CheckpointEngine + from arctic_training.config.checkpoint import CheckpointConfig if not issubclass(cls, CheckpointEngine): raise ValueError( @@ -41,7 +43,7 @@ def register_checkpoint_engine( ) _validate_class_attribute_set(cls, "name") - _validate_class_attribute_set(cls, "config_type") + _validate_class_attribute_type(cls, "config", CheckpointConfig) if cls.name in _supported_checkpoint_registry: if not force: diff --git a/arctic_training/registry/data.py b/arctic_training/registry/data.py index 28a2126..defd34d 100644 --- a/arctic_training/registry/data.py +++ b/arctic_training/registry/data.py @@ -27,6 +27,7 @@ from arctic_training.logging import logger from arctic_training.registry.utils import AlreadyRegisteredError from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type from arctic_training.registry.utils import _validate_method_definition _supported_data_source_registry: Dict[str, Dict[str, Type["DataSource"]]] = defaultdict( @@ -65,9 +66,10 @@ def register_data_factory( cls: Type["DataFactory"], force: bool = False ) -> Type["DataFactory"]: global _supported_data_factory_registry + from arctic_training.config.data import DataConfig _validate_class_attribute_set(cls, "name") - _validate_class_attribute_set(cls, "config_type") + _validate_class_attribute_type(cls, "config", DataConfig) _validate_method_definition(cls, "load", ["self", "num_proc", "eval"]) if cls.name in _supported_data_factory_registry: diff --git a/arctic_training/registry/model.py b/arctic_training/registry/model.py index cc461c5..b2c5867 100644 --- a/arctic_training/registry/model.py +++ b/arctic_training/registry/model.py @@ -19,14 +19,14 @@ from typing import Union from arctic_training.logging import logger - -if TYPE_CHECKING: - from arctic_training.model.factory import ModelFactory - from arctic_training.registry.utils import AlreadyRegisteredError from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type from arctic_training.registry.utils import _validate_method_definition +if TYPE_CHECKING: + from arctic_training.model.factory import ModelFactory + _supported_model_factory_registry: Dict[str, Type["ModelFactory"]] = {} @@ -34,14 +34,17 @@ def register_model_factory( cls: Type["ModelFactory"], force: bool = False ) -> Type["ModelFactory"]: global _supported_model_factory_registry + from arctic_training.config.model import ModelConfig from arctic_training.model.factory import ModelFactory if not issubclass(cls, ModelFactory): raise ValueError( - f"New Model Factory {cls.__name__} clss must be a subclass of ModelFactory." + f"New Model Factory {cls.__name__} class must be a subclass of" + " ModelFactory." ) _validate_class_attribute_set(cls, "name") + _validate_class_attribute_type(cls, "config", ModelConfig) _validate_method_definition(cls, "create_config", ["self"]) _validate_method_definition(cls, "create_model", ["self", "model_config"]) diff --git a/arctic_training/registry/optimizer.py b/arctic_training/registry/optimizer.py index 93fc210..53252ed 100644 --- a/arctic_training/registry/optimizer.py +++ b/arctic_training/registry/optimizer.py @@ -19,14 +19,14 @@ from typing import Union from arctic_training.logging import logger - -if TYPE_CHECKING: - from arctic_training.optimizer.factory import OptimizerFactory - from arctic_training.registry.utils import AlreadyRegisteredError from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type from arctic_training.registry.utils import _validate_method_definition +if TYPE_CHECKING: + from arctic_training.optimizer.factory import OptimizerFactory + _supported_optimizer_factory_registry: Dict[str, Type["OptimizerFactory"]] = {} @@ -34,6 +34,7 @@ def register_optimizer_factory( cls: Type["OptimizerFactory"], force: bool = False ) -> Type["OptimizerFactory"]: global _supported_optimizer_factory_registry + from arctic_training.config.optimizer import OptimizerConfig from arctic_training.optimizer.factory import OptimizerFactory if not issubclass(cls, OptimizerFactory): @@ -43,7 +44,7 @@ def register_optimizer_factory( ) _validate_class_attribute_set(cls, "name") - _validate_class_attribute_set(cls, "config_type") + _validate_class_attribute_type(cls, "config", OptimizerConfig) _validate_method_definition( cls, "create_optimizer", ["self", "model", "optimizer_config"] ) diff --git a/arctic_training/registry/scheduler.py b/arctic_training/registry/scheduler.py index 35e8225..cc5c2cb 100644 --- a/arctic_training/registry/scheduler.py +++ b/arctic_training/registry/scheduler.py @@ -21,6 +21,7 @@ from arctic_training.logging import logger from arctic_training.registry.utils import AlreadyRegisteredError from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type from arctic_training.registry.utils import _validate_method_definition if TYPE_CHECKING: @@ -33,6 +34,7 @@ def register_scheduler_factory( cls: Type["SchedulerFactory"], force: bool = False ) -> Type["SchedulerFactory"]: global _supported_scheduler_factory_registry + from arctic_training.config.scheduler import SchedulerConfig from arctic_training.scheduler.factory import SchedulerFactory if not issubclass(cls, SchedulerFactory): @@ -42,7 +44,7 @@ def register_scheduler_factory( ) _validate_class_attribute_set(cls, "name") - _validate_class_attribute_set(cls, "config_type") + _validate_class_attribute_type(cls, "config", SchedulerConfig) _validate_method_definition(cls, "create_scheduler", ["self", "optimizer"]) if cls.name in _supported_scheduler_factory_registry: diff --git a/arctic_training/registry/tokenizer.py b/arctic_training/registry/tokenizer.py index 0a46988..9f8092c 100644 --- a/arctic_training/registry/tokenizer.py +++ b/arctic_training/registry/tokenizer.py @@ -19,14 +19,14 @@ from typing import Union from arctic_training.logging import logger - -if TYPE_CHECKING: - from arctic_training.tokenizer.factory import TokenizerFactory - from arctic_training.registry.utils import AlreadyRegisteredError from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type from arctic_training.registry.utils import _validate_method_definition +if TYPE_CHECKING: + from arctic_training.tokenizer.factory import TokenizerFactory + _supported_tokenizer_factory_registry: Dict[str, Type["TokenizerFactory"]] = {} @@ -34,6 +34,7 @@ def register_tokenizer_factory( cls: Type["TokenizerFactory"], force: bool = False ) -> Type["TokenizerFactory"]: global _supported_tokenizer_factory_registry + from arctic_training.config.tokenizer import TokenizerConfig from arctic_training.tokenizer.factory import TokenizerFactory if not issubclass(cls, TokenizerFactory): @@ -43,7 +44,7 @@ def register_tokenizer_factory( ) _validate_class_attribute_set(cls, "name") - _validate_class_attribute_set(cls, "config_type") + _validate_class_attribute_type(cls, "config", TokenizerConfig) _validate_method_definition(cls, "create_tokenizer", ["self"]) if cls.name in _supported_tokenizer_factory_registry: diff --git a/arctic_training/registry/trainer.py b/arctic_training/registry/trainer.py index 0272f4a..01a6b20 100644 --- a/arctic_training/registry/trainer.py +++ b/arctic_training/registry/trainer.py @@ -15,25 +15,20 @@ from typing import TYPE_CHECKING from typing import Dict -from typing import Iterable from typing import Type from typing import Union from arctic_training.logging import logger -from arctic_training.registry.checkpoint import get_registered_checkpoint_engine from arctic_training.registry.checkpoint import register_checkpoint_engine -from arctic_training.registry.data import get_registered_data_factory from arctic_training.registry.data import register_data_factory -from arctic_training.registry.model import get_registered_model_factory from arctic_training.registry.model import register_model_factory -from arctic_training.registry.optimizer import get_registered_optimizer_factory from arctic_training.registry.optimizer import register_optimizer_factory -from arctic_training.registry.scheduler import get_registered_scheduler_factory from arctic_training.registry.scheduler import register_scheduler_factory -from arctic_training.registry.tokenizer import get_registered_tokenizer_factory from arctic_training.registry.tokenizer import register_tokenizer_factory from arctic_training.registry.utils import AlreadyRegisteredError +from arctic_training.registry.utils import _get_class_attr_type_hints from arctic_training.registry.utils import _validate_class_attribute_set +from arctic_training.registry.utils import _validate_class_attribute_type if TYPE_CHECKING: from arctic_training.trainer.trainer import Trainer @@ -43,6 +38,13 @@ def register_trainer(cls: Type["Trainer"], force: bool = False) -> Type["Trainer"]: global _supported_trainer_registry + from arctic_training.checkpoint.engine import CheckpointEngine + from arctic_training.config.trainer import TrainerConfig + from arctic_training.data.factory import DataFactory + from arctic_training.model.factory import ModelFactory + from arctic_training.optimizer.factory import OptimizerFactory + from arctic_training.scheduler.factory import SchedulerFactory + from arctic_training.tokenizer.factory import TokenizerFactory from arctic_training.trainer.trainer import Trainer if not issubclass(cls, Trainer): @@ -51,59 +53,49 @@ def register_trainer(cls: Type["Trainer"], force: bool = False) -> Type["Trainer ) _validate_class_attribute_set(cls, "name") - _validate_class_attribute_set(cls, "config_type") + _validate_class_attribute_type(cls, "config", TrainerConfig) trainer_attributes = [ ( - "data_factory_type", - get_registered_data_factory, + "data_factory", + DataFactory, register_data_factory, ), ( - "model_factory_type", - get_registered_model_factory, + "model_factory", + ModelFactory, register_model_factory, ), ( - "checkpoint_engine_type", - get_registered_checkpoint_engine, + "checkpoint_engine", + CheckpointEngine, register_checkpoint_engine, ), ( - "optimizer_factory_type", - get_registered_optimizer_factory, + "optimizer_factory", + OptimizerFactory, register_optimizer_factory, ), ( - "scheduler_factory_type", - get_registered_scheduler_factory, + "scheduler_factory", + SchedulerFactory, register_scheduler_factory, ), ( - "tokenizer_factory_type", - get_registered_tokenizer_factory, + "tokenizer_factory", + TokenizerFactory, register_tokenizer_factory, ), ] - for attr, get_class, register_class in trainer_attributes: - _validate_class_attribute_set(cls, attr) - - # Coerce to list if not already - if not isinstance(getattr(cls, attr), Iterable) or isinstance( - getattr(cls, attr), str - ): - setattr(cls, attr, [getattr(cls, attr)]) - - for class_type in getattr(cls, attr): - # Try to register the class, skip if already registered - if not issubclass(class_type, str): - try: - _ = register_class(class_type) - except AlreadyRegisteredError: - pass - - # Verify that the class is registered - _ = get_class(class_type) + for attr, type_, register_class in trainer_attributes: + _validate_class_attribute_type(cls, attr, type_) + + # Try to register the class, skip if already registered + for attr_type_hint in _get_class_attr_type_hints(cls, attr): + try: + _ = register_class(attr_type_hint) + except AlreadyRegisteredError: + pass if cls.name in _supported_trainer_registry: if not force: diff --git a/arctic_training/registry/utils.py b/arctic_training/registry/utils.py index 0448faf..33c41a5 100644 --- a/arctic_training/registry/utils.py +++ b/arctic_training/registry/utils.py @@ -16,8 +16,12 @@ import inspect from typing import TYPE_CHECKING from typing import List +from typing import Tuple from typing import Type from typing import Union +from typing import get_args +from typing import get_origin +from typing import get_type_hints if TYPE_CHECKING: from arctic_training.checkpoint.engine import CheckpointEngine @@ -99,3 +103,23 @@ def _validate_method_definition( def _validate_class_attribute_set(cls: RegistryClassTypes, attribute: str) -> None: if not getattr(cls, attribute, None): raise ValueError(f"{cls.__name__} must define {attribute} attribute.") + + +def _validate_class_attribute_type( + cls: RegistryClassTypes, attribute: str, type_: Type +) -> None: + for attr_type_hint in _get_class_attr_type_hints(cls, attribute): + if not issubclass(attr_type_hint, type_): + raise TypeError( + f"{cls.__name__}.{attribute} must be an instance of {type_.__name__}." + f" But got {attr_type_hint.__name__}." + ) + + +def _get_class_attr_type_hints(cls: RegistryClassTypes, attribute: str) -> Tuple[Type]: + cls_type_hints = get_type_hints(cls) + if get_origin(cls_type_hints[attribute]) is Union: + attribute_type_hints = get_args(cls_type_hints[attribute]) + else: + attribute_type_hints = (cls_type_hints[attribute],) + return attribute_type_hints diff --git a/arctic_training/scheduler/factory.py b/arctic_training/scheduler/factory.py index 8b47dd7..2aa7c0d 100644 --- a/arctic_training/scheduler/factory.py +++ b/arctic_training/scheduler/factory.py @@ -17,7 +17,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING from typing import Any -from typing import Type +from typing import Optional from arctic_training.callback.mixin import CallbackMixin from arctic_training.callback.mixin import callback_wrapper @@ -36,7 +36,7 @@ class SchedulerFactory(ABC, CallbackMixin): factory in the registry. """ - config_type: Type[SchedulerConfig] = SchedulerConfig + config: SchedulerConfig """ The configuration class for the scheduler factory. This is used to validate the configuration passed to the factory. @@ -45,7 +45,7 @@ class SchedulerFactory(ABC, CallbackMixin): def __init__( self, trainer: "Trainer", - scheduler_config=None, + scheduler_config: Optional[SchedulerConfig] = None, ) -> None: if scheduler_config is None: scheduler_config = trainer.config.scheduler diff --git a/arctic_training/scheduler/hf_factory.py b/arctic_training/scheduler/hf_factory.py index 1897d08..452b8b7 100644 --- a/arctic_training/scheduler/hf_factory.py +++ b/arctic_training/scheduler/hf_factory.py @@ -29,7 +29,7 @@ class HFSchedulerConfig(SchedulerConfig): @register class HFSchedulerFactory(SchedulerFactory): name = "huggingface" - config_type = HFSchedulerConfig + config: HFSchedulerConfig def create_scheduler(self, optimizer: Any) -> Any: return get_scheduler( diff --git a/arctic_training/tokenizer/factory.py b/arctic_training/tokenizer/factory.py index 2ee632a..fd90caa 100644 --- a/arctic_training/tokenizer/factory.py +++ b/arctic_training/tokenizer/factory.py @@ -17,7 +17,6 @@ from abc import abstractmethod from typing import TYPE_CHECKING from typing import Optional -from typing import Type from transformers import PreTrainedTokenizer @@ -38,14 +37,14 @@ class TokenizerFactory(ABC, CallbackMixin): factory in the registry. """ - config_type: Type[TokenizerConfig] = TokenizerConfig + config: TokenizerConfig """ The configuration class for the tokenizer factory. This is used to validate the configuration passed to the factory. """ def __init__( - self, trainer: "Trainer", tokenizer_config: Optional["TokenizerConfig"] = None + self, trainer: "Trainer", tokenizer_config: Optional[TokenizerConfig] = None ) -> None: if tokenizer_config is None: tokenizer_config = trainer.config.tokenizer diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index 7022a76..3086c02 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -14,12 +14,12 @@ # limitations under the License. from typing import Dict +from typing import Union import torch from arctic_training.checkpoint.ds_engine import DSCheckpointEngine from arctic_training.checkpoint.hf_engine import HFCheckpointEngine -from arctic_training.config.trainer import TrainerConfig from arctic_training.data.sft_factory import SFTDataFactory from arctic_training.model.hf_factory import HFModelFactory from arctic_training.model.liger_factory import LigerModelFactory @@ -40,13 +40,12 @@ def to_device(batch: Dict, device: str) -> Dict: @register class SFTTrainer(Trainer): name = "sft" - config_type = TrainerConfig - data_factory_type = [SFTDataFactory] - model_factory_type = [HFModelFactory, LigerModelFactory] - checkpoint_engine_type = [DSCheckpointEngine, HFCheckpointEngine] - optimizer_factory_type = [FusedAdamOptimizerFactory] - scheduler_factory_type = [HFSchedulerFactory] - tokenizer_factory_type = [HFTokenizerFactory] + data_factory: SFTDataFactory + model_factory: Union[HFModelFactory, LigerModelFactory] + checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine] + optimizer_factory: Union[FusedAdamOptimizerFactory] + scheduler_factory: Union[HFSchedulerFactory] + tokenizer_factory: Union[HFTokenizerFactory] def loss(self, batch) -> torch.Tensor: batch = to_device(batch, self.device) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index e1bfc4d..4a013b4 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -16,12 +16,10 @@ import random from abc import ABC from abc import abstractmethod -from typing import TYPE_CHECKING from typing import Callable from typing import Dict from typing import List from typing import Tuple -from typing import Type import deepspeed import numpy as np @@ -34,16 +32,14 @@ from arctic_training.callback.logging import post_loss_log_cb from arctic_training.callback.mixin import CallbackMixin from arctic_training.callback.mixin import callback_wrapper +from arctic_training.checkpoint.engine import CheckpointEngine from arctic_training.config.trainer import TrainerConfig +from arctic_training.data.factory import DataFactory from arctic_training.logging import logger - -if TYPE_CHECKING: - from arctic_training.checkpoint.engine import CheckpointEngine - from arctic_training.data.factory import DataFactory - from arctic_training.model.factory import ModelFactory - from arctic_training.optimizer.factory import OptimizerFactory - from arctic_training.scheduler.factory import SchedulerFactory - from arctic_training.tokenizer.factory import TokenizerFactory +from arctic_training.model.factory import ModelFactory +from arctic_training.optimizer.factory import OptimizerFactory +from arctic_training.scheduler.factory import SchedulerFactory +from arctic_training.tokenizer.factory import TokenizerFactory try: from transformers.integrations.deepspeed import HfDeepSpeedConfig @@ -61,48 +57,48 @@ class Trainer(ABC, CallbackMixin): trainer to be used. """ - config_type: Type[TrainerConfig] + config: TrainerConfig """ The type of the config class that the trainer uses. This should be a subclass of TrainerConfig and add any trainer-specific fields. """ - data_factory_type: List[Type["DataFactory"]] + data_factory: DataFactory """ A List of valid data factory types that the trainer can use. These should inherit from DataFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ - model_factory_type: List[Type["ModelFactory"]] + model_factory: ModelFactory """ A List of valid model factory types that the trainer can use. These should inherit from ModelFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ - checkpoint_engine_type: List[Type["CheckpointEngine"]] + checkpoint_engine: CheckpointEngine """ A List of valid checkpoint engine types that the trainer can use. These should inherit from CheckpointEngine. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ - optimizer_factory_type: List[Type["OptimizerFactory"]] + optimizer_factory: OptimizerFactory """ A List of valid optimizer factory types that the trainer can use. These should inherit from OptimizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ - scheduler_factory_type: List[Type["SchedulerFactory"]] + scheduler_factory: SchedulerFactory """ A List of valid scheduler factory types that the trainer can use. These should inherit from SchedulerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ - tokenizer_factory_type: List[Type["TokenizerFactory"]] + tokenizer_factory: TokenizerFactory """ A List of valid tokenizer factory types that the trainer can use. These should inherit from TokenizerFactory. The first item in the list will be @@ -117,7 +113,7 @@ class Trainer(ABC, CallbackMixin): `post-` for `init`, `train`, `epoch`, `step`, and `checkpoint`. """ - def __init__(self, config: "TrainerConfig") -> None: + def __init__(self, config: TrainerConfig) -> None: logger.info(f"Initializing Trainer with config:\n{debug.format(config)}") self.config = config self.epoch_idx = 0 diff --git a/docs/checkpoint.rst b/docs/checkpoint.rst index 84e36ef..66355c3 100644 --- a/docs/checkpoint.rst +++ b/docs/checkpoint.rst @@ -15,10 +15,10 @@ Attributes ---------- Similar to the ``*Factory`` classes of ArcticTraining, the CheckpointEngine -class requires only the :attr:`~.CheckpointEngine.name` and -:attr:`~.CheckpointEngine.config_type` attributes to be defined. The name -attribute is used to identify the engine when registering it with ArcticTraining -and the config_type attribute is used to validate the config object passed to +class requires only the :attr:`~.CheckpointEngine.name` be defined and the +:attr:`~.CheckpointEngine.config` attribute type hint. The ``name`` attribute is +used to identify the engine when registering it with ArcticTraining and the +``config`` attribute type hint is used to validate the config object passed to the engine. Properties diff --git a/docs/data.rst b/docs/data.rst index c72e255..99edc64 100644 --- a/docs/data.rst +++ b/docs/data.rst @@ -39,9 +39,9 @@ datasets used in the training pipeline. Attributes ^^^^^^^^^^ -To define a custom data factory, you must subclass the DataFactory and define -two attributes: :attr:`~.DataFactory.name` and -:attr:`~.DataFactory.config_type`. +To define a custom data factory, you must subclass the DataFactory, define the +:attr:`~.DataFactory.name` attribute, and give a type hint for the +:attr:`~.DataFactory.config` attribute. Properties ^^^^^^^^^^ diff --git a/docs/model.rst b/docs/model.rst index ea182aa..1806c82 100644 --- a/docs/model.rst +++ b/docs/model.rst @@ -18,8 +18,8 @@ Attributes Similar to other Factory classes in ArcticTraining, the ModelFactory class must have a :attr:`~.ModelFactory.name` attribute that is used to identify the factory when registering it with ArcticTraining and a -:attr:`~.ModelFactory.config_type` attribute that is used to validate the config -object passed to the factory. +:attr:`~.ModelFactory.config` attribute type hint that is used to validate the +config object passed to the factory. Properties ---------- diff --git a/docs/optimizer.rst b/docs/optimizer.rst index 18a3f01..c67568c 100644 --- a/docs/optimizer.rst +++ b/docs/optimizer.rst @@ -18,8 +18,8 @@ Attributes Similar to other Factory classes in ArcticTraining, the :class:`~.OptimizerFactory` class must have a :attr:`~.OptimizerFactory.name` attribute that is used to identify the factory when registering it with -ArcticTraining and a :attr:`~.OptimizerFactory.config_type` attribute that is -used to validate the config object passed to the factory. +ArcticTraining and a :attr:`~.OptimizerFactory.config` attribute type hint that +is used to validate the config object passed to the factory. Properties ---------- diff --git a/docs/scheduler.rst b/docs/scheduler.rst index e7780cc..520cf48 100644 --- a/docs/scheduler.rst +++ b/docs/scheduler.rst @@ -18,8 +18,8 @@ Attributes Similar to other Factory classes in ArcticTraining, the SchedulerFactory class must have a :attr:`~.SchedulerFactory.name` attribute that is used to identify the factory when registering it with ArcticTraining and a -:attr:`~.SchedulerFactory.config_type` attribute that is used to validate the -config object passed to the factory. +:attr:`~.SchedulerFactory.config` attribute type hint that is used to validate +the config object passed to the factory. Properties ---------- diff --git a/docs/tokenizer.rst b/docs/tokenizer.rst index 94b43f4..5092f8a 100644 --- a/docs/tokenizer.rst +++ b/docs/tokenizer.rst @@ -18,8 +18,8 @@ Attributes Similar to other Factory classes in ArcticTraining, the TokenizerFactory class must have a :attr:`~.TokenizerFactory.name` attribute that is used to identify the factory when registering it with ArcticTraining and a -:attr:`~.TokenizerFactory.config_type` attribute that is used to validate the -config object passed to the factory. +:attr:`~.TokenizerFactory.config` attribute type hint that is used to validate +the config object passed to the factory. Properties ---------- diff --git a/docs/trainer.rst b/docs/trainer.rst index a173afa..d1465a9 100644 --- a/docs/trainer.rst +++ b/docs/trainer.rst @@ -21,16 +21,21 @@ Attributes .. _trainer-attributes: -There are several attributes that must be defined in the Trainer class to create -a new custom trainer. These attributes include: :attr:`~.Trainer.name`, -:attr:`~.Trainer.config_type`, :attr:`~.Trainer.data_factory_type`, -:attr:`~.Trainer.model_factory_type`, :attr:`~.Trainer.checkpoint_engine_type`, -:attr:`~.Trainer.optimizer_factory_type`, -:attr:`~.Trainer.scheduler_factory_type`, and -:attr:`~.Trainer.tokenizer_factory_type`. - -These attributes are used when registering new custom trainers with -ArcticTraining and to validate training recipes that use the trainer. +Creating a custom trainer starts with Inheriting from the base +:class:`~.Trainer` class and defining the :attr:`~.Trainer.name` attribute. The +name attribute is used to identify the trainer when registering it with +ArcticTraining. Additionally, you can define custom types for +:attr:`~.Trainer.config`, :attr:`~.Trainer.data_factory`, +:attr:`~.Trainer.model_factory`, :attr:`~.Trainer.checkpoint_engine`, +:attr:`~.Trainer.optimizer_factory`, :attr:`~.Trainer.scheduler_factory`, and +:attr:`~.Trainer.tokenizer_factory` to specify the default factories for each +component. + +Specify the type hint for these attributes tells ArcticTraining which building +blocks are compatible with your custom trainer. You may define multiple +compatible building blocks by using `typing.Union` in the type hint. When +multiple types are specified for one of these attributes, the first is used as a +default in the case where `type` is not specified in the input config. Properties ---------- @@ -92,8 +97,9 @@ in :ref:`Trainer Attributes`. We use a custom data factory, SFTDataFactory, which we describe in greater detail in the :ref:`Data Factory` section. The remainder of the attributes use the base building blocks from ArcticTraining. For example the model factory defaults to the -HFModelFactory (because it is listed first in the model_factory_type attribute), -but this trainer can work with either `HFModelFactory` or `LigerModelFactory`. +HFModelFactory (because it is listed first in the ``model_factory`` attribute +type hint), but this trainer can work with either `HFModelFactory` or +`LigerModelFactory`. .. literalinclude:: ../arctic_training/trainer/sft_trainer.py :pyobject: SFTTrainer diff --git a/docs/usage.rst b/docs/usage.rst index 01d6b72..7d5ca84 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -88,7 +88,7 @@ these steps: @register class CustomTrainer(SFTTrainer): name = "my_custom_trainer" - model_factory_type = CustomModelFactory + model_factory: CustomModelFactory def loss(self, batch): # Custom loss function implementation @@ -123,6 +123,7 @@ API: from arctic_training import register from arctic_training.model import HFModelFactory from arctic_training import SFTTrainer + from arctic_training import get_config class CustomModelFactory(HFModelFactory): name = "my_custom_model_factory" @@ -134,14 +135,14 @@ API: @register class CustomTrainer(SFTTrainer): name = "my_custom_trainer" - model_factory_type = CustomModelFactory + model_factory: CustomModelFactory def loss(self, batch): # Custom loss function implementation return loss if __name__ == "__main__": - config = { + config_dict = { "type": "my_custom_trainer", "model": { "name_or_path": "meta-llama/Llama-3.1-8B-Instruct" @@ -158,6 +159,6 @@ API: ] } - config = CustomTrainer.config_type(**config) + config = get_config(config_dict) trainer = CustomTrainer(config) trainer.train() diff --git a/projects/mlp_speculator/train.py b/projects/mlp_speculator/train.py index e11ea50..b9a0f0e 100644 --- a/projects/mlp_speculator/train.py +++ b/projects/mlp_speculator/train.py @@ -88,7 +88,7 @@ class MLPSpeculatorModelConfig(ModelConfig): class MLPSpeculatorModelFactory(HFModelFactory): name = "spec-decode" - config_type = MLPSpeculatorModelConfig + config: MLPSpeculatorModelConfig def post_create_model_callback(self, model): hidden_size = model.lm_head.in_features @@ -228,9 +228,9 @@ def save(self) -> None: @register class MLPSpeculatorTrainer(SFTTrainer): name = "spec-decode" - config_type = MLPSpeculatorTrainerConfig - model_factory_type = [MLPSpeculatorModelFactory] - checkpoint_engine_type = [MLPSpeculatorCheckpointEngine] + config: MLPSpeculatorTrainerConfig + model_factory: MLPSpeculatorModelFactory + checkpoint_engine: MLPSpeculatorCheckpointEngine def generate( self, diff --git a/projects/swiftkv/train.py b/projects/swiftkv/train.py index 80b2acc..8051f86 100644 --- a/projects/swiftkv/train.py +++ b/projects/swiftkv/train.py @@ -14,6 +14,7 @@ # limitations under the License. from typing import Any +from typing import Union import llama_swiftkv import torch @@ -37,7 +38,7 @@ class SwiftKVModelConfig(ModelConfig): class SwiftKVModelFactory(HFModelFactory): name = "swiftkv" - config_type = SwiftKVModelConfig + config: SwiftKVModelConfig def post_create_config_callback(self, hf_config): llama_swiftkv.register_auto() @@ -95,9 +96,9 @@ class SwiftKVTrainerConfig(TrainerConfig): @register class SwiftKVTrainer(SFTTrainer): name = "swiftkv" - config_type = SwiftKVTrainerConfig - model_factory_type = SwiftKVModelFactory - checkpoint_engine_type = HFCheckpointEngine + config: SwiftKVTrainerConfig + model_factory: SwiftKVModelFactory + checkpoint_engine: Union[HFCheckpointEngine] def loss(self, batch: Any) -> torch.Tensor: batch = to_device(batch, self.device) diff --git a/scripts/upgrade_user_code.py b/scripts/upgrade_user_code.py new file mode 100644 index 0000000..7d1c58b --- /dev/null +++ b/scripts/upgrade_user_code.py @@ -0,0 +1,75 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from pathlib import Path + + +def update_file(file_path: Path): + with open(file_path, "r") as f: + content = f.read() + + union_import_needed = False + + class_attrs = [ + "config_type", + "data_factory_type", + "model_factory_type", + "checkpoint_engine_type", + "optimizer_factory_type", + "scheduler_factory_type", + "tokenizer_factory_type", + ] + + for attr in class_attrs: + content = re.sub( + rf"{attr}\s*=\s*([A-Za-z_][A-Za-z0-9_]*)", + lambda match, attr_name=attr.replace( + "_type", "" + ): f"{attr_name}: {match.group(1)}", + content, + ) + + def replace_with_union(match, attr_name): + nonlocal union_import_needed + union_import_needed = True + return f"{attr_name}: Union[{match.group(1)}]" + + for attr in class_attrs: + content = re.sub( + rf"{attr}\s*=\s*\[([A-Za-z0-9_,\s]+)\]", + lambda match, attr_name=attr.replace("_type", ""): replace_with_union( + match, attr_name + ), + content, + ) + + if union_import_needed and not re.search(r"from typing import .*Union.*", content): + content = "from typing import Union\n" + content + + with open(file_path, "w") as f: + f.write(content) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python upgrade_user_code.py ") + sys.exit(1) + file_path = Path(sys.argv[1]) + if not file_path.exists(): + print(f"File {file_path} not found") + sys.exit(1) + update_file(file_path) diff --git a/tests/checkpoint/test_ds_engine.py b/tests/checkpoint/test_ds_engine.py index 0db6217..51fdc0d 100644 --- a/tests/checkpoint/test_ds_engine.py +++ b/tests/checkpoint/test_ds_engine.py @@ -25,6 +25,7 @@ def test_ds_engine(tmp_path): config_dict = { "type": "sft", "exit_iteration": 2, + "skip_validation": True, "model": { "type": "random-weight-hf", "name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct", diff --git a/tests/checkpoint/test_hf_engine.py b/tests/checkpoint/test_hf_engine.py index 48b42ca..f7ab2fa 100644 --- a/tests/checkpoint/test_hf_engine.py +++ b/tests/checkpoint/test_hf_engine.py @@ -24,6 +24,7 @@ def test_hf_engine(tmp_path): config_dict = { "type": "sft", + "skip_validation": True, "model": { "type": "random-weight-hf", "name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct", diff --git a/tests/trainer/test_sft_trainer.py b/tests/trainer/test_sft_trainer.py index 5a4f60f..ccae4dd 100644 --- a/tests/trainer/test_sft_trainer.py +++ b/tests/trainer/test_sft_trainer.py @@ -24,6 +24,7 @@ def test_sft_trainer(tmp_path): config_dict = { "type": "sft", + "skip_validation": True, "exit_iteration": 2, "micro_batch_size": 1, "model": { @@ -51,6 +52,7 @@ def test_sft_trainer(tmp_path): def test_sft_trainer_cpu(tmp_path): config_dict = { "type": "sft", + "skip_validation": True, "exit_iteration": 2, "micro_batch_size": 1, "model": {