39ff650539
Use Agg backend and show=False only when saving to file. When neither gif nor mp4 is requested, show the animation window interactively. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
438 lines
20 KiB
Python
438 lines
20 KiB
Python
"""
|
||
dynamics.py
|
||
------
|
||
统一入口:读取 YAML 配置文件 → 运行模拟 → 抽帧 → 绘图(可选)
|
||
|
||
用法:
|
||
python dynamics.py # 使用 input/input.txt
|
||
python dynamics.py input/input.txt # 指定配置文件(YAML 格式)
|
||
python dynamics.py --config input/input.txt --no-plot
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import subprocess
|
||
import time
|
||
import argparse
|
||
import json
|
||
from contextlib import contextmanager
|
||
from pathlib import Path
|
||
|
||
import yaml
|
||
import numpy as np
|
||
|
||
# 导入同目录下的模块
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
import compute
|
||
|
||
|
||
def _fmt_alpha(v):
|
||
"""将 alpha 值格式化为逗号分隔字符串,兼容 numpy 数组/list/标量。"""
|
||
if isinstance(v, (list, tuple, np.ndarray)):
|
||
return ",".join(str(float(x)) for x in v)
|
||
return str(float(v))
|
||
|
||
|
||
def _json_field(value):
|
||
"""Serialize arrays/lists for display header metadata."""
|
||
if isinstance(value, np.ndarray):
|
||
value = value.tolist()
|
||
return json.dumps(value, ensure_ascii=False)
|
||
|
||
|
||
def _load_camera_kf(config, runtime_base):
|
||
"""加载 move_camera.txt(速度段格式)→ JSON 字符串。"""
|
||
import re, json
|
||
if not int(config.get("move_camera", 0)):
|
||
return ""
|
||
cam_file = str(config.get("move_camera_file",
|
||
os.path.join("input", "move_camera.txt")))
|
||
cam_path = cam_file
|
||
if not os.path.isabs(cam_file):
|
||
cam_path = os.path.join(runtime_base, cam_file)
|
||
if not os.path.exists(cam_path):
|
||
return ""
|
||
segments = []
|
||
with open(cam_path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
# 解析帧范围:支持 "all"(全程)或 "N-M"(区间)
|
||
if line.lower().startswith("all") or re.match(r'^\s*all\s', line, re.IGNORECASE):
|
||
start, end = 0, 10**9
|
||
else:
|
||
m = re.match(r'(\d+)\s*-\s*(\d+)', line)
|
||
if not m:
|
||
continue
|
||
start, end = int(m.group(1)), int(m.group(2))
|
||
v, r = [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]
|
||
for i, axis in enumerate(['x', 'y', 'z']):
|
||
m2 = re.search(r'v' + axis + r'\s*=\s*([-\d.]+)', line)
|
||
if m2: v[i] = float(m2.group(1))
|
||
m2 = re.search(r'r' + axis + r'\s*=\s*([-\d.]+)', line)
|
||
if m2: r[i] = float(m2.group(1))
|
||
if any(v) or any(r):
|
||
segments.append({"start": start, "end": end, "v": v, "r": r})
|
||
return json.dumps(segments) if segments else ""
|
||
|
||
|
||
def read_optional_index(data, key, default_value):
|
||
"""Read an optional integer index from structured txt metadata."""
|
||
if key not in data:
|
||
return default_value
|
||
value = data[key]
|
||
if value is None or int(value) < 0:
|
||
return default_value
|
||
return int(value)
|
||
|
||
|
||
def load_yaml_config(config_path):
|
||
"""从 YAML 文件加载配置字典。"""
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
return config
|
||
|
||
|
||
def resolve_path(base_dir, path_value):
|
||
"""Resolve a path relative to the runtime base directory."""
|
||
path = Path(path_value)
|
||
if path.is_absolute():
|
||
return path
|
||
return (Path(base_dir) / path).resolve()
|
||
|
||
|
||
@contextmanager
|
||
def runtime_dir_overrides(input_dir=None, output_dir=None):
|
||
"""Temporarily override runtime input/output directories for compute helpers."""
|
||
old_input = os.environ.get("DYNAMICS_INPUT_DIR")
|
||
old_output = os.environ.get("DYNAMICS_OUTPUT_DIR")
|
||
try:
|
||
if input_dir is not None:
|
||
os.environ["DYNAMICS_INPUT_DIR"] = str(Path(input_dir).resolve())
|
||
if output_dir is not None:
|
||
os.environ["DYNAMICS_OUTPUT_DIR"] = str(Path(output_dir).resolve())
|
||
yield
|
||
finally:
|
||
if old_input is None:
|
||
os.environ.pop("DYNAMICS_INPUT_DIR", None)
|
||
else:
|
||
os.environ["DYNAMICS_INPUT_DIR"] = old_input
|
||
if old_output is None:
|
||
os.environ.pop("DYNAMICS_OUTPUT_DIR", None)
|
||
else:
|
||
os.environ["DYNAMICS_OUTPUT_DIR"] = old_output
|
||
|
||
|
||
def build_sample_indices(total_steps, sample_step, sample_start, sample_end):
|
||
"""Validate sampling settings and build frame indices."""
|
||
if sample_step <= 0:
|
||
raise ValueError(f"NSTEP 必须为正整数,实际为 {sample_step}")
|
||
if sample_start < 0:
|
||
raise ValueError(f"sample_start 不能小于 0,实际为 {sample_start}")
|
||
if sample_end > total_steps:
|
||
raise ValueError(
|
||
f"sample_end 不能大于记录步数 {total_steps},实际为 {sample_end}")
|
||
if sample_start >= sample_end:
|
||
raise ValueError(
|
||
f"sample_start 必须小于 sample_end,实际为 [{sample_start}, {sample_end})")
|
||
|
||
n_frames = (sample_end - sample_start) // sample_step
|
||
if n_frames <= 0:
|
||
raise ValueError(
|
||
f"抽帧范围 [{sample_start}, {sample_end}) 过短,按 NSTEP={sample_step} 无法抽出任何帧")
|
||
|
||
indices = np.arange(n_frames, dtype=np.int64) * sample_step + sample_start
|
||
return indices
|
||
|
||
|
||
def run_case(config_path, runtime_base, input_dir="input", output_dir="output", no_plot=False):
|
||
"""Run one case with explicit program path, input path, and output path."""
|
||
runtime_base = Path(runtime_base).resolve()
|
||
input_dir_path = resolve_path(runtime_base, input_dir)
|
||
output_dir_path = resolve_path(runtime_base, output_dir)
|
||
config_path = resolve_path(runtime_base, config_path)
|
||
|
||
with runtime_dir_overrides(input_dir_path, output_dir_path):
|
||
# 1. 加载 YAML 配置
|
||
config = load_yaml_config(config_path)
|
||
print(f"[run] 已加载配置: {config_path}")
|
||
|
||
# ── T_total → NT 自动转换 ──────────────────────
|
||
if "T_total" in config and "NT" not in config:
|
||
dt = float(config["DT"])
|
||
config["NT"] = int(float(config["T_total"]) / dt)
|
||
print(f"[run] T_total={config['T_total']} → NT={config['NT']} (DT={dt})")
|
||
elif "T_total" in config and "NT" in config:
|
||
print(f"[run] 同时指定了 T_total 和 NT,使用 NT={config['NT']}")
|
||
|
||
# 显示步骤控制信息
|
||
steps_info = {k: config.get(k, 1) for k in ["step_simulate", "step_sample", "step_plot", "step_plot_wave", "step_animation"]}
|
||
step_flags = ", ".join(f"{k}={v}" for k, v in steps_info.items())
|
||
print(f"[run] 步骤开关: {step_flags}")
|
||
|
||
warmup = config.get("warmup_steps", 0)
|
||
ss = config.get("sample_start", "从头")
|
||
se = config.get("sample_end", "到尾")
|
||
method = config.get("method", "explicit_euler")
|
||
coord_file = config.get("coord_file", os.path.join("input", "coord.txt"))
|
||
plot_atom = config.get("plot_atom", "第一个原子")
|
||
print(f"[run] 算法: {method}")
|
||
print(f"[run] 坐标文件: {coord_file}")
|
||
print(f"[run] 绘图/动画原子序号: {plot_atom}")
|
||
print(f"[run] 步骤控制: 预热={warmup}步, 抽帧范围=[{ss}, {se})")
|
||
|
||
output_dir_abs = compute.get_output_dir(str(runtime_base))
|
||
traj_path = os.path.join(output_dir_abs, "trajectory.txt")
|
||
disp_path = os.path.join(output_dir_abs, "display.txt")
|
||
|
||
# ── 自动缓存检测 ───────────────────────────────────────
|
||
# force_calc=1: 强制重新计算,忽略缓存
|
||
# force_calc=0: 尊重 step_simulate 设置,不自动覆盖
|
||
force_calc = int(config.get("force_calc", 0))
|
||
if force_calc:
|
||
print(f"[run] force_calc=1,跳过缓存,强制重新计算")
|
||
config["step_simulate"] = 1
|
||
config["step_sample"] = 1
|
||
elif config.get("step_simulate", 1):
|
||
# step_simulate=1 且 force_calc=0 → 按用户要求执行计算
|
||
# 但检测一下参数是否已变更(NT),如果变了则自动更新 step_sample
|
||
if os.path.isdir(output_dir_abs) and os.path.exists(traj_path) and os.path.exists(disp_path):
|
||
cached_nt = None
|
||
try:
|
||
with open(traj_path, 'rb') as _f:
|
||
_f.seek(-4096, 2)
|
||
_tail = _f.read().decode('utf-8', errors='replace')
|
||
import re as _re
|
||
_m = _re.search(r'"NT":\s*(\d+)', _tail)
|
||
if _m:
|
||
cached_nt = int(_m.group(1))
|
||
except Exception:
|
||
pass
|
||
config_nt = int(config.get("NT", 0))
|
||
if cached_nt is not None and cached_nt != config_nt:
|
||
print(f"[run] 参数已变更(缓存 NT={cached_nt},配置 NT={config_nt}),"
|
||
f"将重新计算")
|
||
config["step_sample"] = 1
|
||
else:
|
||
# 参数一致,按 step_simulate=1 执行,step_sample 由用户设置决定
|
||
pass
|
||
else:
|
||
# step_simulate=0 → 检测缓存是否存在并提示
|
||
if os.path.isdir(output_dir_abs) and os.path.exists(traj_path) and os.path.exists(disp_path):
|
||
print(f"[run] 已有输出,步骤已被跳过")
|
||
else:
|
||
print(f"[run] 没有可用的缓存输出,但 step_simulate=0,将跳过模拟")
|
||
|
||
# 2. 运行物理模拟 → output/trajectory.txt
|
||
if config.get("step_simulate", 1):
|
||
engine = config.get("engine", "python")
|
||
total_steps = config["NT"]
|
||
record_steps = total_steps - (config.get("warmup_steps") or 0)
|
||
print(f"[run] 开始计算 总步数={total_steps} 记录步数={record_steps} DT={config['DT']}")
|
||
|
||
import time as _time
|
||
_t0 = _time.time()
|
||
|
||
if engine == "python":
|
||
compute.run_from_config(config, str(runtime_base))
|
||
else:
|
||
# 外部引擎:先加载配置到全局变量,再运行引擎
|
||
config["_skip_run"] = True
|
||
compute.run_from_config(config, str(runtime_base))
|
||
config.pop("_skip_run", None)
|
||
input_dir_abs = str(input_dir_path.resolve())
|
||
output_dir_abs = str(output_dir_path.resolve())
|
||
# 外部引擎写完整 trajectory.txt,后续抽帧
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_engine(
|
||
engine, input_dir_abs, output_dir_abs, config)
|
||
if int(config.get("save_trajectory", 0)):
|
||
compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base))
|
||
|
||
_elapsed = _time.time() - _t0
|
||
print(f"[run] 引擎: {engine} 计算完成: {record_steps} 步 {_elapsed:.3f} s")
|
||
else:
|
||
print("[run] 步骤 [模拟] 已跳过")
|
||
|
||
# 3. 检查/生成 display.txt
|
||
disp_path_new = os.path.join(output_dir_abs, "display.txt")
|
||
save_traj = int(config.get("save_trajectory", 0))
|
||
|
||
if os.path.exists(disp_path_new):
|
||
# Python 引擎或新版外部引擎(save_trajectory=0)已直接写入
|
||
print(f"[run] 发现已有 display.txt(引擎直接抽帧)")
|
||
elif engine != "python" and os.path.exists(os.path.join(output_dir_abs, "trajectory.txt")):
|
||
# 旧版外部引擎:从 trajectory.txt 抽帧
|
||
traj_path = os.path.join(output_dir_abs, "trajectory.txt")
|
||
if not os.path.exists(traj_path):
|
||
print(f"[run] 错误: 找不到 trajectory.txt 或 display.txt")
|
||
sys.exit(1)
|
||
data = compute.load_text_data(traj_path)
|
||
NT = int(data["NT"]); DT = float(data["DT"]); NSTEP = int(data.get("NSTEP", 1))
|
||
record_steps = NT - int(data.get("warmup_steps", 0))
|
||
n_atoms = len(data["atom_ids"])
|
||
sample_start = 0
|
||
sample_end = NT
|
||
indices = np.arange(0, record_steps, NSTEP, dtype=np.int64)
|
||
if len(indices) == 0:
|
||
indices = np.array([0])
|
||
|
||
traj_x = data["traj_x"]; traj_y = data["traj_y"]; traj_z = data["traj_z"]
|
||
traj_vx = data["traj_vx"]; traj_vy = data["traj_vy"]; traj_vz = data["traj_vz"]
|
||
|
||
# 构建 header_fields
|
||
hf = {"DT": str(DT), "NSTEP": str(NSTEP), "method": str(data.get("method", "")),
|
||
"warmup_steps": str(data.get("warmup_steps", 0)),
|
||
"dynamic_steps": str(record_steps),
|
||
"T_total": str(NT * DT),
|
||
"X_MAX": str(data.get("X_MAX", 10)), "X_MIN": str(data.get("X_MIN", -10)),
|
||
"Y_MAX": str(data.get("Y_MAX", 10)), "Y_MIN": str(data.get("Y_MIN", -10)),
|
||
"Z_MAX": str(data.get("Z_MAX", 10)), "Z_MIN": str(data.get("Z_MIN", -10)),
|
||
"ball_radius": str(data.get("ball_radius", 0.5)),
|
||
"ball_color_r": str(data.get("ball_color_r", 0.9)),
|
||
"ball_color_g": str(data.get("ball_color_g", 0.2)),
|
||
"ball_color_b": str(data.get("ball_color_b", 0.2)),
|
||
"box_color_r": str(data.get("box_color_r", 0.8)),
|
||
"box_color_g": str(data.get("box_color_g", 0.8)),
|
||
"box_color_b": str(data.get("box_color_b", 0.85)),
|
||
"gravity_field": str(data.get("gravity_field", 1)),
|
||
"gravity_interaction": str(data.get("gravity_interaction", 0)),
|
||
"elastic_force": str(data.get("elastic_force", 1)),
|
||
"damping_force": str(data.get("damping_force", 0)),
|
||
"gravity_strength": str(data.get("gravity_strength", 1.0)),
|
||
"driving_force": str(data.get("driving_force", 0)),
|
||
"use_marker": str(config.get("use_marker", 0)),
|
||
"alpha": _fmt_alpha(data.get("alpha", 0.2)),
|
||
"atom_masses": _json_field(data.get("atom_masses", [])),
|
||
"atom_positions": _json_field(data.get("atom_positions", [])),
|
||
"bond_pairs": _json_field(data.get("bond_pairs", [])),
|
||
"bond_stiffness": _json_field(data.get("bond_stiffness", [])),
|
||
"bond_rest_lengths": _json_field(data.get("bond_rest_lengths", [])),
|
||
"G": _json_field(data.get("G", [0.0, 0.0, 0.0])),
|
||
"atom_radii": _fmt_alpha(data.get("atom_radii", [])),
|
||
"camera_distance": str(config.get("camera_distance", 40.0)),
|
||
"camera_elevation": str(config.get("camera_elevation", 0)),
|
||
"camera_azimuth": str(config.get("camera_azimuth", 0)),
|
||
"camera_keyframes": _load_camera_kf(config, str(runtime_base))}
|
||
|
||
n_frames = len(indices)
|
||
compute.save_display_txt(
|
||
disp_path_new,
|
||
traj_x[indices], traj_y[indices], traj_z[indices],
|
||
traj_vx[indices], traj_vy[indices], traj_vz[indices],
|
||
np.array(data["atom_ids"]), n_frames, n_atoms,
|
||
header_fields=hf)
|
||
print(f"[run] 从 trajectory.txt 抽帧生成 display.txt ({n_frames} 帧)")
|
||
|
||
# save_trajectory=0 时清理 trajectory.txt
|
||
if not save_traj:
|
||
try:
|
||
os.remove(traj_path)
|
||
print(f"[run] save_trajectory=0,已删除 {traj_path}")
|
||
except OSError:
|
||
pass
|
||
|
||
# 4. 绘图(可选)
|
||
if not no_plot and config.get("step_plot", 1):
|
||
print("[run] 注意: 旧版 step_plot 绘图路径依赖完整轨迹局部变量,当前已暂时跳过。")
|
||
print("[run] 如需波形/能量动画,请使用 step_plot_wave: 1。")
|
||
|
||
print(f"[run] 完成!输出目录: {output_dir_abs}")
|
||
|
||
# 5. 自动播放动画(可选)
|
||
if config.get("step_animation", 0):
|
||
draw_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "draw.py")
|
||
if not os.path.exists(draw_script):
|
||
print(f"[run] 未找到动画脚本: {draw_script}")
|
||
else:
|
||
# 检查 display.txt 是否存在(step_sample=0 时可能没有)
|
||
disp_path = os.path.join(output_dir_abs, "display.txt")
|
||
if not os.path.exists(disp_path):
|
||
print(f"[run] 错误: 找不到 {disp_path}")
|
||
print(f"[run] 启动动画需要先运行抽帧(step_sample: 1),或手动保留 output/display.txt")
|
||
else:
|
||
try:
|
||
print("[run] 正在启动 VisPy 3D 动画窗口…")
|
||
ansi_log = os.path.join(output_dir_abs, "animation.log")
|
||
if sys.platform == "win32":
|
||
# Windows 上给子进程独立控制台,避免父进程退出时连带关闭
|
||
creation_flags = subprocess.CREATE_NEW_CONSOLE
|
||
else:
|
||
creation_flags = 0
|
||
proc = subprocess.Popen(
|
||
[sys.executable, draw_script, output_dir_abs],
|
||
cwd=runtime_base,
|
||
stdout=subprocess.DEVNULL,
|
||
stderr=open(ansi_log, "w", encoding="utf-8"),
|
||
creationflags=creation_flags,
|
||
)
|
||
# 等待半秒检查子进程是否成功启动(未立即崩溃)
|
||
time.sleep(0.5)
|
||
if proc.poll() is not None:
|
||
print(f"[run] ⚠ 动画进程已退出,返回码={proc.returncode}")
|
||
print(f"[run] 请查看错误日志: {ansi_log}")
|
||
else:
|
||
print(f"[run] VisPy 动画窗口已启动(PID={proc.pid})")
|
||
except Exception as e:
|
||
print(f"[run] 启动动画失败: {e}")
|
||
else:
|
||
print("[run] 运行 python draw.py 查看动画。")
|
||
|
||
# 6. 波形能量动画(可选)
|
||
if config.get("step_plot_wave", 0):
|
||
try:
|
||
_save_gif = int(config.get("plot_wave_save_gif", 0))
|
||
_save_mp4 = int(config.get("plot_wave_save_mp4", 0))
|
||
_to_file = bool(_save_gif or _save_mp4)
|
||
if _to_file:
|
||
import matplotlib
|
||
matplotlib.use("Agg") # 保存文件时用非交互式后端
|
||
import plot_wave as pw
|
||
print("[run] 正在绘制波形与能量图…")
|
||
gif = pw.plot_wave(
|
||
str(output_dir_abs),
|
||
save_gif=_save_gif,
|
||
save_mp4=_save_mp4,
|
||
show=not _to_file, # 不保存文件时弹出交互窗口
|
||
)
|
||
if gif:
|
||
print(f"[run] 波形 GIF 已保存: {gif}")
|
||
except Exception as e:
|
||
import traceback; traceback.print_exc()
|
||
print(f"[run] 绘制波形图失败: {e}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="物理模拟统一入口")
|
||
parser.add_argument("config_file", nargs="?", default=os.path.join("input", "input.txt"),
|
||
help="YAML 配置文件路径(默认: input/input.txt,虽然是 .txt 后缀但使用 YAML 格式)")
|
||
parser.add_argument("--config", dest="config_override",
|
||
help="YAML 配置文件路径(可选,优先于位置参数)")
|
||
parser.add_argument("--input-dir", default="input",
|
||
help="输入目录路径(默认: input)")
|
||
parser.add_argument("--output-dir", default="output",
|
||
help="输出目录路径(默认: output)")
|
||
parser.add_argument("--runtime-base", default=".",
|
||
help="案例运行根目录(默认: 当前目录)")
|
||
parser.add_argument("--no-plot", action="store_true",
|
||
help="跳过 matplotlib 绘图")
|
||
args = parser.parse_args()
|
||
|
||
config_path = args.config_override or args.config_file
|
||
runtime_base = resolve_path(os.getcwd(), args.runtime_base)
|
||
config_path_abs = resolve_path(runtime_base, config_path)
|
||
if not os.path.exists(config_path_abs):
|
||
print(f"错误: 配置文件不存在: {config_path_abs}")
|
||
sys.exit(1)
|
||
run_case(
|
||
config_path=config_path_abs,
|
||
runtime_base=runtime_base,
|
||
input_dir=args.input_dir,
|
||
output_dir=args.output_dir,
|
||
no_plot=args.no_plot,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|