This commit is contained in:
2026-05-17 08:47:25 +08:00
parent 1159d86b8b
commit 45513fe334
27 changed files with 4734 additions and 2 deletions
+454
View File
@@ -0,0 +1,454 @@
"""
dynamics.py
------
统一入口:读取 YAML 配置文件 → 运行模拟 → 抽帧 → 绘图(可选)
用法:
python dynamics.py # 使用 input/parameters.yaml
python dynamics.py input/parameters.yaml # 指定配置文件
python dynamics.py --config input/parameters.yaml --no-plot
"""
import os
import sys
import subprocess
import argparse
from contextlib import contextmanager
from pathlib import Path
import yaml
import numpy as np
# 导入同目录下的模块
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def read_optional_index(data, key, default_value):
"""Read an optional integer index from structured txt metadata."""
if key not in data:
return default_value
value = data[key]
if value is None or int(value) < 0:
return default_value
return int(value)
def load_yaml_config(config_path):
"""从 YAML 文件加载配置字典。"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def resolve_path(base_dir, path_value):
"""Resolve a path relative to the runtime base directory."""
path = Path(path_value)
if path.is_absolute():
return path
return (Path(base_dir) / path).resolve()
@contextmanager
def runtime_dir_overrides(input_dir=None, output_dir=None):
"""Temporarily override runtime input/output directories for compute helpers."""
old_input = os.environ.get("DYNAMICS_INPUT_DIR")
old_output = os.environ.get("DYNAMICS_OUTPUT_DIR")
try:
if input_dir is not None:
os.environ["DYNAMICS_INPUT_DIR"] = str(Path(input_dir).resolve())
if output_dir is not None:
os.environ["DYNAMICS_OUTPUT_DIR"] = str(Path(output_dir).resolve())
yield
finally:
if old_input is None:
os.environ.pop("DYNAMICS_INPUT_DIR", None)
else:
os.environ["DYNAMICS_INPUT_DIR"] = old_input
if old_output is None:
os.environ.pop("DYNAMICS_OUTPUT_DIR", None)
else:
os.environ["DYNAMICS_OUTPUT_DIR"] = old_output
def build_sample_indices(total_steps, sample_step, sample_start, sample_end):
"""Validate sampling settings and build frame indices."""
if sample_step <= 0:
raise ValueError(f"NSTEP 必须为正整数,实际为 {sample_step}")
if sample_start < 0:
raise ValueError(f"sample_start 不能小于 0,实际为 {sample_start}")
if sample_end > total_steps:
raise ValueError(
f"sample_end 不能大于记录步数 {total_steps},实际为 {sample_end}")
if sample_start >= sample_end:
raise ValueError(
f"sample_start 必须小于 sample_end,实际为 [{sample_start}, {sample_end})")
n_frames = (sample_end - sample_start) // sample_step
if n_frames <= 0:
raise ValueError(
f"抽帧范围 [{sample_start}, {sample_end}) 过短,按 NSTEP={sample_step} 无法抽出任何帧")
indices = np.arange(n_frames, dtype=np.int64) * sample_step + sample_start
return indices
def save_display_txt(data, out_dir=None):
"""将抽帧数据保存到 output/display.txt(含所有参数元数据)。"""
if out_dir is None:
out_dir = os.path.dirname(os.path.abspath(__file__))
disp_path = os.path.join(compute.get_output_dir(out_dir), "display.txt")
compute.save_text_data(disp_path, data)
print(f"[sample] 显示数组已保存至: {disp_path}")
return disp_path
def run_case(config_path, runtime_base, input_dir="input", output_dir="output", no_plot=False):
"""Run one case with explicit program path, input path, and output path."""
runtime_base = Path(runtime_base).resolve()
input_dir_path = resolve_path(runtime_base, input_dir)
output_dir_path = resolve_path(runtime_base, output_dir)
config_path = resolve_path(runtime_base, config_path)
with runtime_dir_overrides(input_dir_path, output_dir_path):
# 1. 加载 YAML 配置
config = load_yaml_config(config_path)
print(f"[run] 已加载配置: {config_path}")
# 显示步骤控制信息
steps_info = {k: config.get(k, 1) for k in ["step_simulate", "step_sample", "step_plot", "step_animation"]}
step_flags = ", ".join(f"{k}={v}" for k, v in steps_info.items())
print(f"[run] 步骤开关: {step_flags}")
warmup = config.get("warmup_steps", 0)
ss = config.get("sample_start", "从头")
se = config.get("sample_end", "到尾")
method = config.get("method", "explicit_euler")
coord_file = config.get("coord_file", os.path.join("input", "coord.txt"))
plot_atom = config.get("plot_atom", "第一个原子")
print(f"[run] 算法: {method}")
print(f"[run] 坐标文件: {coord_file}")
print(f"[run] 绘图/动画原子序号: {plot_atom}")
print(f"[run] 步骤控制: 预热={warmup}步, 抽帧范围=[{ss}, {se})")
output_dir_abs = compute.get_output_dir(str(runtime_base))
traj_path = os.path.join(output_dir_abs, "trajectory.txt")
disp_path = os.path.join(output_dir_abs, "display.txt")
# ── 自动缓存检测 ───────────────────────────────────────
# 若 output/ 中 trajectory.txt 和 display.txt 均已存在,
# 自动跳过模拟和抽帧,直接使用已有结果。
# 如需强制重新计算,删除 output/ 目录或设 step_simulate: 1 即可。
output_exists = (
os.path.isdir(output_dir_abs)
and os.path.exists(traj_path)
and os.path.exists(disp_path)
)
if output_exists:
if config.get("step_simulate", 1):
print(f"[run] 检测到已有输出({traj_path}),自动跳过模拟与抽帧,直接进入后续步骤")
config["step_simulate"] = 0
config["step_sample"] = 0
else:
print(f"[run] 已有输出,步骤已被跳过")
else:
# 目录存在但文件不全 → 强制重新计算
if os.path.isdir(output_dir_abs):
print(f"[run] output/ 目录存在但文件不完整,将重新计算")
else:
print(f"[run] output/ 目录不存在,将执行完整流程")
# 2. 运行物理模拟 → output/trajectory.txt
if config.get("step_simulate", 1):
engine = config.get("engine", "python")
total_steps = config["NT"]
record_steps = total_steps - (config.get("warmup_steps") or 0)
print(f"[run] 开始计算 总步数={total_steps} 记录步数={record_steps} DT={config['DT']}")
if engine == "python":
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_from_config(config, str(runtime_base))
print(f"[run] 计算完成,记录 {record_steps}")
compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base))
else:
# 外部引擎:先加载配置到全局变量,再运行引擎,再用 save_trajectory_txt 补全 metadata
config["_skip_run"] = True
compute.run_from_config(config, str(runtime_base))
config.pop("_skip_run", None)
input_dir_abs = str(input_dir_path.resolve())
output_dir_abs = str(output_dir_path.resolve())
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_engine(
engine, input_dir_abs, output_dir_abs, config)
compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base))
else:
print("[run] 步骤 [模拟] 已跳过,直接加载已有轨迹")
if not os.path.exists(traj_path):
print(f"[run] 错误: trajectory.txt 不存在,无法跳过模拟")
sys.exit(1)
# 3. 抽帧 → output/display.txt
traj_path = os.path.join(output_dir_abs, "trajectory.txt")
data = compute.load_text_data(traj_path)
NT = int(data["NT"]); DT = float(data["DT"]); NSTEP = int(data["NSTEP"])
warmup_steps = int(data.get("warmup_steps", 0))
plot_atom_row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0
plot_atom_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else int(data["atom_ids"][plot_atom_row])
# 抽帧范围控制
sample_start = read_optional_index(data, "sample_start", 0)
sample_end = read_optional_index(data, "sample_end", NT)
indices = build_sample_indices(NT, NSTEP, sample_start, sample_end)
n_frames = len(indices)
print(f"[run] 抽帧范围: [{sample_start}, {sample_end}), 共 {n_frames}")
traj_x = data["traj_x"]
traj_y = data["traj_y"]
traj_z = data["traj_z"]
traj_vx = data["traj_vx"]
traj_vy = data["traj_vy"]
traj_vz = data["traj_vz"]
if traj_x.ndim == 1:
selected_x = traj_x
selected_y = traj_y
selected_z = traj_z
selected_vx = traj_vx
selected_vy = traj_vy
selected_vz = traj_vz
all_x = traj_x[:, None]
all_y = traj_y[:, None]
all_z = traj_z[:, None]
all_vx = traj_vx[:, None]
all_vy = traj_vy[:, None]
all_vz = traj_vz[:, None]
else:
selected_x = traj_x[:, plot_atom_row]
selected_y = traj_y[:, plot_atom_row]
selected_z = traj_z[:, plot_atom_row]
selected_vx = traj_vx[:, plot_atom_row]
selected_vy = traj_vy[:, plot_atom_row]
selected_vz = traj_vz[:, plot_atom_row]
all_x = traj_x
all_y = traj_y
all_z = traj_z
all_vx = traj_vx
all_vy = traj_vy
all_vz = traj_vz
if config.get("step_sample", 1):
disp_data = {
"disp_x": selected_x[indices],
"disp_y": selected_y[indices],
"disp_z": selected_z[indices],
"disp_vx": selected_vx[indices],
"disp_vy": selected_vy[indices],
"disp_vz": selected_vz[indices],
"disp_all_x": all_x[indices],
"disp_all_y": all_y[indices],
"disp_all_z": all_z[indices],
"disp_all_vx": all_vx[indices],
"disp_all_vy": all_vy[indices],
"disp_all_vz": all_vz[indices],
"disp_t": indices * DT,
"disp_step": indices,
"n_frames": n_frames,
"NT": NT, "DT": DT, "NSTEP": NSTEP,
"plot_atom_id": plot_atom_id,
"plot_atom_row": plot_atom_row,
"method": str(data["method"]) if "method" in data else "explicit_euler",
"coord_file": str(data["coord_file"]) if "coord_file" in data else os.path.join("input", "coord.txt"),
"atom_ids": data["atom_ids"] if "atom_ids" in data else np.array([1]),
"atom_masses": data["atom_masses"] if "atom_masses" in data else np.array([float(data["M"])]),
"atom_radii": data["atom_radii"] if "atom_radii" in data else np.array([float(data["ball_radius"])]),
"atom_positions": data["atom_positions"] if "atom_positions" in data else np.array([[float(data["X0"]), float(data["Y0"]), float(data["Z0"])]]),
"atom_velocities": data["atom_velocities"] if "atom_velocities" in data else np.array([[float(data["VX0"]), float(data["VY0"]), float(data["VZ0"])]]),
"atom_fixed": data["atom_fixed"] if "atom_fixed" in data else np.array([[0, 0, 0]]),
"bond_pairs": data.get("bond_pairs", np.zeros((0, 2), dtype=np.int64)).tolist(),
"warmup_steps": warmup_steps,
"sample_start": sample_start,
"sample_end": sample_end,
"X_MIN": float(data["X_MIN"]), "X_MAX": float(data["X_MAX"]),
"Y_MIN": float(data["Y_MIN"]), "Y_MAX": float(data["Y_MAX"]),
"Z_MIN": float(data["Z_MIN"]), "Z_MAX": float(data["Z_MAX"]),
"X0": float(data["X0"]), "Y0": float(data["Y0"]), "Z0": float(data["Z0"]),
"VX0": float(data["VX0"]), "VY0": float(data["VY0"]), "VZ0": float(data["VZ0"]),
"M": float(data["M"]) if "M" in data else 1.0,
"alpha": data["alpha"],
"ball_radius": float(data["ball_radius"]),
"ball_color_r": float(data["ball_color_r"]),
"ball_color_g": float(data["ball_color_g"]),
"ball_color_b": float(data["ball_color_b"]),
"box_color_r": float(data["box_color_r"]),
"box_color_g": float(data["box_color_g"]),
"box_color_b": float(data["box_color_b"]),
}
save_display_txt(disp_data, str(runtime_base))
print(f"[run] 抽帧完成: {sample_end - sample_start} 步 -> {n_frames}")
else:
print("[run] 步骤 [抽帧] 已跳过")
# 4. 绘图(可选)
if not no_plot and config.get("step_plot", 1):
try:
import matplotlib.pyplot as plt
# 配置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
time = np.arange(NT) * DT
n_atoms = all_x.shape[1]
atom_ids_list = data.get("atom_ids", np.arange(n_atoms) + 1)
fig, axes = plt.subplots(3, 3, figsize=(15, 13))
fig.suptitle("轨迹与能量分析", fontsize=16)
# ── 位置 / 速度 6 子图(前 2 行,每行 3 列) ──
plot_configs = [
(axes[0, 0], all_x, "x - 时间"),
(axes[0, 1], all_y, "y - 时间"),
(axes[0, 2], all_z, "z - 时间"),
(axes[1, 0], all_vx, "vx - 时间"),
(axes[1, 1], all_vy, "vy - 时间"),
(axes[1, 2], all_vz, "vz - 时间"),
]
colors = plt.cm.tab10(np.linspace(0, 1, n_atoms))
for ax, data_arr, title in plot_configs:
for i in range(n_atoms):
atom_id = int(atom_ids_list[i])
ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {atom_id}")
ax.set_title(title)
ax.set_xlabel("时间 (s)")
ax.grid(True, alpha=0.3)
ax.legend()
# ── 能量计算 ─────────────────────────────────────
masses = np.array(data["atom_masses"]) # (n_atoms,)
G_vec = np.array(data.get("G", [0.0, 0.0, -9.8])) # [gx, gy, gz]
# 动能 Ek = ½ m v²
ek = 0.5 * masses[np.newaxis, :] * (all_vx**2 + all_vy**2 + all_vz**2)
# 重力势能 Ug = -m G·r
ug = -masses[np.newaxis, :] * (
G_vec[0] * all_x + G_vec[1] * all_y + G_vec[2] * all_z
)
# 弹性势能 Us = ½ k (d - d₀)²
us = np.zeros_like(ek)
bond_pairs = data.get("bond_pairs")
bond_stiffness = data.get("bond_stiffness")
bond_rest_lengths = data.get("bond_rest_lengths")
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = all_x[:, j] - all_x[:, i]
dy = all_y[:, j] - all_y[:, i]
dz = all_z[:, j] - all_z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2
us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2
# 各原子总能量
e_total = ek + ug + us # (NT, n_atoms)
# 系统能量分量
ek_sys = np.sum(ek, axis=1)
ug_sys = np.sum(ug, axis=1)
us_sys = np.sum(us, axis=1)
e_sys = ek_sys + ug_sys + us_sys
# ── 第 3 行左:各原子总能量 ──
ax_e = axes[2, 0]
for i in range(n_atoms):
aid = int(atom_ids_list[i])
ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
ax_e.set_title("各原子总能量")
ax_e.set_xlabel("时间 (s)")
ax_e.set_ylabel("能量")
ax_e.grid(True, alpha=0.3)
ax_e.legend(loc="upper right")
# ── 第 3 行右:系统总能量 ──
ax_sys = axes[2, 1]
ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能")
ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能")
if bond_pairs is not None and len(bond_pairs) > 0:
ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能")
ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量")
ax_sys.set_title("系统总能量")
ax_sys.set_xlabel("时间 (s)")
ax_sys.set_ylabel("能量")
ax_sys.grid(True, alpha=0.3)
ax_sys.legend(loc="upper right")
# 隐藏第 3 行第 3 列空子图
axes[2, 2].set_visible(False)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plot_path = os.path.join(output_dir_abs, "trajectory_plots.png")
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
print(f"[run] 轨迹图表已保存至: {plot_path}")
plt.show()
except ImportError:
print("[run] 警告: 未安装 matplotlib,跳过绘图")
print(f"[run] 完成!输出目录: {output_dir_abs}")
# 5. 自动播放动画(可选)
if config.get("step_animation", 0):
draw_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "draw.py")
if os.path.exists(draw_script):
try:
print("[run] 正在启动 VisPy 3D 动画窗口…")
subprocess.Popen(
[sys.executable, draw_script],
cwd=runtime_base,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except Exception as e:
print(f"[run] 启动动画失败: {e}")
else:
print(f"[run] 未找到动画脚本: {draw_script}")
else:
print("[run] 运行 python draw.py 查看动画。")
def main():
parser = argparse.ArgumentParser(description="物理模拟统一入口")
parser.add_argument("config_file", nargs="?", default=os.path.join("input", "parameters.yaml"),
help="YAML 配置文件路径(默认: input/parameters.yaml")
parser.add_argument("--config", dest="config_override",
help="YAML 配置文件路径(可选,优先于位置参数)")
parser.add_argument("--input-dir", default="input",
help="输入目录路径(默认: input")
parser.add_argument("--output-dir", default="output",
help="输出目录路径(默认: output")
parser.add_argument("--runtime-base", default=".",
help="案例运行根目录(默认: 当前目录)")
parser.add_argument("--no-plot", action="store_true",
help="跳过 matplotlib 绘图")
args = parser.parse_args()
config_path = args.config_override or args.config_file
runtime_base = resolve_path(os.getcwd(), args.runtime_base)
config_path_abs = resolve_path(runtime_base, config_path)
if not os.path.exists(config_path_abs):
print(f"错误: 配置文件不存在: {config_path_abs}")
sys.exit(1)
run_case(
config_path=config_path_abs,
runtime_base=runtime_base,
input_dir=args.input_dir,
output_dir=args.output_dir,
no_plot=args.no_plot,
)
if __name__ == "__main__":
main()