""" dynamics.py ------ 统一入口:读取 YAML 配置文件 → 运行模拟 → 抽帧 → 绘图(可选) 用法: python dynamics.py # 使用 input/parameters.yaml python dynamics.py input/parameters.yaml # 指定配置文件 python dynamics.py --config input/parameters.yaml --no-plot """ import os import sys import subprocess import argparse from contextlib import contextmanager from pathlib import Path import yaml import numpy as np # 导入同目录下的模块 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import compute def read_optional_index(data, key, default_value): """Read an optional integer index from structured txt metadata.""" if key not in data: return default_value value = data[key] if value is None or int(value) < 0: return default_value return int(value) def load_yaml_config(config_path): """从 YAML 文件加载配置字典。""" with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return config def resolve_path(base_dir, path_value): """Resolve a path relative to the runtime base directory.""" path = Path(path_value) if path.is_absolute(): return path return (Path(base_dir) / path).resolve() @contextmanager def runtime_dir_overrides(input_dir=None, output_dir=None): """Temporarily override runtime input/output directories for compute helpers.""" old_input = os.environ.get("DYNAMICS_INPUT_DIR") old_output = os.environ.get("DYNAMICS_OUTPUT_DIR") try: if input_dir is not None: os.environ["DYNAMICS_INPUT_DIR"] = str(Path(input_dir).resolve()) if output_dir is not None: os.environ["DYNAMICS_OUTPUT_DIR"] = str(Path(output_dir).resolve()) yield finally: if old_input is None: os.environ.pop("DYNAMICS_INPUT_DIR", None) else: os.environ["DYNAMICS_INPUT_DIR"] = old_input if old_output is None: os.environ.pop("DYNAMICS_OUTPUT_DIR", None) else: os.environ["DYNAMICS_OUTPUT_DIR"] = old_output def build_sample_indices(total_steps, sample_step, sample_start, sample_end): """Validate sampling settings and build frame indices.""" if sample_step <= 0: raise ValueError(f"NSTEP 必须为正整数,实际为 {sample_step}") if sample_start < 0: raise ValueError(f"sample_start 不能小于 0,实际为 {sample_start}") if sample_end > total_steps: raise ValueError( f"sample_end 不能大于记录步数 {total_steps},实际为 {sample_end}") if sample_start >= sample_end: raise ValueError( f"sample_start 必须小于 sample_end,实际为 [{sample_start}, {sample_end})") n_frames = (sample_end - sample_start) // sample_step if n_frames <= 0: raise ValueError( f"抽帧范围 [{sample_start}, {sample_end}) 过短,按 NSTEP={sample_step} 无法抽出任何帧") indices = np.arange(n_frames, dtype=np.int64) * sample_step + sample_start return indices def save_display_txt(data, out_dir=None): """将抽帧数据保存到 output/display.txt(含所有参数元数据)。""" if out_dir is None: out_dir = os.path.dirname(os.path.abspath(__file__)) disp_path = os.path.join(compute.get_output_dir(out_dir), "display.txt") compute.save_text_data(disp_path, data) print(f"[sample] 显示数组已保存至: {disp_path}") return disp_path def run_case(config_path, runtime_base, input_dir="input", output_dir="output", no_plot=False): """Run one case with explicit program path, input path, and output path.""" runtime_base = Path(runtime_base).resolve() input_dir_path = resolve_path(runtime_base, input_dir) output_dir_path = resolve_path(runtime_base, output_dir) config_path = resolve_path(runtime_base, config_path) with runtime_dir_overrides(input_dir_path, output_dir_path): # 1. 加载 YAML 配置 config = load_yaml_config(config_path) print(f"[run] 已加载配置: {config_path}") # 显示步骤控制信息 steps_info = {k: config.get(k, 1) for k in ["step_simulate", "step_sample", "step_plot", "step_animation"]} step_flags = ", ".join(f"{k}={v}" for k, v in steps_info.items()) print(f"[run] 步骤开关: {step_flags}") warmup = config.get("warmup_steps", 0) ss = config.get("sample_start", "从头") se = config.get("sample_end", "到尾") method = config.get("method", "explicit_euler") coord_file = config.get("coord_file", os.path.join("input", "coord.txt")) plot_atom = config.get("plot_atom", "第一个原子") print(f"[run] 算法: {method}") print(f"[run] 坐标文件: {coord_file}") print(f"[run] 绘图/动画原子序号: {plot_atom}") print(f"[run] 步骤控制: 预热={warmup}步, 抽帧范围=[{ss}, {se})") output_dir_abs = compute.get_output_dir(str(runtime_base)) traj_path = os.path.join(output_dir_abs, "trajectory.txt") disp_path = os.path.join(output_dir_abs, "display.txt") # ── 自动缓存检测 ─────────────────────────────────────── # 若 output/ 中 trajectory.txt 和 display.txt 均已存在, # 自动跳过模拟和抽帧,直接使用已有结果。 # 如需强制重新计算,删除 output/ 目录或设 step_simulate: 1 即可。 output_exists = ( os.path.isdir(output_dir_abs) and os.path.exists(traj_path) and os.path.exists(disp_path) ) if output_exists: if config.get("step_simulate", 1): print(f"[run] 检测到已有输出({traj_path}),自动跳过模拟与抽帧,直接进入后续步骤") config["step_simulate"] = 0 config["step_sample"] = 0 else: print(f"[run] 已有输出,步骤已被跳过") else: # 目录存在但文件不全 → 强制重新计算 if os.path.isdir(output_dir_abs): print(f"[run] output/ 目录存在但文件不完整,将重新计算") else: print(f"[run] output/ 目录不存在,将执行完整流程") # 2. 运行物理模拟 → output/trajectory.txt if config.get("step_simulate", 1): engine = config.get("engine", "python") total_steps = config["NT"] record_steps = total_steps - (config.get("warmup_steps") or 0) print(f"[run] 开始计算 总步数={total_steps} 记录步数={record_steps} DT={config['DT']}") if engine == "python": traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_from_config(config, str(runtime_base)) print(f"[run] 计算完成,记录 {record_steps} 步") compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base)) else: # 外部引擎:先加载配置到全局变量,再运行引擎,再用 save_trajectory_txt 补全 metadata config["_skip_run"] = True compute.run_from_config(config, str(runtime_base)) config.pop("_skip_run", None) input_dir_abs = str(input_dir_path.resolve()) output_dir_abs = str(output_dir_path.resolve()) traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = compute.run_engine( engine, input_dir_abs, output_dir_abs, config) compute.save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, str(runtime_base)) else: print("[run] 步骤 [模拟] 已跳过,直接加载已有轨迹") if not os.path.exists(traj_path): print(f"[run] 错误: trajectory.txt 不存在,无法跳过模拟") sys.exit(1) # 3. 抽帧 → output/display.txt traj_path = os.path.join(output_dir_abs, "trajectory.txt") data = compute.load_text_data(traj_path) NT = int(data["NT"]); DT = float(data["DT"]); NSTEP = int(data["NSTEP"]) warmup_steps = int(data.get("warmup_steps", 0)) plot_atom_row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0 plot_atom_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else int(data["atom_ids"][plot_atom_row]) # 抽帧范围控制 sample_start = read_optional_index(data, "sample_start", 0) sample_end = read_optional_index(data, "sample_end", NT) indices = build_sample_indices(NT, NSTEP, sample_start, sample_end) n_frames = len(indices) print(f"[run] 抽帧范围: [{sample_start}, {sample_end}), 共 {n_frames} 帧") traj_x = data["traj_x"] traj_y = data["traj_y"] traj_z = data["traj_z"] traj_vx = data["traj_vx"] traj_vy = data["traj_vy"] traj_vz = data["traj_vz"] if traj_x.ndim == 1: selected_x = traj_x selected_y = traj_y selected_z = traj_z selected_vx = traj_vx selected_vy = traj_vy selected_vz = traj_vz all_x = traj_x[:, None] all_y = traj_y[:, None] all_z = traj_z[:, None] all_vx = traj_vx[:, None] all_vy = traj_vy[:, None] all_vz = traj_vz[:, None] else: selected_x = traj_x[:, plot_atom_row] selected_y = traj_y[:, plot_atom_row] selected_z = traj_z[:, plot_atom_row] selected_vx = traj_vx[:, plot_atom_row] selected_vy = traj_vy[:, plot_atom_row] selected_vz = traj_vz[:, plot_atom_row] all_x = traj_x all_y = traj_y all_z = traj_z all_vx = traj_vx all_vy = traj_vy all_vz = traj_vz if config.get("step_sample", 1): disp_data = { "disp_x": selected_x[indices], "disp_y": selected_y[indices], "disp_z": selected_z[indices], "disp_vx": selected_vx[indices], "disp_vy": selected_vy[indices], "disp_vz": selected_vz[indices], "disp_all_x": all_x[indices], "disp_all_y": all_y[indices], "disp_all_z": all_z[indices], "disp_all_vx": all_vx[indices], "disp_all_vy": all_vy[indices], "disp_all_vz": all_vz[indices], "disp_t": indices * DT, "disp_step": indices, "n_frames": n_frames, "NT": NT, "DT": DT, "NSTEP": NSTEP, "plot_atom_id": plot_atom_id, "plot_atom_row": plot_atom_row, "method": str(data["method"]) if "method" in data else "explicit_euler", "coord_file": str(data["coord_file"]) if "coord_file" in data else os.path.join("input", "coord.txt"), "atom_ids": data["atom_ids"] if "atom_ids" in data else np.array([1]), "atom_masses": data["atom_masses"] if "atom_masses" in data else np.array([float(data["M"])]), "atom_radii": data["atom_radii"] if "atom_radii" in data else np.array([float(data["ball_radius"])]), "atom_positions": data["atom_positions"] if "atom_positions" in data else np.array([[float(data["X0"]), float(data["Y0"]), float(data["Z0"])]]), "atom_velocities": data["atom_velocities"] if "atom_velocities" in data else np.array([[float(data["VX0"]), float(data["VY0"]), float(data["VZ0"])]]), "atom_fixed": data["atom_fixed"] if "atom_fixed" in data else np.array([[0, 0, 0]]), "bond_pairs": data.get("bond_pairs", np.zeros((0, 2), dtype=np.int64)).tolist(), "warmup_steps": warmup_steps, "sample_start": sample_start, "sample_end": sample_end, "X_MIN": float(data["X_MIN"]), "X_MAX": float(data["X_MAX"]), "Y_MIN": float(data["Y_MIN"]), "Y_MAX": float(data["Y_MAX"]), "Z_MIN": float(data["Z_MIN"]), "Z_MAX": float(data["Z_MAX"]), "X0": float(data["X0"]), "Y0": float(data["Y0"]), "Z0": float(data["Z0"]), "VX0": float(data["VX0"]), "VY0": float(data["VY0"]), "VZ0": float(data["VZ0"]), "M": float(data["M"]) if "M" in data else 1.0, "alpha": data["alpha"], "ball_radius": float(data["ball_radius"]), "ball_color_r": float(data["ball_color_r"]), "ball_color_g": float(data["ball_color_g"]), "ball_color_b": float(data["ball_color_b"]), "box_color_r": float(data["box_color_r"]), "box_color_g": float(data["box_color_g"]), "box_color_b": float(data["box_color_b"]), } save_display_txt(disp_data, str(runtime_base)) print(f"[run] 抽帧完成: {sample_end - sample_start} 步 -> {n_frames} 帧") else: print("[run] 步骤 [抽帧] 已跳过") # 4. 绘图(可选) if not no_plot and config.get("step_plot", 1): try: import matplotlib.pyplot as plt # 配置中文字体支持 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False time = np.arange(NT) * DT n_atoms = all_x.shape[1] atom_ids_list = data.get("atom_ids", np.arange(n_atoms) + 1) fig, axes = plt.subplots(3, 3, figsize=(15, 13)) fig.suptitle("轨迹与能量分析", fontsize=16) # ── 位置 / 速度 6 子图(前 2 行,每行 3 列) ── plot_configs = [ (axes[0, 0], all_x, "x - 时间"), (axes[0, 1], all_y, "y - 时间"), (axes[0, 2], all_z, "z - 时间"), (axes[1, 0], all_vx, "vx - 时间"), (axes[1, 1], all_vy, "vy - 时间"), (axes[1, 2], all_vz, "vz - 时间"), ] colors = plt.cm.tab10(np.linspace(0, 1, n_atoms)) for ax, data_arr, title in plot_configs: for i in range(n_atoms): atom_id = int(atom_ids_list[i]) ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {atom_id}") ax.set_title(title) ax.set_xlabel("时间 (s)") ax.grid(True, alpha=0.3) ax.legend() # ── 能量计算 ───────────────────────────────────── masses = np.array(data["atom_masses"]) # (n_atoms,) G_vec = np.array(data.get("G", [0.0, 0.0, -9.8])) # [gx, gy, gz] # 动能 Ek = ½ m v² ek = 0.5 * masses[np.newaxis, :] * (all_vx**2 + all_vy**2 + all_vz**2) # 重力势能 Ug = -m G·r ug = -masses[np.newaxis, :] * ( G_vec[0] * all_x + G_vec[1] * all_y + G_vec[2] * all_z ) # 弹性势能 Us = ½ k (d - d₀)² us = np.zeros_like(ek) bond_pairs = data.get("bond_pairs") bond_stiffness = data.get("bond_stiffness") bond_rest_lengths = data.get("bond_rest_lengths") if bond_pairs is not None and len(bond_pairs) > 0: for b_idx in range(len(bond_pairs)): i, j = bond_pairs[b_idx] dx = all_x[:, j] - all_x[:, i] dy = all_y[:, j] - all_y[:, i] dz = all_z[:, j] - all_z[:, i] dist = np.sqrt(dx**2 + dy**2 + dz**2) stretch = dist - bond_rest_lengths[b_idx] us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2 us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2 # 各原子总能量 e_total = ek + ug + us # (NT, n_atoms) # 系统能量分量 ek_sys = np.sum(ek, axis=1) ug_sys = np.sum(ug, axis=1) us_sys = np.sum(us, axis=1) e_sys = ek_sys + ug_sys + us_sys # ── 第 3 行左:各原子总能量 ── ax_e = axes[2, 0] for i in range(n_atoms): aid = int(atom_ids_list[i]) ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}") ax_e.set_title("各原子总能量") ax_e.set_xlabel("时间 (s)") ax_e.set_ylabel("能量") ax_e.grid(True, alpha=0.3) ax_e.legend(loc="upper right") # ── 第 3 行右:系统总能量 ── ax_sys = axes[2, 1] ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能") ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能") if bond_pairs is not None and len(bond_pairs) > 0: ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能") ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量") ax_sys.set_title("系统总能量") ax_sys.set_xlabel("时间 (s)") ax_sys.set_ylabel("能量") ax_sys.grid(True, alpha=0.3) ax_sys.legend(loc="upper right") # 隐藏第 3 行第 3 列空子图 axes[2, 2].set_visible(False) plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plot_path = os.path.join(output_dir_abs, "trajectory_plots.png") plt.savefig(plot_path, dpi=300, bbox_inches="tight") print(f"[run] 轨迹图表已保存至: {plot_path}") plt.show() except ImportError: print("[run] 警告: 未安装 matplotlib,跳过绘图") print(f"[run] 完成!输出目录: {output_dir_abs}") # 5. 自动播放动画(可选) if config.get("step_animation", 0): draw_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "draw.py") if os.path.exists(draw_script): try: print("[run] 正在启动 VisPy 3D 动画窗口…") subprocess.Popen( [sys.executable, draw_script], cwd=runtime_base, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) except Exception as e: print(f"[run] 启动动画失败: {e}") else: print(f"[run] 未找到动画脚本: {draw_script}") else: print("[run] 运行 python draw.py 查看动画。") def main(): parser = argparse.ArgumentParser(description="物理模拟统一入口") parser.add_argument("config_file", nargs="?", default=os.path.join("input", "parameters.yaml"), help="YAML 配置文件路径(默认: input/parameters.yaml)") parser.add_argument("--config", dest="config_override", help="YAML 配置文件路径(可选,优先于位置参数)") parser.add_argument("--input-dir", default="input", help="输入目录路径(默认: input)") parser.add_argument("--output-dir", default="output", help="输出目录路径(默认: output)") parser.add_argument("--runtime-base", default=".", help="案例运行根目录(默认: 当前目录)") parser.add_argument("--no-plot", action="store_true", help="跳过 matplotlib 绘图") args = parser.parse_args() config_path = args.config_override or args.config_file runtime_base = resolve_path(os.getcwd(), args.runtime_base) config_path_abs = resolve_path(runtime_base, config_path) if not os.path.exists(config_path_abs): print(f"错误: 配置文件不存在: {config_path_abs}") sys.exit(1) run_case( config_path=config_path_abs, runtime_base=runtime_base, input_dir=args.input_dir, output_dir=args.output_dir, no_plot=args.no_plot, ) if __name__ == "__main__": main()