Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support spin in dpgen2. #265

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
7 changes: 6 additions & 1 deletion dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
)
from dpgen2.exploration.render import (
TrajRenderLammps,
TrajRenderLammpsSpin,
)
from dpgen2.exploration.report import (
ExplorationReportTrustLevelsRandom,
Expand All @@ -81,6 +82,7 @@
CustomizedLmpTemplateTaskGroup,
ExplorationStage,
ExplorationTask,
LmpSpinTaskGroup,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Remove Unused Import LmpSpinTaskGroup

The LmpSpinTaskGroup is imported in dpgen2/entrypoint/submit.py but is not used within this file. Please remove this import to keep the codebase clean and maintainable.

🔗 Analysis chain

Consider removing unused import.

The LmpSpinTaskGroup is imported but not used in this file. Consider removing this import to keep the codebase clean.

To verify if this import is used elsewhere in the codebase, we can run the following script:

If the script doesn't return any results, it's safe to remove the import. Would you like me to remove this unused import?


Update Shell Script to Verify Usage of LmpSpinTaskGroup

The previous script encountered an error because python is not a recognized file type in rg. Please run the updated script below to check if LmpSpinTaskGroup is used elsewhere in the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if LmpSpinTaskGroup is used in other files

# Search for LmpSpinTaskGroup usage in Python files
rg --type python "LmpSpinTaskGroup" -g '!dpgen2/entrypoint/submit.py'

Length of output: 100


Script:

#!/bin/bash
# Description: Check if LmpSpinTaskGroup is used in other Python files

# Search for LmpSpinTaskGroup usage in Python files, excluding submit.py
rg --type py "LmpSpinTaskGroup" -g '!dpgen2/entrypoint/submit.py'

Length of output: 461

🧰 Tools
🪛 Ruff

85-85: dpgen2.exploration.task.LmpSpinTaskGroup imported but unused

Remove unused import

(F401)

LmpTemplateTaskGroup,
NPTTaskGroup,
caly_normalize,
Expand Down Expand Up @@ -377,7 +379,10 @@ def make_lmp_naive_exploration_scheduler(config):
# report
conv_style = convergence.pop("type")
report = conv_styles[conv_style](**convergence)
render = TrajRenderLammps(nopbc=output_nopbc, use_ele_temp=use_ele_temp)
if "spin" in conv_style:
render = TrajRenderLammpsSpin(nopbc=output_nopbc, use_ele_temp=use_ele_temp)
else:
render = TrajRenderLammps(nopbc=output_nopbc, use_ele_temp=use_ele_temp)
# selector
selector = ConfSelectorFrames(
render,
Expand Down
3 changes: 3 additions & 0 deletions dpgen2/exploration/deviation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .deviation_manager import (
DeviManager,
)
from .deviation_spin import (
DeviManagerSpin,
)
from .deviation_std import (
DeviManagerStd,
)
127 changes: 127 additions & 0 deletions dpgen2/exploration/deviation/deviation_spin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from collections import (
defaultdict,
)
from typing import (
Dict,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused import Dict

The import Dict from the typing module is not used in the code. Removing it will clean up unnecessary imports.

Apply this diff to remove the unused import:

 from typing import (
-    Dict,
     List,
     Optional,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Dict,
from typing import (
List,
Optional,
)
🧰 Tools
🪛 Ruff

5-5: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)

List,
Optional,
)

import numpy as np

from .deviation_manager import (
DeviManager,
)


class DeviManagerSpin(DeviManager):
r"""The class which is responsible for DeepSPIN model deviation management.

This is the implementation of DeviManager for DeepSPIN model. Each deviation
(e.g. max_devi_af, max_devi_mf in file `model_devi.out`) is stored
as a List[Optional[np.ndarray]], where np.array is a one-dimensional
array.
A List[np.ndarray][ii][jj] is the force model deviation of the jj-th
frame of the ii-th trajectory.
The model deviation can be List[None], where len(List[None]) is
the number of trajectory files.

"""

MAX_DEVI_AF = "max_devi_af"
MIN_DEVI_AF = "min_devi_af"
AVG_DEVI_AF = "avg_devi_af"
MAX_DEVI_MF = "max_devi_mf"
MIN_DEVI_MF = "min_devi_mf"
AVG_DEVI_MF = "avg_devi_mf"

def __init__(self):
super().__init__()
self._data = defaultdict(list)

Comment on lines +38 to +41
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Initialize self.ntraj in the constructor

The attribute self.ntraj is used in the _add and _get methods but is not initialized in the constructor. This could lead to an AttributeError if these methods are called before self.ntraj is defined.

Apply this diff to initialize self.ntraj:

     def __init__(self):
         super().__init__()
         self._data = defaultdict(list)
+        self.ntraj = 0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self):
super().__init__()
self._data = defaultdict(list)
def __init__(self):
super().__init__()
self._data = defaultdict(list)
self.ntraj = 0

def _check_name(self, name: str):
assert name in (
DeviManager.MAX_DEVI_V,
DeviManager.MIN_DEVI_V,
DeviManager.AVG_DEVI_V,
self.MAX_DEVI_AF,
self.MIN_DEVI_AF,
self.AVG_DEVI_AF,
self.MAX_DEVI_MF,
self.MIN_DEVI_MF,
self.AVG_DEVI_MF,
), f"Error: unknown deviation name {name}"

Comment on lines +42 to +54
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use exceptions instead of assert statements for input validation

Using assert statements for input validation is not recommended in production code because assertions can be disabled with Python optimizations (-O flag). It's better to raise appropriate exceptions to ensure that validation checks are always enforced.

Consider replacing assert statements with explicit exception handling. For example, in the _check_name method:

     def _check_name(self, name: str):
-        assert name in (
+        if name not in (
             DeviManager.MAX_DEVI_V,
             DeviManager.MIN_DEVI_V,
             DeviManager.AVG_DEVI_V,
             self.MAX_DEVI_AF,
             self.MIN_DEVI_AF,
             self.AVG_DEVI_AF,
             self.MAX_DEVI_MF,
             self.MIN_DEVI_MF,
             self.AVG_DEVI_MF,
-        ), f"Error: unknown deviation name {name}"
+        ):
+            raise ValueError(f"Error: unknown deviation name {name}")

Similarly, in the _add method:

     def _add(self, name: str, deviation: np.ndarray) -> None:
-        assert isinstance(
-            deviation, np.ndarray
-        ), f"Error: deviation(type: {type(deviation)}) is not a np.ndarray"
+        if not isinstance(deviation, np.ndarray):
+            raise TypeError(f"Error: deviation(type: {type(deviation)}) is not a np.ndarray")

-        assert len(deviation.shape) == 1, (
+        if len(deviation.shape) != 1:
             f"Error: deviation(shape: {deviation.shape}) is not a "
             + "one-dimensional array"
+            )
+            raise ValueError(
+                f"Error: deviation(shape: {deviation.shape}) is not a one-dimensional array"
+            )

In the _check_data method:

             if len(self._data[name]) > 0:
-                assert len(self._data[name]) == self.ntraj, (
+                if len(self._data[name]) != self.ntraj:
                     f"Error: the number of model deviation {name} "
                     + f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
                     + "are not equal."
+                    )
+                    raise ValueError(
+                        f"Error: the number of model deviation {name} "
+                        f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+                        "are not equal."
+                    )

By raising explicit exceptions, you ensure that validations are always active and provide clearer error handling.

Also applies to: 55-65, 95-105

def _add(self, name: str, deviation: np.ndarray) -> None:
assert isinstance(
deviation, np.ndarray
), f"Error: deviation(type: {type(deviation)}) is not a np.ndarray"
assert len(deviation.shape) == 1, (
f"Error: deviation(shape: {deviation.shape}) is not a "
+ f"one-dimensional array"
)
Comment on lines +59 to +62
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unnecessary f prefix in string concatenation

At line 61, the string being concatenated does not contain any placeholders, so the f prefix is unnecessary.

Apply this diff to remove the extraneous f prefix:

             f"Error: deviation(shape: {deviation.shape}) is not a "
-            + f"one-dimensional array"
+            + "one-dimensional array"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert len(deviation.shape) == 1, (
f"Error: deviation(shape: {deviation.shape}) is not a "
+ f"one-dimensional array"
)
assert len(deviation.shape) == 1, (
f"Error: deviation(shape: {deviation.shape}) is not a "
+ "one-dimensional array"
)
🧰 Tools
🪛 Ruff

61-61: f-string without any placeholders

Remove extraneous f prefix

(F541)

self._data[name].append(deviation)
self.ntraj = max(self.ntraj, len(self._data[name]))

def _get(self, name: str) -> List[Optional[np.ndarray]]:
if self.ntraj == 0:
return []
elif len(self._data[name]) == 0:
return [None for _ in range(self.ntraj)]
else:
return self._data[name]

def clear(self) -> None:
self.__init__()
return None
Comment on lines +74 to +76
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid calling __init__ directly within clear method

Calling self.__init__() inside the clear method is not recommended, as it can lead to unexpected behavior, especially in complex inheritance hierarchies. It's better to reset the necessary attributes directly.

Consider refactoring the clear method to reset the attributes explicitly:

     def clear(self) -> None:
-        self.__init__()
+        self._data.clear()
+        self.ntraj = 0
         return None

Committable suggestion was skipped due to low confidence.


def _check_data(self) -> None:
r"""Check if data is valid"""
model_devi_names = (
DeviManager.MAX_DEVI_V,
DeviManager.MIN_DEVI_V,
DeviManager.AVG_DEVI_V,
self.MAX_DEVI_AF,
self.MIN_DEVI_AF,
self.AVG_DEVI_AF,
self.MAX_DEVI_MF,
self.MIN_DEVI_MF,
self.AVG_DEVI_MF,
Comment on lines +84 to +89

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be added to devi manager.
what is the difference between DeviManager.XXX_DEVI_F and self.XXX_DEVI_AF? they should be referring to the same thing.

)
# check the length of model deviations
frames = {}
for name in model_devi_names:
if len(self._data[name]) > 0:
assert len(self._data[name]) == self.ntraj, (
f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+ f"are not equal."
)
Comment on lines +96 to +99
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unnecessary f prefix in string concatenation

At line 98, the string being concatenated does not contain any placeholders, so the f prefix is unnecessary.

Apply this diff to remove the extraneous f prefix:

                     f"Error: the number of model deviation {name} "
                     + f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
-                    + f"are not equal."
+                    + "are not equal."
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+ f"are not equal."
)
f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+ "are not equal."
🧰 Tools
🪛 Ruff

98-98: f-string without any placeholders

Remove extraneous f prefix

(F541)

for idx, ndarray in enumerate(self._data[name]):
assert isinstance(ndarray, np.ndarray), (
f"Error: model deviation in {name} is not ndarray, "
+ f"index: {idx}, type: {type(ndarray)}"
)

frames[name] = [arr.shape[0] for arr in self._data[name]]
if len(frames[name]) == 0:
frames.pop(name)

# check if "max_devi_af" and "max_devi_mf" exist
assert (
len(self._data[self.MAX_DEVI_AF]) == self.ntraj
), f"Error: cannot find model deviation {self.MAX_DEVI_AF}"
assert (
len(self._data[self.MAX_DEVI_MF]) == self.ntraj
), f"Error: cannot find model deviation {self.MAX_DEVI_MF}"

# check if the length of the arrays corresponding to the same
# trajectory has the same number of frames
non_empty_deviations = list(frames.keys())
for name in non_empty_deviations[1:]:
assert frames[name] == frames[non_empty_deviations[0]], (
f"Error: the number of frames in {name} is different "
+ f"with that in {non_empty_deviations[0]}.\n"
+ f"{name}: {frames[name]}\n"
+ f"{non_empty_deviations[0]}: {frames[non_empty_deviations[0]]}\n"
)
3 changes: 3 additions & 0 deletions dpgen2/exploration/render/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
from .traj_render_lammps import (
TrajRenderLammps,
)
from .traj_render_lammps_spin import (
TrajRenderLammpsSpin,
)
75 changes: 75 additions & 0 deletions dpgen2/exploration/render/traj_render_lammps_spin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import (
Path,
)
from typing import (
TYPE_CHECKING,
List,
Optional,
Tuple,
Union,
Comment on lines +8 to +9
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused imports Tuple and Union

The imports Tuple (line 8) and Union (line 9) from the typing module are not used in the code. Removing unused imports helps keep the code clean and maintain readability.

Apply this diff to remove the unused imports:

 from typing import (
     TYPE_CHECKING,
     List,
     Optional,
-    Tuple,
-    Union,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Tuple,
Union,
from typing import (
TYPE_CHECKING,
List,
Optional,
)
🧰 Tools
🪛 Ruff

8-8: typing.Tuple imported but unused

Remove unused import

(F401)


9-9: typing.Union imported but unused

Remove unused import

(F401)

)

import dpdata
import numpy as np

from ..deviation import (
DeviManager,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused import DeviManager

The DeviManager import from ..deviation is not utilized in the code. It's best practice to remove such unused imports to reduce clutter and potential confusion.

Apply this diff to remove the unused import:

 from ..deviation import (
-    DeviManager,
     DeviManagerSpin,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
DeviManager,
from ..deviation import (
DeviManagerSpin,
)
🧰 Tools
🪛 Ruff

16-16: ..deviation.DeviManager imported but unused

Remove unused import: ..deviation.DeviManager

(F401)

DeviManagerSpin,
)
from .traj_render import (
TrajRender,
)

if TYPE_CHECKING:
from dpgen2.exploration.selector import (
ConfFilters,
)
Comment on lines +1 to +26
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused imports

There are several unused imports in this file. To improve code cleanliness and reduce potential confusion, please remove the following unused imports:

Apply this diff to remove the unused imports:

 from typing import (
     TYPE_CHECKING,
     List,
     Optional,
-    Tuple,
-    Union,
 )

 from ..deviation import (
-    DeviManager,
     DeviManagerSpin,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from pathlib import (
Path,
)
from typing import (
TYPE_CHECKING,
List,
Optional,
Tuple,
Union,
)
import dpdata
import numpy as np
from ..deviation import (
DeviManager,
DeviManagerSpin,
)
from .traj_render import (
TrajRender,
)
if TYPE_CHECKING:
from dpgen2.exploration.selector import (
ConfFilters,
)
from pathlib import (
Path,
)
from typing import (
TYPE_CHECKING,
List,
Optional,
)
import dpdata
import numpy as np
from ..deviation import (
DeviManagerSpin,
)
from .traj_render import (
TrajRender,
)
if TYPE_CHECKING:
from dpgen2.exploration.selector import (
ConfFilters,
)
🧰 Tools
🪛 Ruff

8-8: typing.Tuple imported but unused

Remove unused import

(F401)


9-9: typing.Union imported but unused

Remove unused import

(F401)


16-16: ..deviation.DeviManager imported but unused

Remove unused import: ..deviation.DeviManager

(F401)



class TrajRenderLammpsSpin(TrajRender):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can derive TrajRenderLammpsSpin from TrajRenderLammps, overriding only necessary methods, so that it can inherit new features from TrajRenderLammps.

def __init__(
self,
nopbc: bool = False,
use_ele_temp: int = 0,
):
Comment on lines +32 to +34
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Unused use_ele_temp parameter in __init__ method

The parameter use_ele_temp in the __init__ method is not used within the class. If it's not needed, consider removing it to clean up the code. If you plan to use it in the future, you might want to store it as an instance variable.

Apply this diff to remove the unused parameter:

 def __init__(
     self,
     nopbc: bool = False,
-    use_ele_temp: int = 0,
 ):
     self.nopbc = nopbc
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
nopbc: bool = False,
use_ele_temp: int = 0,
):
nopbc: bool = False,
):

self.nopbc = nopbc

def get_model_devi(
self,
files: List[Path],
) -> DeviManagerSpin:
ntraj = len(files)

model_devi = DeviManagerSpin()
for ii in range(ntraj):
self._load_one_model_devi(files[ii], model_devi)

return model_devi

def _load_one_model_devi(self, fname, model_devi):
dd = np.loadtxt(fname)
Comment on lines +49 to +50
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling when loading deviation data

The method _load_one_model_devi uses np.loadtxt(fname) to load data from a file, which might raise exceptions if the file is missing, inaccessible, or improperly formatted. Consider adding error handling to gracefully manage these potential issues and provide informative error messages.

You could modify the method as follows:

 def _load_one_model_devi(self, fname, model_devi):
-    dd = np.loadtxt(fname)
+    try:
+        dd = np.loadtxt(fname)
+    except Exception as e:
+        raise RuntimeError(f"Failed to load deviation data from {fname}: {e}")

This addition ensures that users receive clear feedback if the data loading process encounters problems.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _load_one_model_devi(self, fname, model_devi):
dd = np.loadtxt(fname)
def _load_one_model_devi(self, fname, model_devi):
try:
dd = np.loadtxt(fname)
except Exception as e:
raise RuntimeError(f"Failed to load deviation data from {fname}: {e}")

model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9])
Comment on lines +50 to +56
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for the shape of loaded data

In the _load_one_model_devi method, the code assumes that the data loaded from fname has at least 10 columns:

dd = np.loadtxt(fname)
model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9])

If the loaded data does not have the expected number of columns, this will raise an IndexError. Consider adding a validation step to ensure the data has the correct shape before accessing specific columns.

Apply this diff to add data shape validation:

 def _load_one_model_devi(self, fname, model_devi):
     dd = np.loadtxt(fname)
+    if dd.shape[1] < 10:
+        raise ValueError(f"Expected at least 10 columns in {fname}, but got {dd.shape[1]}")
     model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
     model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
     model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
     model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
     model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
     model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9])

Comment on lines +49 to +56
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling and data validation in _load_one_model_devi

The _load_one_model_devi method currently lacks error handling for file loading and doesn't validate the shape of the loaded data. This could lead to unexpected errors if the file is missing, inaccessible, or doesn't have the expected format.

Consider adding error handling and data validation:

  1. Wrap the np.loadtxt call in a try-except block to handle potential file loading errors.
  2. Validate the shape of the loaded data before accessing specific columns.

Here's an example of how you could modify the method:

def _load_one_model_devi(self, fname, model_devi):
    try:
        dd = np.loadtxt(fname)
        if dd.shape[1] < 10:
            raise ValueError(f"Expected at least 10 columns in {fname}, but got {dd.shape[1]}")
        model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
        model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
        model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
        model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
        model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
        model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9])
    except Exception as e:
        raise RuntimeError(f"Failed to load or process deviation data from {fname}: {e}")

