From 29db7918e0d03910773a89dc0d24a34ec8af55ce Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:09:23 +0800 Subject: [PATCH] feat(pt): add datafile option for change-bias (#3945) ## Summary by CodeRabbit - **New Features** - Added optional `--datafile` argument to specify a file for system data processing. - **Bug Fixes** - Improved `help` messages for `--datafile` argument to clarify its usage. - **Tests** - Enhanced test coverage for changing bias with a new method that handles data from a system file. --- deepmd/main.py | 9 +++++++- deepmd/pt/entrypoints/main.py | 7 +++++- source/tests/pt/test_change_bias.py | 34 +++++++++++++++++++++++++++-- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index 32a73fb4d0..727e027359 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -370,7 +370,7 @@ def main_parser() -> argparse.ArgumentParser: "--datafile", default=None, type=str, - help="The path to file of test list.", + help="The path to the datafile, each line of which is a path to one data system.", ) parser_tst.add_argument( "-S", @@ -685,6 +685,13 @@ def main_parser() -> argparse.ArgumentParser: type=str, help="The system dir. Recursively detect systems in this directory", ) + parser_change_bias_source.add_argument( + "-f", + "--datafile", + default=None, + type=str, + help="The path to the datafile, each line of which is a path to one data system.", + ) parser_change_bias_source.add_argument( "-b", "--bias-value", diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 1413283afd..3dfeecb670 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -469,7 +469,12 @@ def change_bias(FLAGS): updated_model = model_to_change else: # calculate bias on given systems - data_systems = process_systems(expand_sys_str(FLAGS.system)) + if FLAGS.datafile is not None: + with open(FLAGS.datafile) as datalist: + all_sys = datalist.read().splitlines() + else: + all_sys = expand_sys_str(FLAGS.system) + data_systems = process_systems(all_sys) data_single = DpLoaderSet( data_systems, 1, diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index 67a08730ea..f76be40b3f 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -2,6 +2,7 @@ import json import os import shutil +import tempfile import unittest from copy import ( deepcopy, @@ -36,6 +37,9 @@ to_torch_tensor, ) +from .common import ( + run_dp, +) from .model.test_permutation import ( model_se_e2_a, ) @@ -77,12 +81,15 @@ def setUp(self): self.model_path_data_bias = Path(current_path) / ( model_name + "data_bias" + ".pt" ) + self.model_path_data_file_bias = Path(current_path) / ( + model_name + "data_file_bias" + ".pt" + ) self.model_path_user_bias = Path(current_path) / ( model_name + "user_bias" + ".pt" ) def test_change_bias_with_data(self): - os.system( + run_dp( f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}" ) state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE) @@ -99,9 +106,32 @@ def test_change_bias_with_data(self): expected_bias = expected_model.get_out_bias() torch.testing.assert_close(updated_bias, expected_bias) + def test_change_bias_with_data_sys_file(self): + tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt") + with open(tmp_file.name, "w") as f: + f.writelines([sys + "\n" for sys in self.data_file]) + run_dp( + f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}" + ) + state_dict = torch.load( + str(self.model_path_data_file_bias), map_location=DEVICE + ) + model_params = state_dict["model"]["_extra_state"]["model_params"] + model_for_wrapper = get_model_for_wrapper(model_params) + wrapper = ModelWrapper(model_for_wrapper) + wrapper.load_state_dict(state_dict["model"]) + updated_bias = wrapper.model["Default"].get_out_bias() + expected_model = model_change_out_bias( + self.trainer.wrapper.model["Default"], + self.sampled, + _bias_adjust_mode="change-by-statistic", + ) + expected_bias = expected_model.get_out_bias() + torch.testing.assert_close(updated_bias, expected_bias) + def test_change_bias_with_user_defined(self): user_bias = [0.1, 3.2, -0.5] - os.system( + run_dp( f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}" ) state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE)