Skip to content

Commit

Permalink
fix(pt): finetuning property/dipole/polar/dos fitting with multi-dime…
Browse files Browse the repository at this point in the history
…nsional data causes error (#4145)

Fix issue #4108 

If a pretrained model is labeled with energy and the `out_bias` is one
dimension. If we want to finetune a dos/polar/dipole/property model
using this pretrained model, the `out_bias` of finetuning model is
multi-dimension(example: numb_dos = 250). An error occurs:
`RuntimeError: Error(s) in loading state_dict for ModelWrapper:`
` size mismatch for model.Default.atomic_model.out_bias: copying a param
with shape torch.Size([1, 118, 1]) from checkpoint, the shape in current
model is torch.Size([1, 118, 250]).`
` size mismatch for model.Default.atomic_model.out_std: copying a param
with shape torch.Size([1, 118, 1]) from checkpoint, the shape in current
model is torch.Size([1, 118, 250]).`

When using new fitting, old out_bias is useless because we will
recompute the new bias in later code. So we do not need to load old
out_bias when using new fitting finetune.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced parameter collection for fine-tuning, refining criteria for
parameter retention.
- Introduced a model checkpoint file for saving and resuming training
states, facilitating iterative development.

- **Tests**
- Added a new test class to validate training and fine-tuning processes,
ensuring model performance consistency across configurations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Chengqian-Zhang and pre-commit-ci[bot] authored Sep 25, 2024
1 parent 508759c commit 0b3f860
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
2 changes: 1 addition & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def collect_single_finetune_params(
if i != "_extra_state" and f".{_model_key}." in i
]
for item_key in target_keys:
if _new_fitting and ".fitting_net." in item_key:
if _new_fitting and (".descriptor." not in item_key):
# print(f'Keep {item_key} in old model!')
_new_state_dict[item_key] = (
_random_state_dict[item_key].clone().detach()
Expand Down
68 changes: 68 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,5 +448,73 @@ def tearDown(self) -> None:
DPTrainTest.tearDown(self)


class TestPropFintuFromEnerModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_dpa1)
self.config["model"]["type_map"] = ["H", "C", "N", "O"]
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

property_input = str(Path(__file__).parent / "property/input.json")
with open(property_input) as f:
self.config_property = json.load(f)
prop_data_file = [str(Path(__file__).parent / "property/single")]
self.config_property["training"]["training_data"]["systems"] = prop_data_file
self.config_property["training"]["validation_data"]["systems"] = prop_data_file
self.config_property["model"]["descriptor"] = deepcopy(model_dpa1["descriptor"])
self.config_property["training"]["numb_steps"] = 1
self.config_property["training"]["save_freq"] = 1

def test_dp_train(self):
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
trainer.run()
state_dict_trained = trainer.wrapper.model.state_dict()

# test fine-tuning using diffferent fitting_net, here using property fitting
finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
self.config_property["model"], finetune_links = get_finetune_rules(
finetune_model,
self.config_property["model"],
model_branch="RANDOM",
)
trainer_finetune = get_trainer(
deepcopy(self.config_property),
finetune_model=finetune_model,
finetune_links=finetune_links,
)

# check parameters
state_dict_finetuned = trainer_finetune.wrapper.model.state_dict()
for state_key in state_dict_finetuned:
if (
"out_bias" not in state_key
and "out_std" not in state_key
and "fitting" not in state_key
):
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned[state_key],
)

# check running
trainer_finetune.run()

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


if __name__ == "__main__":
unittest.main()

0 comments on commit 0b3f860

Please sign in to comment.