This modification will provide more informative error messages and prevent potential index out of bounds errors.


def get_confs(
self,
trajs: List[Path],
id_selected: List[List[int]],
type_map: Optional[List[str]] = None,
conf_filters: Optional["ConfFilters"] = None,
) -> dpdata.MultiSystems:
del conf_filters # by far does not support conf filters
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider removing the conf_filters parameter if unsupported

The get_confs method accepts a conf_filters parameter but immediately deletes it since it's not supported:

del conf_filters  # by far does not support conf filters

If conf_filters is not intended for future use, consider removing it from the method signature to avoid confusion. If you plan to support it later, you might leave a TODO comment explaining when it will be implemented.

Apply this diff to remove the unused parameter:

 def get_confs(
     self,
     trajs: List[Path],
     id_selected: List[List[int]],
     type_map: Optional[List[str]] = None,
-    conf_filters: Optional["ConfFilters"] = None,
 ) -> dpdata.MultiSystems:
-    del conf_filters  # by far does not support conf filters
     ntraj = len(trajs)

ntraj = len(trajs)
traj_fmt = "lammps/dump"
ms = dpdata.MultiSystems(type_map=type_map)
for ii in range(ntraj):
if len(id_selected[ii]) > 0:
Comment on lines +67 to +71
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure trajs and id_selected have matching lengths

