Skip to content

Commit

Permalink
Recovered all the skipped test for hybrid descriptor (#3400)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
3 people authored Mar 3, 2024
1 parent e826260 commit ec32340
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 12 deletions.
15 changes: 15 additions & 0 deletions source/tests/pt/model/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
eval_model,
model_dpa1,
model_dpa2,
model_hybrid,
model_se_e2_a,
model_zbl,
)
Expand Down Expand Up @@ -192,6 +193,20 @@ def setUp(self):
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelHybridForce(unittest.TestCase, ForceTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelHybridVirial(unittest.TestCase, VirialTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelZBLForce(unittest.TestCase, ForceTest):
def setUp(self):
model_params = copy.deepcopy(model_zbl)
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt/model/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def tearDown(self):
JITTest.tearDown(self)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, JITTest):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
Expand All @@ -118,7 +117,6 @@ def tearDown(self):
JITTest.tearDown(self)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid2(unittest.TestCase, JITTest):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
Expand All @@ -128,7 +126,7 @@ def setUp(self):
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_hybrid)
self.config["model"]["descriptor"]["hybrid_mode"] = "sequential"
# self.config["model"]["descriptor"]["hybrid_mode"] = "sequential"
self.config["training"]["numb_steps"] = 10
self.config["training"]["save_freq"] = 10

Expand Down
2 changes: 0 additions & 2 deletions source/tests/pt/model/test_null_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,13 @@ def setUp(self):
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, NullTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
Expand Down
2 changes: 0 additions & 2 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,13 @@ def setUp(self):
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, PermutationTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, PermutationTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
Expand Down
2 changes: 0 additions & 2 deletions source/tests/pt/model/test_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,13 @@ def setUp(self):
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, RotTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, RotTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
Expand Down
1 change: 0 additions & 1 deletion source/tests/pt/model/test_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def setUp(self):
self.epsilon, self.aprec = None, None


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
Expand Down
2 changes: 0 additions & 2 deletions source/tests/pt/model/test_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,13 @@ def setUp(self):
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, TransTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, TransTest):
def setUp(self):
model_params = copy.deepcopy(model_hybrid)
Expand Down

0 comments on commit ec32340

Please sign in to comment.