-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support spin in dpgen2. #265
base: master
Are you sure you want to change the base?
Changes from 12 commits
b886dcc
23ef302
243b9ac
396bdda
521c822
084e4f8
d951a57
8435208
4132bfb
262912a
cc6dc7b
81e2f4f
6f7ee72
6f4953a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from .deviation_manager import ( | ||
DeviManager, | ||
) | ||
from .deviation_spin import ( | ||
DeviManagerSpin, | ||
) | ||
from .deviation_std import ( | ||
DeviManagerStd, | ||
) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,127 @@ | ||||||||||||||||||
from collections import ( | ||||||||||||||||||
defaultdict, | ||||||||||||||||||
) | ||||||||||||||||||
from typing import ( | ||||||||||||||||||
Dict, | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import The import Apply this diff to remove the unused import: from typing import (
- Dict,
List,
Optional,
) 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||||||||||||||
List, | ||||||||||||||||||
Optional, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
import numpy as np | ||||||||||||||||||
|
||||||||||||||||||
from .deviation_manager import ( | ||||||||||||||||||
DeviManager, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class DeviManagerSpin(DeviManager): | ||||||||||||||||||
r"""The class which is responsible for DeepSPIN model deviation management. | ||||||||||||||||||
|
||||||||||||||||||
This is the implementation of DeviManager for DeepSPIN model. Each deviation | ||||||||||||||||||
(e.g. max_devi_af, max_devi_mf in file `model_devi.out`) is stored | ||||||||||||||||||
as a List[Optional[np.ndarray]], where np.array is a one-dimensional | ||||||||||||||||||
array. | ||||||||||||||||||
A List[np.ndarray][ii][jj] is the force model deviation of the jj-th | ||||||||||||||||||
frame of the ii-th trajectory. | ||||||||||||||||||
The model deviation can be List[None], where len(List[None]) is | ||||||||||||||||||
the number of trajectory files. | ||||||||||||||||||
|
||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
MAX_DEVI_AF = "max_devi_af" | ||||||||||||||||||
MIN_DEVI_AF = "min_devi_af" | ||||||||||||||||||
AVG_DEVI_AF = "avg_devi_af" | ||||||||||||||||||
MAX_DEVI_MF = "max_devi_mf" | ||||||||||||||||||
MIN_DEVI_MF = "min_devi_mf" | ||||||||||||||||||
AVG_DEVI_MF = "avg_devi_mf" | ||||||||||||||||||
|
||||||||||||||||||
def __init__(self): | ||||||||||||||||||
super().__init__() | ||||||||||||||||||
self._data = defaultdict(list) | ||||||||||||||||||
|
||||||||||||||||||
Comment on lines
+38
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initialize The attribute Apply this diff to initialize def __init__(self):
super().__init__()
self._data = defaultdict(list)
+ self.ntraj = 0 📝 Committable suggestion
Suggested change
|
||||||||||||||||||
def _check_name(self, name: str): | ||||||||||||||||||
assert name in ( | ||||||||||||||||||
DeviManager.MAX_DEVI_V, | ||||||||||||||||||
DeviManager.MIN_DEVI_V, | ||||||||||||||||||
DeviManager.AVG_DEVI_V, | ||||||||||||||||||
self.MAX_DEVI_AF, | ||||||||||||||||||
self.MIN_DEVI_AF, | ||||||||||||||||||
self.AVG_DEVI_AF, | ||||||||||||||||||
self.MAX_DEVI_MF, | ||||||||||||||||||
self.MIN_DEVI_MF, | ||||||||||||||||||
self.AVG_DEVI_MF, | ||||||||||||||||||
), f"Error: unknown deviation name {name}" | ||||||||||||||||||
|
||||||||||||||||||
Comment on lines
+42
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use exceptions instead of Using Consider replacing def _check_name(self, name: str):
- assert name in (
+ if name not in (
DeviManager.MAX_DEVI_V,
DeviManager.MIN_DEVI_V,
DeviManager.AVG_DEVI_V,
self.MAX_DEVI_AF,
self.MIN_DEVI_AF,
self.AVG_DEVI_AF,
self.MAX_DEVI_MF,
self.MIN_DEVI_MF,
self.AVG_DEVI_MF,
- ), f"Error: unknown deviation name {name}"
+ ):
+ raise ValueError(f"Error: unknown deviation name {name}") Similarly, in the def _add(self, name: str, deviation: np.ndarray) -> None:
- assert isinstance(
- deviation, np.ndarray
- ), f"Error: deviation(type: {type(deviation)}) is not a np.ndarray"
+ if not isinstance(deviation, np.ndarray):
+ raise TypeError(f"Error: deviation(type: {type(deviation)}) is not a np.ndarray")
- assert len(deviation.shape) == 1, (
+ if len(deviation.shape) != 1:
f"Error: deviation(shape: {deviation.shape}) is not a "
+ "one-dimensional array"
+ )
+ raise ValueError(
+ f"Error: deviation(shape: {deviation.shape}) is not a one-dimensional array"
+ ) In the if len(self._data[name]) > 0:
- assert len(self._data[name]) == self.ntraj, (
+ if len(self._data[name]) != self.ntraj:
f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+ "are not equal."
+ )
+ raise ValueError(
+ f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+ "are not equal."
+ ) By raising explicit exceptions, you ensure that validations are always active and provide clearer error handling. Also applies to: 55-65, 95-105 |
||||||||||||||||||
def _add(self, name: str, deviation: np.ndarray) -> None: | ||||||||||||||||||
assert isinstance( | ||||||||||||||||||
deviation, np.ndarray | ||||||||||||||||||
), f"Error: deviation(type: {type(deviation)}) is not a np.ndarray" | ||||||||||||||||||
assert len(deviation.shape) == 1, ( | ||||||||||||||||||
f"Error: deviation(shape: {deviation.shape}) is not a " | ||||||||||||||||||
+ f"one-dimensional array" | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+59
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unnecessary At line 61, the string being concatenated does not contain any placeholders, so the Apply this diff to remove the extraneous f"Error: deviation(shape: {deviation.shape}) is not a "
- + f"one-dimensional array"
+ + "one-dimensional array" 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||||||||||||||
self._data[name].append(deviation) | ||||||||||||||||||
self.ntraj = max(self.ntraj, len(self._data[name])) | ||||||||||||||||||
|
||||||||||||||||||
def _get(self, name: str) -> List[Optional[np.ndarray]]: | ||||||||||||||||||
if self.ntraj == 0: | ||||||||||||||||||
return [] | ||||||||||||||||||
elif len(self._data[name]) == 0: | ||||||||||||||||||
return [None for _ in range(self.ntraj)] | ||||||||||||||||||
else: | ||||||||||||||||||
return self._data[name] | ||||||||||||||||||
|
||||||||||||||||||
def clear(self) -> None: | ||||||||||||||||||
self.__init__() | ||||||||||||||||||
return None | ||||||||||||||||||
Comment on lines
+74
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid calling Calling Consider refactoring the def clear(self) -> None:
- self.__init__()
+ self._data.clear()
+ self.ntraj = 0
return None
|
||||||||||||||||||
|
||||||||||||||||||
def _check_data(self) -> None: | ||||||||||||||||||
r"""Check if data is valid""" | ||||||||||||||||||
model_devi_names = ( | ||||||||||||||||||
DeviManager.MAX_DEVI_V, | ||||||||||||||||||
DeviManager.MIN_DEVI_V, | ||||||||||||||||||
DeviManager.AVG_DEVI_V, | ||||||||||||||||||
self.MAX_DEVI_AF, | ||||||||||||||||||
self.MIN_DEVI_AF, | ||||||||||||||||||
self.AVG_DEVI_AF, | ||||||||||||||||||
self.MAX_DEVI_MF, | ||||||||||||||||||
self.MIN_DEVI_MF, | ||||||||||||||||||
self.AVG_DEVI_MF, | ||||||||||||||||||
Comment on lines
+84
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be added to devi manager. |
||||||||||||||||||
) | ||||||||||||||||||
# check the length of model deviations | ||||||||||||||||||
frames = {} | ||||||||||||||||||
for name in model_devi_names: | ||||||||||||||||||
if len(self._data[name]) > 0: | ||||||||||||||||||
assert len(self._data[name]) == self.ntraj, ( | ||||||||||||||||||
f"Error: the number of model deviation {name} " | ||||||||||||||||||
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) " | ||||||||||||||||||
+ f"are not equal." | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+96
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unnecessary At line 98, the string being concatenated does not contain any placeholders, so the Apply this diff to remove the extraneous f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
- + f"are not equal."
+ + "are not equal." 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||||||||||||||
for idx, ndarray in enumerate(self._data[name]): | ||||||||||||||||||
assert isinstance(ndarray, np.ndarray), ( | ||||||||||||||||||
f"Error: model deviation in {name} is not ndarray, " | ||||||||||||||||||
+ f"index: {idx}, type: {type(ndarray)}" | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
frames[name] = [arr.shape[0] for arr in self._data[name]] | ||||||||||||||||||
if len(frames[name]) == 0: | ||||||||||||||||||
frames.pop(name) | ||||||||||||||||||
|
||||||||||||||||||
# check if "max_devi_af" and "max_devi_mf" exist | ||||||||||||||||||
assert ( | ||||||||||||||||||
len(self._data[self.MAX_DEVI_AF]) == self.ntraj | ||||||||||||||||||
), f"Error: cannot find model deviation {self.MAX_DEVI_AF}" | ||||||||||||||||||
assert ( | ||||||||||||||||||
len(self._data[self.MAX_DEVI_MF]) == self.ntraj | ||||||||||||||||||
), f"Error: cannot find model deviation {self.MAX_DEVI_MF}" | ||||||||||||||||||
|
||||||||||||||||||
# check if the length of the arrays corresponding to the same | ||||||||||||||||||
# trajectory has the same number of frames | ||||||||||||||||||
non_empty_deviations = list(frames.keys()) | ||||||||||||||||||
for name in non_empty_deviations[1:]: | ||||||||||||||||||
assert frames[name] == frames[non_empty_deviations[0]], ( | ||||||||||||||||||
f"Error: the number of frames in {name} is different " | ||||||||||||||||||
+ f"with that in {non_empty_deviations[0]}.\n" | ||||||||||||||||||
+ f"{name}: {frames[name]}\n" | ||||||||||||||||||
+ f"{non_empty_deviations[0]}: {frames[non_empty_deviations[0]]}\n" | ||||||||||||||||||
) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,75 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from pathlib import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TYPE_CHECKING, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
List, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Optional, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Tuple, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Union, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+8
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused imports The imports Apply this diff to remove the unused imports: from typing import (
TYPE_CHECKING,
List,
Optional,
- Tuple,
- Union,
) 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import dpdata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..deviation import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DeviManager, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import The Apply this diff to remove the unused import: from ..deviation import (
- DeviManager,
DeviManagerSpin,
) 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DeviManagerSpin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .traj_render import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TrajRender, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from dpgen2.exploration.selector import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ConfFilters, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused imports There are several unused imports in this file. To improve code cleanliness and reduce potential confusion, please remove the following unused imports: Apply this diff to remove the unused imports: from typing import (
TYPE_CHECKING,
List,
Optional,
- Tuple,
- Union,
)
from ..deviation import (
- DeviManager,
DeviManagerSpin,
) 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class TrajRenderLammpsSpin(TrajRender): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps you can derive TrajRenderLammpsSpin from TrajRenderLammps, overriding only necessary methods, so that it can inherit new features from TrajRenderLammps. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nopbc: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_ele_temp: int = 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+32
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused The parameter Apply this diff to remove the unused parameter: def __init__(
self,
nopbc: bool = False,
- use_ele_temp: int = 0,
):
self.nopbc = nopbc 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.nopbc = nopbc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_model_devi( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
files: List[Path], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> DeviManagerSpin: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ntraj = len(files) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi = DeviManagerSpin() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in range(ntraj): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._load_one_model_devi(files[ii], model_devi) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return model_devi | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _load_one_model_devi(self, fname, model_devi): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dd = np.loadtxt(fname) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+49
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add error handling when loading deviation data The method You could modify the method as follows: def _load_one_model_devi(self, fname, model_devi):
- dd = np.loadtxt(fname)
+ try:
+ dd = np.loadtxt(fname)
+ except Exception as e:
+ raise RuntimeError(f"Failed to load deviation data from {fname}: {e}") This addition ensures that users receive clear feedback if the data loading process encounters problems. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+50
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add validation for the shape of loaded data In the dd = np.loadtxt(fname)
model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9]) If the loaded data does not have the expected number of columns, this will raise an Apply this diff to add data shape validation: def _load_one_model_devi(self, fname, model_devi):
dd = np.loadtxt(fname)
+ if dd.shape[1] < 10:
+ raise ValueError(f"Expected at least 10 columns in {fname}, but got {dd.shape[1]}")
model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9])
Comment on lines
+49
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling and data validation in The Consider adding error handling and data validation:
Here's an example of how you could modify the method: def _load_one_model_devi(self, fname, model_devi):
try:
dd = np.loadtxt(fname)
if dd.shape[1] < 10:
raise ValueError(f"Expected at least 10 columns in {fname}, but got {dd.shape[1]}")
model_devi.add(DeviManagerSpin.MAX_DEVI_AF, dd[:, 4])
model_devi.add(DeviManagerSpin.MIN_DEVI_AF, dd[:, 5])
model_devi.add(DeviManagerSpin.AVG_DEVI_AF, dd[:, 6])
model_devi.add(DeviManagerSpin.MAX_DEVI_MF, dd[:, 7])
model_devi.add(DeviManagerSpin.MIN_DEVI_MF, dd[:, 8])
model_devi.add(DeviManagerSpin.AVG_DEVI_MF, dd[:, 9])
except Exception as e:
raise RuntimeError(f"Failed to load or process deviation data from {fname}: {e}") This modification will provide more informative error messages and prevent potential index out of bounds errors. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_confs( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
trajs: List[Path], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
id_selected: List[List[int]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
type_map: Optional[List[str]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
conf_filters: Optional["ConfFilters"] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> dpdata.MultiSystems: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
del conf_filters # by far does not support conf filters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider removing the The del conf_filters # by far does not support conf filters If Apply this diff to remove the unused parameter: def get_confs(
self,
trajs: List[Path],
id_selected: List[List[int]],
type_map: Optional[List[str]] = None,
- conf_filters: Optional["ConfFilters"] = None,
) -> dpdata.MultiSystems:
- del conf_filters # by far does not support conf filters
ntraj = len(trajs) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ntraj = len(trajs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
traj_fmt = "lammps/dump" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ms = dpdata.MultiSystems(type_map=type_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in range(ntraj): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if len(id_selected[ii]) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+67
to
+71
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure In the You might add a validation check at the beginning of the method: ntraj = len(trajs)
+ if ntraj != len(id_selected):
+ raise ValueError("The lengths of 'trajs' and 'id_selected' must match.") Alternatively, you could iterate using for traj, ids in zip(trajs, id_selected):
if len(ids) > 0:
ss = dpdata.System(traj, fmt=traj_fmt, type_map=type_map)
ss.nopbc = self.nopbc
ss = ss.sub_system(ids)
ms.append(ss) This approach inherently handles lists of unequal lengths by stopping at the shortest list's end. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss = dpdata.System(trajs[ii], fmt=traj_fmt, type_map=type_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss.nopbc = self.nopbc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss = ss.sub_system(id_selected[ii]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ms.append(ss) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return ms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Remove Unused Import
LmpSpinTaskGroup
The
LmpSpinTaskGroup
is imported indpgen2/entrypoint/submit.py
but is not used within this file. Please remove this import to keep the codebase clean and maintainable.🔗 Analysis chain
Consider removing unused import.
The
LmpSpinTaskGroup
is imported but not used in this file. Consider removing this import to keep the codebase clean.To verify if this import is used elsewhere in the codebase, we can run the following script:
If the script doesn't return any results, it's safe to remove the import. Would you like me to remove this unused import?
Update Shell Script to Verify Usage of
LmpSpinTaskGroup
The previous script encountered an error because
python
is not a recognized file type inrg
. Please run the updated script below to check ifLmpSpinTaskGroup
is used elsewhere in the codebase.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
Length of output: 100
Script:
Length of output: 461
🧰 Tools
🪛 Ruff