Skip to content

Commit

Permalink
[dev] sample and save exported data to npz, better progress
Browse files Browse the repository at this point in the history
  • Loading branch information
BanananaFish committed Apr 13, 2024
1 parent b24f363 commit 2fb0c1e
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 74 deletions.
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
/exports/**/*.txt
/exports/**/*.pkl
/exports/raw/*
/exports/*
/models/**/*.mph
/models/**/*.mph.lock
/models/**/*.mph.bk
Expand Down
47 changes: 34 additions & 13 deletions config/cell.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,45 @@
cell:
r:
init: 0.00000003
max: 0.00000032
min: 0.0000001
step: 0.00000003
hudu:
init: 0.001
max: 180
min: 5
step: 20
zeta:
init: 0.0005
init: 0.001
max: 180
min: 0
step: 30
rr:
step: 20
g:
init: 0.001
max: 0.0045
min: 0.0035
step: 0.0005
rrr:
max: 0.0000006
min: 0.0000001
step: 0.0000001
rr:
init: 0.001
max: 0.003
min: 0.001
step: 0.001
max: 0.00000003
min: 0.000000005
step: 0.000000005

export:
# 导出的目录
dir: "exports/0"
# 哪些导出关键字要进行采样压缩,只要含有关键字的导出标签都会被计算,请写成列表形式,注意 - 和 " 之间有空格
sample_keys:
- "flied"
# - "bd"
# - "export_foo"
# - "..."

train:
epoch: 200
epoch: 500
batch_size: 16
lr: 0.001
lr: 0.0001
early_stop: 30

dataset:
sampler: four_points # six_points, four_points
sampler: six_points # six_points, four_points
27 changes: 21 additions & 6 deletions src/comsol/cmdline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import click
from rich.console import Console
from rich.progress import Progress
from traitlets import default

from comsol.console import console

Expand All @@ -15,24 +16,38 @@ def main():
@click.option("--config", help="Path to the config yaml.", default="config/cell.yaml")
@click.option("--dump", help="Dump model file after study", is_flag=True)
@click.option("--raw", help="Save raw exported data", is_flag=True)
def run(model, config, raw, dump):
@click.option("--avg", help="Save grid avg exported data, infer in cfg", is_flag=True)
@click.option("--sample", help="sampled frac", default=0.1)
def run(model, config, dump, raw, avg, sample):
console.log(":baseball: [bold magenta italic]Comsol CLI! by Bananafish[/]")
from comsol.interface import Comsol
from comsol.utils import Config

console.log(f"Running model {model}, CFG: {config}, dump: {dump}, raw: {raw}")
console.log(
f"Running model {model}, CFG: {config}, dump: {dump}, raw: {raw}, sample: {sample}, avg: {avg}"
)
cfg = Config(config)
cli = Comsol(model, *cfg.params)
cli = Comsol(model, cfg["export"]["dir"], *cfg.params)

with Progress(console=console) as progress:
study_tast = progress.add_task("[cyan]Study", total=len(cfg.tasks))
for task in cfg.tasks:
# cli.update(**task)
# cli.study()
cli.study_count += 1
if raw:
cli.save_raw_data()
if raw or avg or sample:
if raw:
cli.save_raw_data()
if avg:
cli.save_avg_data(cfg["export"]["grid_avg"])
if sample:
cli.save_sampled_data(
frac=sample,
sample_keys=cfg["export"]["sample_keys"],
progress=progress,
)
else:
cli.save()
console.log("[red]No save option selected")
if dump:
cli.dump()
progress.update(study_tast, advance=1)
Expand Down
33 changes: 33 additions & 0 deletions src/comsol/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

import numpy as np
import pandas as pd


def sample_cood(csv_path: Path, frac: float = 0.1):
# 读取整个文件
with open(csv_path, "r", encoding="utf-8") as f:
lines = f.readlines()

