diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index b1745b384e..c32f202625 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -19,6 +19,7 @@ eval_model, model_dpa1, model_dpa2, + model_hybrid, model_se_e2_a, model_zbl, ) @@ -192,6 +193,20 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +class TestEnergyModelHybridForce(unittest.TestCase, ForceTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelHybridVirial(unittest.TestCase, VirialTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + class TestEnergyModelZBLForce(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_zbl) diff --git a/source/tests/pt/model/test_jit.py b/source/tests/pt/model/test_jit.py index a1aa9658fc..fc07267b88 100644 --- a/source/tests/pt/model/test_jit.py +++ b/source/tests/pt/model/test_jit.py @@ -101,7 +101,6 @@ def tearDown(self): JITTest.tearDown(self) -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, JITTest): def setUp(self): input_json = str(Path(__file__).parent / "water/se_atten.json") @@ -118,7 +117,6 @@ def tearDown(self): JITTest.tearDown(self) -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid2(unittest.TestCase, JITTest): def setUp(self): input_json = str(Path(__file__).parent / "water/se_atten.json") @@ -128,7 +126,7 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_hybrid) - self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" + # self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" self.config["training"]["numb_steps"] = 10 self.config["training"]["save_freq"] = 10 diff --git a/source/tests/pt/model/test_null_input.py b/source/tests/pt/model/test_null_input.py index 93a3ff8511..eb8ff714e8 100644 --- a/source/tests/pt/model/test_null_input.py +++ b/source/tests/pt/model/test_null_input.py @@ -119,7 +119,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, NullTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) @@ -127,7 +126,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestForceModelHybrid(unittest.TestCase, NullTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 45790bf43d..fa97281718 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -279,7 +279,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) @@ -287,7 +286,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestForceModelHybrid(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 0c3a34e2d5..19f671e619 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -154,7 +154,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) @@ -162,7 +161,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestForceModelHybrid(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index 88d75a040c..bc1d26bffa 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -195,7 +195,6 @@ def setUp(self): self.epsilon, self.aprec = None, None -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index 23365f3c9a..b9affac3aa 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -110,7 +110,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) @@ -118,7 +117,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") class TestForceModelHybrid(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_hybrid)