Skip to content

Commit

Permalink
format training logging (#3397)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Mar 3, 2024
1 parent 13a8adf commit e826260
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 34 deletions.
4 changes: 2 additions & 2 deletions deepmd/loggers/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
)
CFORMATTER = logging.Formatter(
# "%(app_name)s %(levelname)-7s |-> %(name)-45s %(message)s"
"%(app_name)s %(levelname)-7s %(message)s"
"[%(asctime)s] %(app_name)s %(levelname)-7s %(message)s"
)
FFORMATTER_MPI = logging.Formatter(
"[%(asctime)s] %(app_name)s rank:%(rank)-2s %(levelname)-7s %(name)-45s %(message)s"
)
CFORMATTER_MPI = logging.Formatter(
# "%(app_name)s rank:%(rank)-2s %(levelname)-7s |-> %(name)-45s %(message)s"
"%(app_name)s rank:%(rank)-2s %(levelname)-7s %(message)s"
"[%(asctime)s] %(app_name)s rank:%(rank)-2s %(levelname)-7s %(message)s"
)


Expand Down
34 changes: 34 additions & 0 deletions deepmd/loggers/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
Optional,
)


def format_training_message(
batch: int,
wall_time: float,
):
"""Format a training message."""
return f"batch {batch:7d}: " f"total wall time = {wall_time:.2f} s"


def format_training_message_per_task(
batch: int,
task_name: str,
rmse: Dict[str, float],
learning_rate: Optional[float],
):
if task_name:
task_name += ": "
if learning_rate is None:
lr = ""
else:
lr = f", lr = {learning_rate:8.2e}"
# sort rmse
rmse = dict(sorted(rmse.items()))
return (
f"batch {batch:7d}: {task_name}"
f"{', '.join([f'{kk} = {vv:8.2e}' for kk, vv in rmse.items()])}"
f"{lr}"
)
83 changes: 53 additions & 30 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.loggers.training import (
format_training_message,
format_training_message_per_task,
)
from deepmd.pt.loss import (
DenoiseLoss,
EnergyStdLoss,
Expand Down Expand Up @@ -693,33 +697,24 @@ def step(_step_id, task_key="Default"):
# Log and persist
if _step_id % self.disp_freq == 0:
self.wrapper.eval()
msg = f"step={_step_id}, lr={cur_lr:.2e}"

def log_loss_train(_loss, _more_loss, _task_key="Default"):
results = {}
if not self.multi_task:
suffix = ""
else:
suffix = f"_{_task_key}"
_msg = f"loss{suffix}={_loss:.4f}"
rmse_val = {
item: _more_loss[item]
for item in _more_loss
if "l2_" not in item
}
for item in sorted(rmse_val.keys()):
_msg += f", {item}_train{suffix}={rmse_val[item]:.4f}"
results[item] = rmse_val[item]
return _msg, results
return results

def log_loss_valid(_task_key="Default"):
single_results = {}
sum_natoms = 0
if not self.multi_task:
suffix = ""
valid_numb_batch = self.valid_numb_batch
else:
suffix = f"_{_task_key}"
valid_numb_batch = self.valid_numb_batch[_task_key]
for ii in range(valid_numb_batch):
self.optimizer.zero_grad()
Expand All @@ -744,22 +739,32 @@ def log_loss_valid(_task_key="Default"):
single_results.get(k, 0.0) + v * natoms
)
results = {k: v / sum_natoms for k, v in single_results.items()}
_msg = ""
for item in sorted(results.keys()):
_msg += f", {item}_valid{suffix}={results[item]:.4f}"
return _msg, results
return results

if not self.multi_task:
temp_msg, train_results = log_loss_train(loss, more_loss)
msg += "\n" + temp_msg
temp_msg, valid_results = log_loss_valid()
msg += temp_msg
train_results = log_loss_train(loss, more_loss)
valid_results = log_loss_valid()
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results is not None:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
train_msg = {}
valid_msg = {}
train_msg[task_key], train_results[task_key] = log_loss_train(
train_results[task_key] = log_loss_train(
loss, more_loss, _task_key=task_key
)
for _key in self.model_keys:
Expand All @@ -774,19 +779,37 @@ def log_loss_valid(_task_key="Default"):
label=label_dict,
task_key=_key,
)
train_msg[_key], train_results[_key] = log_loss_train(
train_results[_key] = log_loss_train(
loss, more_loss, _task_key=_key
)
valid_msg[_key], valid_results[_key] = log_loss_valid(
_task_key=_key
valid_results[_key] = log_loss_valid(_task_key=_key)
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
msg += "\n" + train_msg[_key]
msg += valid_msg[_key]
if valid_results is not None:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
)
)

train_time = time.time() - self.t0
self.t0 = time.time()
msg += f", speed={train_time:.2f} s/{self.disp_freq if _step_id else 1} batches"
log.info(msg)
current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
)
)

if fout:
if self.lcurve_should_print_header:
Expand Down
44 changes: 42 additions & 2 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.loggers.training import (
format_training_message,
format_training_message_per_task,
)
from deepmd.tf.common import (
data_requirement,
get_precision,
Expand Down Expand Up @@ -774,8 +778,10 @@ def train(self, train_data=None, valid_data=None):
test_time = toc - tic
wall_time = toc - wall_time_tic
log.info(
"batch %7d training time %.2f s, testing time %.2f s, total wall time %.2f s"
% (cur_batch, train_time, test_time, wall_time)
format_training_message(
batch=cur_batch,
wall_time=wall_time,
)
)
# the first training time is not accurate
if cur_batch > self.disp_freq or stop_batch < 2 * self.disp_freq:
Expand Down Expand Up @@ -959,6 +965,23 @@ def print_on_training(
for k in train_results.keys():
print_str += prop_fmt % (train_results[k])
print_str += " %8.1e\n" % cur_lr
log.info(
format_training_message_per_task(
batch=cur_batch,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results is not None:
log.info(
format_training_message_per_task(
batch=cur_batch,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
for fitting_key in train_results:
if valid_results[fitting_key] is not None:
Expand All @@ -974,6 +997,23 @@ def print_on_training(
for k in train_results[fitting_key].keys():
print_str += prop_fmt % (train_results[fitting_key][k])
print_str += " %8.1e\n" % cur_lr_dict[fitting_key]
log.info(
format_training_message_per_task(
batch=cur_batch,
task_name=f"{fitting_key}_trn",
rmse=train_results[fitting_key],
learning_rate=cur_lr_dict[fitting_key],
)
)
if valid_results is not None:
log.info(
format_training_message_per_task(
batch=cur_batch,
task_name=f"{fitting_key}_val",
rmse=valid_results[fitting_key],
learning_rate=None,
)
)
fp.write(print_str)
fp.flush()

Expand Down

0 comments on commit e826260

Please sign in to comment.