Files
dynamics/dynamics.py
T
2026-05-17 08:47:25 +08:00

455 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
dynamics.py
------
统一入口:读取 YAML 配置文件 → 运行模拟 → 抽帧 → 绘图(可选)
用法:
python dynamics.py # 使用 input/parameters.yaml
python dynamics.py input/parameters.yaml # 指定配置文件
python dynamics.py --config input/parameters.yaml --no-plot
"""
import os
import sys
import subprocess
import argparse
from contextlib import contextmanager
from pathlib import Path
import yaml
import numpy as np
# 导入同目录下的模块
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def read_optional_index(data, key, default_value):
"""Read an optional integer index from structured txt metadata."""
if key not in data:
return default_value
value = data[key]
if value is None or int(value) < 0:
return default_value
return int(value)
def load_yaml_config(config_path):
"""从 YAML 文件加载配置字典。"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def resolve_path(base_dir, path_value):
"""Resolve a path relative to the runtime base directory."""
path = Path(path_value)
if path.is_absolute():
return path
return (Path(base_dir) / path).resolve()
@contextmanager
def runtime_dir_overrides(input_dir=None, output_dir=None):
"""Temporarily override runtime input/output directories for compute helpers."""
old_input = os.environ.get("DYNAMICS_INPUT_DIR")
old_output = os.environ.get("DYNAMICS_OUTPUT_DIR")
try:
if input_dir is not None:
os.environ["DYNAMICS_INPUT_DIR"] = str(Path(input_dir).resolve())
if output_dir is not None:
os.environ["DYNAMICS_OUTPUT_DIR"] = str(Path(output_dir).resolve())
yield
finally:
if old_input is None:
os.environ.pop("DYNAMICS_INPUT_DIR", None)
else:
os.environ["DYNAMICS_INPUT_DIR"] = old_input
if old_output is None:
os.environ.pop("DYNAMICS_OUTPUT_DIR", None)
else:
os.environ["DYNAMICS_OUTPUT_DIR"] = old_output
def build_sample_indices(total_steps, sample_step, sample_start, sample_end):
"""Validate sampling settings and build frame indices."""
if sample_step <= 0:
raise ValueError(f"NSTEP 必须为正整数,实际为 {sample_step}")
if sample_start < 0:
raise ValueError(f"sample_start 不能小于 0,实际为 {sample_start}")
if sample_end > total_steps:
raise ValueError(
f"sample_end 不能大于记录步数 {total_steps},实际为 {sample_end}")
if sample_start >= sample_end:
raise ValueError(
f"sample_start 必须小于 sample_end,实际为 [{sample_start}, {sample_end})")
n_frames = (sample_end - sample_start) // sample_step
if n_frames <= 0:
raise ValueError(
f"抽帧范围 [{sample_start}, {sample_end}) 过短,按 NSTEP={sample_step} 无法抽出任何帧")
indices = np.arange(n_frames, dtype=np.int64) * sample_step + sample_start
return indices
def save_display_txt(data, out_dir=None):
"""将抽帧数据保存到 output/display.txt(含所有参数元数据)。"""
if out_dir is None:
out_dir = os.path.dirname(os.path.abspath(__file__))
disp_path = os.path.join(compute.get_output_dir(out_dir), "display.txt")
compute.save_text_data(disp_path, data)
print(f"[sample] 显示数组已保存至: {disp_path}")
return disp_path
def run_case(config_path, runtime_base, input_dir="input", output_dir="output", no_plot=False):
"""Run one case with explicit program path, input path, and output path."""
runtime_base = Path(runtime_base).resolve()
input_dir_path = resolve_path(runtime_base, input_dir)
output_dir_path = resolve_path(runtime_base, output_dir)
config_path = resolve_path(runtime_base, config_path)
with runtime_dir_overrides(input_dir_path, output_dir_path):
# 1. 加载 YAML 配置
config = load_yaml_config(config_path)
print(f"[run] 已加载配置: {config_path}")
# 显示步骤控制信息
steps_info = {k: config.get(k, 1) for k in ["step_simulate", "step_sample", "step_plot", "step_animation"]}
step_flags = ", ".join(f"{k}={v}" for k, v in steps_info.items())
print(f"[run] 步骤开关: {step_flags}")
warmup = config.get("warmup_steps", 0)
ss = config.get("sample_start", "从头")
se = config.get("sample_end", "到尾")
method = config.get("method", "explicit_euler")
coord_file = config.get("coord_file", os.path.join("input", "coord.txt"))
plot_atom = config.get("plot_atom", "第一个原子")
print(f"[run] 算法: {method}")
print(f"[run] 坐标文件: {coord_file}")
print(f"[run] 绘图/动画原子序号: {plot_atom}")
print(f"[run] 步骤控制: 预热={warmup}步, 抽帧范围=[{ss}, {se})")
output_dir_abs = compute.get_output_dir(str(runtime_base))
traj_path = os.path.join(output_dir_abs, "trajectory.txt")
disp_path = os.path.join(output_dir_abs, "display.txt")
# ── 自动缓存检测 ───────────────────────────────────────
# 若 output/ 中 trajectory.txt 和 display.txt 均已存在,
# 自动跳过模拟和抽帧,直接使用已有结果。
# 如需强制重新计算,删除 output/ 目录或设 step_simulate: 1 即可。
output_exists = (
os.path.isdir(output_dir_abs)
and os.path.exists(traj_path)
and os.path.exists(disp_path)
)
if output_exists:
if config.get("step_simulate", 1):
print(f"[run] 检测到已有输出({traj_path}),自动跳过模拟与抽帧,直接进入后续步骤")
config["step_simulate"] = 0
config["step_sample"] = 0
else:
print(f"[run] 已有输出,步骤已被跳过")
else:
# 目录存在但文件不全 → 强制重新计算
if os.path.isdir(output_dir_abs):
print(f"[run] output/ 目录存在但文件不完整,将重新计算")
else:
print(f"[run] output/ 目录不存在,将执行完整流程")
# 2. 运行物理模拟 → output/trajectory.txt
if config.get("step_simulate", 1):
engine = config.get("engine", "python")
total_steps = config["NT"]
record_steps = total_steps - (config.get("warmup_steps") or 0)
print(f"[run] 开始计算 总步数={total_steps} 记录步数={record_steps} DT={config['DT']}")
if engine == "python":
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_from_config(config, str(runtime_base))
print(f"[run] 计算完成,记录 {record_steps}")
compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base))
else:
# 外部引擎:先加载配置到全局变量,再运行引擎,再用 save_trajectory_txt 补全 metadata
config["_skip_run"] = True
compute.run_from_config(config, str(runtime_base))
config.pop("_skip_run", None)
input_dir_abs = str(input_dir_path.resolve())
output_dir_abs = str(output_dir_path.resolve())
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_engine(
engine, input_dir_abs, output_dir_abs, config)
compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base))
else:
print("[run] 步骤 [模拟] 已跳过,直接加载已有轨迹")
if not os.path.exists(traj_path):
print(f"[run] 错误: trajectory.txt 不存在,无法跳过模拟")
sys.exit(1)
# 3. 抽帧 → output/display.txt
traj_path = os.path.join(output_dir_abs, "trajectory.txt")
data = compute.load_text_data(traj_path)
NT = int(data["NT"]); DT = float(data["DT"]); NSTEP = int(data["NSTEP"])
warmup_steps = int(data.get("warmup_steps", 0))
plot_atom_row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0
plot_atom_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else int(data["atom_ids"][plot_atom_row])
# 抽帧范围控制
sample_start = read_optional_index(data, "sample_start", 0)
sample_end = read_optional_index(data, "sample_end", NT)
indices = build_sample_indices(NT, NSTEP, sample_start, sample_end)
n_frames = len(indices)
print(f"[run] 抽帧范围: [{sample_start}, {sample_end}), 共 {n_frames}")
traj_x = data["traj_x"]
traj_y = data["traj_y"]
traj_z = data["traj_z"]
traj_vx = data["traj_vx"]
traj_vy = data["traj_vy"]
traj_vz = data["traj_vz"]
if traj_x.ndim == 1:
selected_x = traj_x
selected_y = traj_y
selected_z = traj_z
selected_vx = traj_vx
selected_vy = traj_vy
selected_vz = traj_vz
all_x = traj_x[:, None]
all_y = traj_y[:, None]
all_z = traj_z[:, None]
all_vx = traj_vx[:, None]
all_vy = traj_vy[:, None]
all_vz = traj_vz[:, None]
else:
selected_x = traj_x[:, plot_atom_row]
selected_y = traj_y[:, plot_atom_row]
selected_z = traj_z[:, plot_atom_row]
selected_vx = traj_vx[:, plot_atom_row]
selected_vy = traj_vy[:, plot_atom_row]
selected_vz = traj_vz[:, plot_atom_row]
all_x = traj_x
all_y = traj_y
all_z = traj_z
all_vx = traj_vx
all_vy = traj_vy
all_vz = traj_vz
if config.get("step_sample", 1):
disp_data = {
"disp_x": selected_x[indices],
"disp_y": selected_y[indices],
"disp_z": selected_z[indices],
"disp_vx": selected_vx[indices],
"disp_vy": selected_vy[indices],
"disp_vz": selected_vz[indices],
"disp_all_x": all_x[indices],
"disp_all_y": all_y[indices],
"disp_all_z": all_z[indices],
"disp_all_vx": all_vx[indices],
"disp_all_vy": all_vy[indices],
"disp_all_vz": all_vz[indices],
"disp_t": indices * DT,
"disp_step": indices,
"n_frames": n_frames,
"NT": NT, "DT": DT, "NSTEP": NSTEP,
"plot_atom_id": plot_atom_id,
"plot_atom_row": plot_atom_row,
"method": str(data["method"]) if "method" in data else "explicit_euler",
"coord_file": str(data["coord_file"]) if "coord_file" in data else os.path.join("input", "coord.txt"),
"atom_ids": data["atom_ids"] if "atom_ids" in data else np.array([1]),
"atom_masses": data["atom_masses"] if "atom_masses" in data else np.array([float(data["M"])]),
"atom_radii": data["atom_radii"] if "atom_radii" in data else np.array([float(data["ball_radius"])]),
"atom_positions": data["atom_positions"] if "atom_positions" in data else np.array([[float(data["X0"]), float(data["Y0"]), float(data["Z0"])]]),
"atom_velocities": data["atom_velocities"] if "atom_velocities" in data else np.array([[float(data["VX0"]), float(data["VY0"]), float(data["VZ0"])]]),
"atom_fixed": data["atom_fixed"] if "atom_fixed" in data else np.array([[0, 0, 0]]),
"bond_pairs": data.get("bond_pairs", np.zeros((0, 2), dtype=np.int64)).tolist(),
"warmup_steps": warmup_steps,
"sample_start": sample_start,
"sample_end": sample_end,
"X_MIN": float(data["X_MIN"]), "X_MAX": float(data["X_MAX"]),
"Y_MIN": float(data["Y_MIN"]), "Y_MAX": float(data["Y_MAX"]),
"Z_MIN": float(data["Z_MIN"]), "Z_MAX": float(data["Z_MAX"]),
"X0": float(data["X0"]), "Y0": float(data["Y0"]), "Z0": float(data["Z0"]),
"VX0": float(data["VX0"]), "VY0": float(data["VY0"]), "VZ0": float(data["VZ0"]),
"M": float(data["M"]) if "M" in data else 1.0,
"alpha": data["alpha"],
"ball_radius": float(data["ball_radius"]),
"ball_color_r": float(data["ball_color_r"]),
"ball_color_g": float(data["ball_color_g"]),
"ball_color_b": float(data["ball_color_b"]),
"box_color_r": float(data["box_color_r"]),
"box_color_g": float(data["box_color_g"]),
"box_color_b": float(data["box_color_b"]),
}
save_display_txt(disp_data, str(runtime_base))
print(f"[run] 抽帧完成: {sample_end - sample_start} 步 -> {n_frames}")
else:
print("[run] 步骤 [抽帧] 已跳过")
# 4. 绘图(可选)
if not no_plot and config.get("step_plot", 1):
try:
import matplotlib.pyplot as plt
# 配置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
time = np.arange(NT) * DT
n_atoms = all_x.shape[1]
atom_ids_list = data.get("atom_ids", np.arange(n_atoms) + 1)
fig, axes = plt.subplots(3, 3, figsize=(15, 13))
fig.suptitle("轨迹与能量分析", fontsize=16)
# ── 位置 / 速度 6 子图(前 2 行,每行 3 列) ──
plot_configs = [
(axes[0, 0], all_x, "x - 时间"),
(axes[0, 1], all_y, "y - 时间"),
(axes[0, 2], all_z, "z - 时间"),
(axes[1, 0], all_vx, "vx - 时间"),
(axes[1, 1], all_vy, "vy - 时间"),
(axes[1, 2], all_vz, "vz - 时间"),
]
colors = plt.cm.tab10(np.linspace(0, 1, n_atoms))
for ax, data_arr, title in plot_configs:
for i in range(n_atoms):
atom_id = int(atom_ids_list[i])
ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {atom_id}")
ax.set_title(title)
ax.set_xlabel("时间 (s)")
ax.grid(True, alpha=0.3)
ax.legend()
# ── 能量计算 ─────────────────────────────────────
masses = np.array(data["atom_masses"]) # (n_atoms,)
G_vec = np.array(data.get("G", [0.0, 0.0, -9.8])) # [gx, gy, gz]
# 动能 Ek = ½ m v²
ek = 0.5 * masses[np.newaxis, :] * (all_vx**2 + all_vy**2 + all_vz**2)
# 重力势能 Ug = -m G·r
ug = -masses[np.newaxis, :] * (
G_vec[0] * all_x + G_vec[1] * all_y + G_vec[2] * all_z
)
# 弹性势能 Us = ½ k (d - d₀)²
us = np.zeros_like(ek)
bond_pairs = data.get("bond_pairs")
bond_stiffness = data.get("bond_stiffness")
bond_rest_lengths = data.get("bond_rest_lengths")
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = all_x[:, j] - all_x[:, i]
dy = all_y[:, j] - all_y[:, i]
dz = all_z[:, j] - all_z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2
us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2
# 各原子总能量
e_total = ek + ug + us # (NT, n_atoms)
# 系统能量分量
ek_sys = np.sum(ek, axis=1)
ug_sys = np.sum(ug, axis=1)
us_sys = np.sum(us, axis=1)
e_sys = ek_sys + ug_sys + us_sys
# ── 第 3 行左:各原子总能量 ──
ax_e = axes[2, 0]
for i in range(n_atoms):
aid = int(atom_ids_list[i])
ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
ax_e.set_title("各原子总能量")
ax_e.set_xlabel("时间 (s)")
ax_e.set_ylabel("能量")
ax_e.grid(True, alpha=0.3)
ax_e.legend(loc="upper right")
# ── 第 3 行右:系统总能量 ──
ax_sys = axes[2, 1]
ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能")
ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能")
if bond_pairs is not None and len(bond_pairs) > 0:
ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能")
ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量")
ax_sys.set_title("系统总能量")
ax_sys.set_xlabel("时间 (s)")
ax_sys.set_ylabel("能量")
ax_sys.grid(True, alpha=0.3)
ax_sys.legend(loc="upper right")
# 隐藏第 3 行第 3 列空子图
axes[2, 2].set_visible(False)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plot_path = os.path.join(output_dir_abs, "trajectory_plots.png")
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
print(f"[run] 轨迹图表已保存至: {plot_path}")
plt.show()
except ImportError:
print("[run] 警告: 未安装 matplotlib,跳过绘图")
print(f"[run] 完成!输出目录: {output_dir_abs}")
# 5. 自动播放动画(可选)
if config.get("step_animation", 0):
draw_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "draw.py")
if os.path.exists(draw_script):
try:
print("[run] 正在启动 VisPy 3D 动画窗口…")
subprocess.Popen(
[sys.executable, draw_script],
cwd=runtime_base,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except Exception as e:
print(f"[run] 启动动画失败: {e}")
else:
print(f"[run] 未找到动画脚本: {draw_script}")
else:
print("[run] 运行 python draw.py 查看动画。")
def main():
parser = argparse.ArgumentParser(description="物理模拟统一入口")
parser.add_argument("config_file", nargs="?", default=os.path.join("input", "parameters.yaml"),
help="YAML 配置文件路径(默认: input/parameters.yaml")
parser.add_argument("--config", dest="config_override",
help="YAML 配置文件路径(可选,优先于位置参数)")
parser.add_argument("--input-dir", default="input",
help="输入目录路径(默认: input")
parser.add_argument("--output-dir", default="output",
help="输出目录路径(默认: output")
parser.add_argument("--runtime-base", default=".",
help="案例运行根目录(默认: 当前目录)")
parser.add_argument("--no-plot", action="store_true",
help="跳过 matplotlib 绘图")
args = parser.parse_args()
config_path = args.config_override or args.config_file
runtime_base = resolve_path(os.getcwd(), args.runtime_base)
config_path_abs = resolve_path(runtime_base, config_path)
if not os.path.exists(config_path_abs):
print(f"错误: 配置文件不存在: {config_path_abs}")
sys.exit(1)
run_case(
config_path=config_path_abs,
runtime_base=runtime_base,
input_dir=args.input_dir,
output_dir=args.output_dir,
no_plot=args.no_plot,
)
if __name__ == "__main__":
main()