Skip to content

Commit

Permalink
Refactor to better use class type hints (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mwyatt authored Feb 13, 2025
1 parent ab7bb18 commit 655377a
Show file tree
Hide file tree
Showing 34 changed files with 414 additions and 235 deletions.
1 change: 1 addition & 0 deletions arctic_training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .config.scheduler import SchedulerConfig
from .config.tokenizer import TokenizerConfig
from .config.trainer import TrainerConfig
from .config.trainer import get_config
from .data.factory import DataFactory
from .data.sft_factory import SFTDataFactory
from .data.source import DataSource
Expand Down
5 changes: 2 additions & 3 deletions arctic_training/checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Type

import torch

Expand All @@ -39,13 +38,13 @@ class CheckpointEngine(ABC, CallbackMixin):
engine in the registry.
"""

config_type: Type[CheckpointConfig] = CheckpointConfig
config: CheckpointConfig
"""
The configuration class for the checkpoint engine. This is used to validate
the configuration passed to the engine.
"""

def __init__(self, trainer: "Trainer", config: "CheckpointConfig") -> None:
def __init__(self, trainer: "Trainer", config: CheckpointConfig) -> None:
self._trainer = trainer
self.config = config

Expand Down
252 changes: 165 additions & 87 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,45 @@
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Union
from typing import cast

import deepspeed
import yaml
from deepspeed.accelerator import get_accelerator
from pydantic import Field
from pydantic import ValidationInfo
from pydantic import field_validator
from pydantic import model_validator
from typing_extensions import Self

if TYPE_CHECKING:
from arctic_training.checkpoint.engine import CheckpointEngine

import deepspeed
from deepspeed.accelerator import get_accelerator
from pydantic import ValidationInfo

from arctic_training.config import BaseConfig
from arctic_training.config.checkpoint import CheckpointConfig
from arctic_training.config.data import DataConfig
from arctic_training.config.enums import DType
from arctic_training.config.logger import LoggerConfig
from arctic_training.config.model import ModelConfig
from arctic_training.config.optimizer import OptimizerConfig
from arctic_training.config.scheduler import SchedulerConfig
from arctic_training.config.tokenizer import TokenizerConfig
from arctic_training.config.wandb import WandBConfig
from arctic_training.registry.checkpoint import get_registered_checkpoint_engine
from arctic_training.registry.data import get_registered_data_factory
from arctic_training.registry.model import get_registered_model_factory
from arctic_training.registry.optimizer import get_registered_optimizer_factory
from arctic_training.registry.scheduler import get_registered_scheduler_factory
from arctic_training.registry.tokenizer import get_registered_tokenizer_factory
from arctic_training.registry.trainer import get_registered_trainer
from arctic_training.registry.utils import _get_class_attr_type_hints
from arctic_training.utils import get_local_rank
from arctic_training.utils import get_world_size

from .checkpoint import CheckpointConfig
from .data import DataConfig
from .logger import LoggerConfig
from .model import ModelConfig
from .optimizer import OptimizerConfig
from .scheduler import SchedulerConfig
from .tokenizer import TokenizerConfig
from .wandb import WandBConfig
if TYPE_CHECKING:
from arctic_training.checkpoint.engine import CheckpointEngine


TRAINER_DEFAULT = "sft"
CUSTOM_CODE_DEFAULT = Path("train.py")
Expand All @@ -70,6 +72,9 @@ class TrainerConfig(BaseConfig):
code: Path = CUSTOM_CODE_DEFAULT
""" Path to the python script containing custom trainer implementation. """

skip_validation: bool = False
""" Skips validation of types for subconfigs and registered classes. """

model: ModelConfig
""" Model configuration. """

Expand Down Expand Up @@ -100,7 +105,7 @@ class TrainerConfig(BaseConfig):
loss_log_interval: int = Field(default=1, ge=0)
""" Number of steps between logging loss. """

gradient_accumulation_steps: int = 1
gradient_accumulation_steps: int = Field(default=1, ge=1)
""" Number of gradient accumulation steps. """

micro_batch_size: int = Field(default=1, ge=1)
Expand Down Expand Up @@ -147,82 +152,161 @@ def checkpoint_engines(self) -> List[partial["CheckpointEngine"]]:
def zero_3_enabled(self) -> bool:
return self.deepspeed.get("zero_optimization", {}).get("stage", 0) == 3

@field_validator("eval_frequency", mode="after")
def validate_eval_frequency(cls, v: int, info: ValidationInfo) -> int:
@staticmethod
def _get_subconfig_object(
v: Union[Dict, BaseConfig],
info: ValidationInfo,
get_class_fn: Callable,
attr_name: str,
) -> BaseConfig:
# Get the trainer class as it will tell us which types of factory
# classes (and thus configs) are default/compatible
trainer_type = info.data["type"]
trainer_cls = get_registered_trainer(trainer_type)

# Get type hints for this factory class. This is a list of compatible
# classes for the given attribute field.
attribute_type_hints = _get_class_attr_type_hints(trainer_cls, attr_name)

# Convert to a dictionary as default values are the base config classes
# and we likely need to use a different class based on the trainer type
# or user requested `type` field value.
if isinstance(v, dict):
config_dict = v
else:
config_dict = v.model_dump()

# Determine which attribute class to use (e.g., for `model`:
# HFModelFactory, LigerModelFactory, etc.)
if config_dict.get("type", ""):
# User explicitly specified the type
attr_cls = get_class_fn(config_dict["type"])
else:
# User did not specify the type, use the first (maybe only) hint as default type
attr_cls = attribute_type_hints[0]

# Check that the requested/resolved type is compatible with the trainer
if (
info.data["data"].eval_sources
or info.data["data"].train_eval_split[1] > 0.0
not info.data.get("skip_validation")
and attr_cls not in attribute_type_hints
):
assert v > 0, "eval_frequency must be set if eval dataset is provided."
return v
raise ValueError(
f"{attr_cls.__name__} is not supported for {attr_name} in"
f" {trainer_cls.__name__}. Supported types are"
f" {[cls.__name__ for cls in attribute_type_hints]}."
)

@field_validator("tokenizer", mode="after")
def set_tokenizer(cls, v: TokenizerConfig, info: ValidationInfo) -> TokenizerConfig:
if not v.name_or_path and "model" in info.data:
v.name_or_path = info.data["model"].name_or_path
# Make sure the `type` field is set in the config dict
config_dict["type"] = attr_cls.name

# Get the config class for the factory class and creat the config
config_cls = _get_class_attr_type_hints(attr_cls, "config")[0]
return config_cls(**config_dict)

@staticmethod
def _to_list(v: Union[Any, List[Any]]) -> List[Any]:
if not isinstance(v, list):
return [v]
return v

@field_validator(
"checkpoint",
"data",
"model",
"optimizer",
"scheduler",
"tokenizer",
mode="before",
)
@field_validator("checkpoint", mode="before")
@classmethod
def parse_sub_config(
def init_checkpoint_configs(
cls,
v: Any,
v: Union[Union[Dict, CheckpointConfig], List[Union[Dict, CheckpointConfig]]],
info: ValidationInfo,
) -> Union[BaseConfig, List[BaseConfig]]:
trainer_attr_map = {
"checkpoint": "checkpoint_engine_type",
"data": "data_factory_type",
"model": "model_factory_type",
"optimizer": "optimizer_factory_type",
"scheduler": "scheduler_factory_type",
"tokenizer": "tokenizer_factory_type",
}
field_name: str = info.field_name # type: ignore
trainer_type: str = info.data["type"]
trainer_cls = get_registered_trainer(trainer_type)
trainer_field_default = getattr(trainer_cls, trainer_attr_map[field_name])[0]

if isinstance(v, tuple) or isinstance(v, list):
return_list = []
for sub_v in v:
if isinstance(sub_v, BaseConfig):
sub_v = sub_v.model_dump()
field_cls = cls._get_config_cls(
sub_v, field_name, trainer_field_default
) -> List[CheckpointConfig]:
v = cls._to_list(v)
return_list = []
for sub_v in v:
return_list.append(
cls._get_subconfig_object(
v=sub_v,
info=info,
get_class_fn=get_registered_checkpoint_engine,
attr_name="checkpoint_engine",
)
sub_v["type"] = field_cls.name
return_list.append(field_cls.config_type(**sub_v))
return return_list
)
return [cast(CheckpointConfig, subconfig) for subconfig in return_list]

if isinstance(v, BaseConfig):
v = v.model_dump()
field_cls = cls._get_config_cls(v, field_name, trainer_field_default)
v["type"] = field_cls.name
return field_cls.config_type(**v)
@field_validator("data", mode="before")
@classmethod
def init_data_config(
cls, v: Union[Dict, DataConfig], info: ValidationInfo
) -> DataConfig:
subconfig = cls._get_subconfig_object(
v=v,
info=info,
get_class_fn=get_registered_data_factory,
attr_name="data_factory",
)
return cast(DataConfig, subconfig)

@field_validator("model", mode="before")
@classmethod
def init_model_config(
cls, v: Union[Dict, ModelConfig], info: ValidationInfo
) -> ModelConfig:
subconfig = cls._get_subconfig_object(
v=v,
info=info,
get_class_fn=get_registered_model_factory,
attr_name="model_factory",
)
return cast(ModelConfig, subconfig)

@field_validator("optimizer", mode="before")
@classmethod
def init_optimizer_config(
cls, v: Union[Dict, OptimizerConfig], info: ValidationInfo
) -> OptimizerConfig:
subconfig = cls._get_subconfig_object(
v=v,
info=info,
get_class_fn=get_registered_optimizer_factory,
attr_name="optimizer_factory",
)
return cast(OptimizerConfig, subconfig)

@field_validator("scheduler", mode="before")
@classmethod
def init_scheduler_config(
cls, v: Union[Dict, SchedulerConfig], info: ValidationInfo
) -> SchedulerConfig:
subconfig = cls._get_subconfig_object(
v=v,
info=info,
get_class_fn=get_registered_scheduler_factory,
attr_name="scheduler_factory",
)
return cast(SchedulerConfig, subconfig)

@field_validator("tokenizer", mode="before")
@classmethod
def _get_config_cls(cls, config_dict, field_name, default_cls):
get_class_fn_map = {
"checkpoint": get_registered_checkpoint_engine,
"data": get_registered_data_factory,
"model": get_registered_model_factory,
"optimizer": get_registered_optimizer_factory,
"scheduler": get_registered_scheduler_factory,
"tokenizer": get_registered_tokenizer_factory,
}
field_type = config_dict.get("type", "")
if field_type == "":
field_type = default_cls
field_cls = get_class_fn_map[field_name](field_type)
return field_cls
def init_tokenizer_config(
cls, v: Union[Dict, TokenizerConfig], info: ValidationInfo
) -> TokenizerConfig:
subconfig = cls._get_subconfig_object(
v=v,
info=info,
get_class_fn=get_registered_tokenizer_factory,
attr_name="tokenizer_factory",
)
return cast(TokenizerConfig, subconfig)

@model_validator(mode="after")
def validate_eval_frequency(self) -> Self:
if self.data.eval_sources or self.data.train_eval_split[1] > 0.0:
assert (
self.eval_frequency > 0
), "eval_frequency must be set if eval dataset is provided."
return self

@model_validator(mode="after")
def set_tokenizer(self) -> Self:
if not self.tokenizer.name_or_path:
self.tokenizer.name_or_path = self.model.name_or_path
return self

@field_validator("logger", mode="after")
@classmethod
Expand All @@ -232,12 +316,6 @@ def initialize_logger(cls, v: LoggerConfig) -> LoggerConfig:
setup_logger(v)
return v

@field_validator("checkpoint", mode="before")
def checkpoint_to_list(cls, v: Union[Dict, List[Dict]]) -> List[Dict]:
if not isinstance(v, list):
return [v]
return v

@model_validator(mode="after")
def build_deepspeed_config(self) -> Self:
ds_config = self.deepspeed
Expand Down Expand Up @@ -312,7 +390,7 @@ def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig:
sys.path = original_sys_path

trainer_cls = get_registered_trainer(trainer_type)
config_cls = trainer_cls.config_type
config_cls = _get_class_attr_type_hints(trainer_cls, "config")[0]

config = config_cls(**config_dict)

Expand Down
13 changes: 5 additions & 8 deletions arctic_training/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type

import torch
from datasets import Dataset
Expand Down Expand Up @@ -52,20 +51,18 @@ class DataFactory(ABC, CallbackMixin):
specify the DataFactory to use.
"""

config_type: Type[DataConfig] = DataConfig
config: DataConfig
"""
The type of the DataConfig object that this DataFactory uses. Any
DataFactory-specific options should be specified in this class.
"""

def __init__(
self, trainer: "Trainer", data_config: Optional["DataConfig"] = None
) -> None:
if data_config is None:
data_config = trainer.config.data
def __init__(self, trainer: "Trainer", config: Optional[DataConfig] = None) -> None:
if config is None:
config = trainer.config.data

self._trainer = trainer
self.config = data_config
self.config = config

def __call__(self) -> Tuple[DataLoader, Optional[DataLoader]]:
train_dataset = self._load_data(
Expand Down
Loading

0 comments on commit 655377a

Please sign in to comment.