In the get_confs method, there's an assumption that trajs and id_selected are lists of the same length, as they are accessed using the same index ii. To prevent potential IndexError exceptions, ensure that both lists are of equal length or handle cases where they might differ.

You might add a validation check at the beginning of the method:

         ntraj = len(trajs)
+        if ntraj != len(id_selected):
+            raise ValueError("The lengths of 'trajs' and 'id_selected' must match.")

Alternatively, you could iterate using zip to pair elements directly:

for traj, ids in zip(trajs, id_selected):
    if len(ids) > 0:
        ss = dpdata.System(traj, fmt=traj_fmt, type_map=type_map)
        ss.nopbc = self.nopbc
        ss = ss.sub_system(ids)
        ms.append(ss)

This approach inherently handles lists of unequal lengths by stopping at the shortest list's end.

ss = dpdata.System(trajs[ii], fmt=traj_fmt, type_map=type_map)
ss.nopbc = self.nopbc
ss = ss.sub_system(id_selected[ii])
ms.append(ss)
return ms
4 changes: 4 additions & 0 deletions dpgen2/exploration/report/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
from .report_trust_levels_random import (
ExplorationReportTrustLevelsRandom,
)
from .report_trust_levels_spin import (
ExplorationReportTrustLevelsSpin,
)

conv_styles = {
"fixed-levels": ExplorationReportTrustLevelsRandom,
"fixed-levels-max-select": ExplorationReportTrustLevelsMax,
"fixed-levels-max-select-spin": ExplorationReportTrustLevelsSpin,
"adaptive-lower": ExplorationReportAdaptiveLower,
}
Loading