# 找到最后一个以%开头的行
last_comment_line = 0
for i, line in enumerate(lines):
if line.startswith("%"):
last_comment_line = i
df = pd.read_csv(
csv_path,
skiprows=last_comment_line + 1,
header=None,
names=lines[last_comment_line].strip("%\n").split(","),
)
df = df.sample(frac=0.1)

arr = df.values
return arr


def compress_save(arr: np.ndarray, save_path: Path):
np.savez_compressed(save_path, arr)


def grid_avg(csv_path: Path): ...
151 changes: 100 additions & 51 deletions src/comsol/interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pickle
import shutil
from os import PathLike
from pathlib import Path
from typing import List, TypeVar

import mph
import numpy as np
from typing_extensions import deprecated

from comsol.console import console
from comsol.csv import compress_save, grid_avg, sample_cood

T = TypeVar("T", int, float)

Expand All @@ -25,55 +28,59 @@ def filter(self, val):


class Comsol:
def __init__(self, model: PathLike | str, *optim_params: Param) -> None:
def __init__(
self, model: PathLike | str, export_dir: PathLike | str, *optim_params: Param
) -> None:
self.client: mph.Client = mph.start()
self.cell: mph.Model = self.client.load(model)

self.export_file = Path("exports") / "res.txt"
self.export_dir = Path(export_dir)

self.params_filter = {param.name: param for param in optim_params}

self.study_count = 0

@deprecated("parse_res is deprecated")
def parse_res(self):
self.cell.export()
with open(self.export_file, "r") as file:
lines = file.readlines()
xy_values = [
list(map(float, ",".join(line.split()).split(",")))
for line in lines
if not line.startswith("%")
]
arr = np.array(xy_values)
return arr
# with open(self.export_file, "r") as file:
# lines = file.readlines()
# xy_values = [
# list(map(float, ",".join(line.split()).split(",")))
# for line in lines
# if not line.startswith("%")
# ]
# arr = np.array(xy_values)
# return arr

@property
def params(self):
return {param: self.cell.parameter(param) for param in self.params_filter}

@property
@deprecated("data property is deprecated")
def data(self):
self.cell.export()
if Path(self.export_file).exists():
arr = self.parse_res()
arr_sorted = arr[arr[:, 0].argsort()] # 按照x值对arr进行排序

min_values: dict[
float, List[float]
] = {} # 初始化一个空的字典来存储每个x值的最小的两个数

for x, y in arr_sorted:
if x not in min_values:
min_values[x] = [y]
else:
if len(min_values[x]) < 2:
min_values[x].append(y)
else:
max_value = max(min_values[x])
if y < max_value:
min_values[x].remove(max_value)
min_values[x].append(y)
return min_values
# if Path(self.export_file).exists():
# arr = self.parse_res()
# arr_sorted = arr[arr[:, 0].argsort()] # 按照x值对arr进行排序

# min_values: dict[
# float, List[float]
# ] = {} # 初始化一个空的字典来存储每个x值的最小的两个数

# for x, y in arr_sorted:
# if x not in min_values:
# min_values[x] = [y]
# else:
# if len(min_values[x]) < 2:
# min_values[x].append(y)
# else:
# max_value = max(min_values[x])
# if y < max_value:
# min_values[x].remove(max_value)
# min_values[x].append(y)
# return min_values

def update(self, **kwargs):
for key, value in kwargs.items():
Expand All @@ -86,40 +93,82 @@ def update(self, **kwargs):
def study(self):
console.log(f"# {self.study_count + 1} Solving...")
self.cell.solve()
self.cell.export()
self.study_count += 1

def save(self):
def save(self, raw=False, avg=True):
if raw:
return self.save_raw_data()
elif avg:
return self.save_avg_data()
console.log("[red]No save option selected")

@deprecated("save_pkl is deprecated, use save_raw_data / save_avg_data instead")
def save_pkl(self):
self.cell.export()
dest = Path("exports") / "saved" / f"res_{self.study_count:05d}.pkl"
dest.parent.mkdir(parents=True, exist_ok=True)
if self.export_file.exists():
with open(dest, "wb") as f:
pickle.dump((self.params, self.parse_res()), f)
console.log(f"Results saved to {dest}")
dest_dir = self.export_dir / "raw" / f"study_{self.study_count:05d}"
dest_dir.mkdir(parents=True, exist_ok=True)
with open(dest_dir / f"res_{self.study_count:05d}.pkl", "wb") as f:
pickle.dump((self.params, self.parse_res()), f)
console.log(f"Results saved to {dest_dir}")

def save_raw_data(self):
dest_dir = Path("exports") / "raw" / f"study_{self.study_count:05d}"
dest_dir = self.export_dir / "raw" / f"study_{self.study_count:05d}"
dest_dir.mkdir(parents=True, exist_ok=True)
export_tasks = self.cell.exports()
for name in export_tasks:
csv = ".." / dest_dir / f"{name}.csv"
console.log(f"Results({name}) saved to {csv}")
self.cell.export(name, csv)

def save_avg_data(self, avg_list: List[str] = ["flied"]):
raise NotImplementedError("save_avg_data is not implemented")
dest_dir = self.export_dir / "raw" / f"study_{self.study_count:05d}"
tmp_dir = self.export_dir / "avg" / "tmp"
dest_dir.mkdir(parents=True, exist_ok=True)
tmp_dir.mkdir(parents=True, exist_ok=True)
export_tasks = self.cell.exports()
for task in export_tasks:
csv_name = f"{task}.csv"
self.cell.export(task, (tmp_dir / csv_name).absolute())
if task in avg_list:
grid_avg(tmp_dir / csv_name)
console.log(f"Results({task}) cal grid avg")
shutil.copy(tmp_dir / csv_name, dest_dir / csv_name)
console.log(f"Results({task}) saved to {dest_dir}")

def save_sampled_data(
self, frac: float, sample_keys: List[str], console=console, progress=None
):
dest_dir = self.export_dir / "sampled" / f"study_{self.study_count:05d}"
tmp_dir = self.export_dir / "tmp"
dest_dir.mkdir(parents=True, exist_ok=True)
tmp_dir.mkdir(parents=True, exist_ok=True)
export_tasks = self.cell.exports()
if progress:
sampled_task = progress.add_task(
f"[light_cyan3]Sampled", total=len(export_tasks)
)
for task in export_tasks:
csv_name = f"{task}.csv"
# 导出的工作路径使用的是 cell 模型的路径,所以目标需要使用绝对路径
self.cell.export(task, (tmp_dir / csv_name).absolute())
if any(sample_key in task for sample_key in sample_keys):
arr = sample_cood(tmp_dir / csv_name)
compress_save(arr, dest_dir / f"{task}.npz")
console.log(f"Results({task}) sampled! frac: {frac:.3f}")
else:
shutil.copy(tmp_dir / csv_name, dest_dir / csv_name)
console.log(f"Results({task}) skip sample, saved to {dest_dir}")

if progress:
progress.update(sampled_task, advance=1)

if progress:
progress.stop_task(sampled_task)
progress.remove_task(sampled_task)

def dump(self):
dest = Path("models") / "saved" / f"cell_{self.study_count:05d}.mph"
dest.parent.mkdir(parents=True, exist_ok=True)
self.cell.save(dest)
console.log(f"Model dumped to {dest}")


if __name__ == "__main__":
comsol = Comsol(
"models/cell.mph", *[Param(name, 0, 3) for name in ["r", "rr", "p"]]
)

console.log(comsol.data)
comsol.update(r=0.0015)
comsol.study()
comsol.dump()
2 changes: 1 addition & 1 deletion src/comsol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class Config:
def __init__(self, config_file):
with open(config_file, "r") as f:
with open(config_file, "r", encoding="utf-8") as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)

def __getitem__(self, key):
Expand Down

0 comments on commit 2fb0c1e

Please sign in to comment.