""" compute.py ---------- 纯 Python 计算脚本,不依赖 VisPy。 功能: 1. 运行 NT 步物理模拟( kinematics / dynamics 等运动模式) 2. 将每一步的 (x, y, z, vx, vy, vz) 保存到 output/trajectory.txt 3. 同时保存所有模拟参数元数据 """ import json import numpy as np import os import platform import subprocess import sys import time import datetime from tqdm import trange # =========================================================================== # 物理参数(必须与 draw.py 加载的边界常量保持一致) # =========================================================================== # 全局参数变量,将在 load_parameters 中填充 box_a = None alpha = None COORD_FILE = None X0 = None Y0 = None Z0 = None VX0 = None VY0 = None VZ0 = None M = None ATOM_IDS = None ATOM_MASSES = None ATOM_RADII = None ATOM_POSITIONS = None ATOM_VELOCITIES = None ATOM_FIXED = None BOND_CONNECTION_FILE = None BOND_PARAMETER_FILE = None BOND_PAIRS = None BOND_NAMES = None BOND_STIFFNESS = None BOND_REST_LENGTHS = None PLOT_ATOM_ID = None PLOT_ATOM_ROW = None G = None B = None METHOD = None NT = None DT = None NSTEP = None warmup_steps = None # 预热步数(跳过不保存) sample_start = None # 抽帧起始索引 sample_end = None # 抽帧结束索引 ball_radius = None ball_color_r = None ball_color_g = None ball_color_b = None box_color_r = None box_color_g = None box_color_b = None # 派生边界(根据 box_a 计算) X_MIN = None X_MAX = None Y_MIN = None Y_MAX = None Z_MIN = None Z_MAX = None def _to_text_value(value): """Convert numpy-heavy objects into JSON-friendly plain Python values.""" if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, (np.integer,)): return int(value) if isinstance(value, (np.floating,)): return float(value) if isinstance(value, dict): return {key: _to_text_value(val) for key, val in value.items()} if isinstance(value, (list, tuple)): return [_to_text_value(item) for item in value] return value def _from_text_value(value): """Convert nested numeric lists back into numpy arrays when appropriate.""" if isinstance(value, list): if not value: return np.array([], dtype=np.float64) if all(not isinstance(item, (list, dict)) for item in value): if all(isinstance(item, (int, float, bool)) for item in value): return np.array(value) return value if all(isinstance(item, list) for item in value): return np.array(value) return [_from_text_value(item) for item in value] if isinstance(value, dict): return {key: _from_text_value(val) for key, val in value.items()} return value def save_text_data(path, data): """Save structured simulation data as formatted JSON text.""" os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(_to_text_value(data), f, ensure_ascii=False, indent=2) return path def load_text_data(path): """Load structured simulation data from JSON text.""" with open(path, "r", encoding="utf-8") as f: data = json.load(f) return _from_text_value(data) def get_output_dir(base_dir=None): """Return the output directory used for generated artifacts.""" override = os.environ.get("DYNAMICS_OUTPUT_DIR") if override: output_dir = os.path.abspath(override) else: if base_dir is None: base_dir = os.path.dirname(os.path.abspath(__file__)) output_dir = os.path.join(base_dir, "output") os.makedirs(output_dir, exist_ok=True) return output_dir def get_input_dir(base_dir=None): """Return the input directory used for configuration and source data.""" override = os.environ.get("DYNAMICS_INPUT_DIR") if override: input_dir = os.path.abspath(override) else: if base_dir is None: base_dir = os.path.dirname(os.path.abspath(__file__)) input_dir = os.path.join(base_dir, "input") os.makedirs(input_dir, exist_ok=True) return input_dir def load_coord_file(coord_path): """Load atom ids, masses, radii, coordinates, and velocities.""" if not os.path.exists(coord_path): raise FileNotFoundError(f"坐标文件不存在: {coord_path}") rows = [] with open(coord_path, "r", encoding="utf-8") as f: header = f.readline().strip().split() expected = ["n", "mass", "radius", "x", "y", "z", "vx", "vy", "vz"] legacy = expected + ["fixed_x", "fixed_y", "fixed_z"] if header not in (expected, legacy): raise ValueError( f"坐标文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}") for line_no, line in enumerate(f, start=2): stripped = line.strip() if not stripped or stripped.startswith("#"): continue parts = stripped.split() if header == expected and len(parts) != 9: raise ValueError(f"{coord_path}:{line_no} 应有 9 列,实际为 {len(parts)} 列") if header == legacy and len(parts) != 12: raise ValueError(f"{coord_path}:{line_no} 应有 12 列,实际为 {len(parts)} 列") rows.append(parts) if not rows: raise ValueError(f"坐标文件没有原子数据: {coord_path}") atom_ids = np.array([int(row[0]) for row in rows], dtype=np.int64) masses = np.array([float(row[1]) for row in rows], dtype=np.float64) radii = np.array([float(row[2]) for row in rows], dtype=np.float64) positions = np.array([[float(row[3]), float(row[4]), float(row[5])] for row in rows], dtype=np.float64) velocities = np.array([[float(row[6]), float(row[7]), float(row[8])] for row in rows], dtype=np.float64) fixed = np.zeros((len(rows), 3), dtype=np.int64) if header == legacy: fixed = np.array([[int(row[9]), int(row[10]), int(row[11])] for row in rows], dtype=np.int64) if np.any(masses <= 0): raise ValueError("坐标文件中的质量必须为正数") if np.any(radii <= 0): raise ValueError("坐标文件中的半径必须为正数") print(f"[compute] 已加载坐标文件: {coord_path},原子数={len(atom_ids)}") return atom_ids, masses, radii, positions, velocities, fixed def load_bond_parameters(bond_path): """Load bond stiffness and optional rest length by bond name. File format (表头): bond_name k # 2 列:从初始坐标计算键长 bond_name k rest_length # 3 列:显式指定键长 """ if not os.path.exists(bond_path): raise FileNotFoundError(f"键参数文件不存在: {bond_path}") bond_map = {} with open(bond_path, "r", encoding="utf-8") as f: header = f.readline().strip().split() if header == ["bond_name", "k"]: has_rest_length = False elif header == ["bond_name", "k", "rest_length"]: has_rest_length = True else: expected = ["bond_name", "k"] + (["rest_length"] if len(header) == 3 else []) raise ValueError( f"键参数文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}") ncols = 3 if has_rest_length else 2 for line_no, line in enumerate(f, start=2): stripped = line.strip() if not stripped or stripped.startswith("#"): continue parts = stripped.split() if len(parts) != ncols: raise ValueError( f"{bond_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)} 列") bond_name = parts[0] stiffness = float(parts[1]) if stiffness < 0: raise ValueError(f"{bond_path}:{line_no} 劲度系数必须非负") rest_length = None if has_rest_length: rest_length = float(parts[2]) if rest_length <= 0: raise ValueError(f"{bond_path}:{line_no} 键长必须为正数") bond_map[bond_name] = {"stiffness": stiffness, "rest_length": rest_length} if not bond_map: raise ValueError(f"键参数文件没有有效数据: {bond_path}") return bond_map def load_bond_connections(connection_path, atom_ids, atom_positions, bond_map): """Load bonded atom pairs and derive rest lengths. 键长优先级: 1. bond_map 中显式指定的 rest_length(bond.txt 第三列) 2. 否则从初始坐标计算 """ if not os.path.exists(connection_path): raise FileNotFoundError(f"成键连接文件不存在: {connection_path}") atom_index = {int(atom_id): idx for idx, atom_id in enumerate(atom_ids)} pairs = [] bond_names = [] stiffness = [] rest_lengths = [] with open(connection_path, "r", encoding="utf-8") as f: header = f.readline().strip().split() expected = ["n1", "n2", "bond_name"] if header != expected: raise ValueError( f"成键连接文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}") for line_no, line in enumerate(f, start=2): stripped = line.strip() if not stripped or stripped.startswith("#"): continue parts = stripped.split() if len(parts) != 3: raise ValueError(f"{connection_path}:{line_no} 应有 3 列,实际为 {len(parts)} 列") atom_1 = int(parts[0]) atom_2 = int(parts[1]) bond_name = parts[2] if atom_1 not in atom_index or atom_2 not in atom_index: raise ValueError( f"{connection_path}:{line_no} 中的原子序号 {atom_1}, {atom_2} 不在坐标文件里") if bond_name not in bond_map: raise ValueError( f"{connection_path}:{line_no} 使用了未定义的键类型 {bond_name}") idx_1 = atom_index[atom_1] idx_2 = atom_index[atom_2] bond_entry = bond_map[bond_name] rest_length = bond_entry["rest_length"] if rest_length is None: # 未指定键长时从初始坐标计算 delta = atom_positions[idx_2] - atom_positions[idx_1] rest_length = float(np.linalg.norm(delta)) pairs.append((idx_1, idx_2)) bond_names.append(bond_name) stiffness.append(bond_entry["stiffness"]) rest_lengths.append(rest_length) if not pairs: return ( np.zeros((0, 2), dtype=np.int64), [], np.zeros(0, dtype=np.float64), np.zeros(0, dtype=np.float64), ) return ( np.array(pairs, dtype=np.int64), bond_names, np.array(stiffness, dtype=np.float64), np.array(rest_lengths, dtype=np.float64), ) def parse_gravity_vector(value): """Parse gravity into a 3D acceleration vector. Backward compatibility: - scalar `G: 9.8` -> [0, 0, -9.8] - vector `G: [gx, gy, gz]` -> [gx, gy, gz] """ if isinstance(value, (int, float, np.integer, np.floating)): return np.array([0.0, 0.0, -float(value)], dtype=np.float64) vector = np.asarray(value, dtype=np.float64) if vector.shape != (3,): raise ValueError(f"G 必须是长度为 3 的分量数组,实际为 {value}") return vector def parse_damping_vector(value): """Parse damping into a 3D component vector. Backward compatibility: - scalar `B: 0.5` -> [0.5, 0.5, 0.5] - vector `B: [bx, by, bz]` -> [bx, by, bz] """ if isinstance(value, (int, float, np.integer, np.floating)): scalar = float(value) return np.array([scalar, scalar, scalar], dtype=np.float64) vector = np.asarray(value, dtype=np.float64) if vector.shape != (3,): raise ValueError(f"B 必须是长度为 3 的分量数组,实际为 {value}") return vector def run_from_config(config, out_dir=None): """直接接收配置字典,运行模拟并保存结果。 config 字典结构: box_a, alpha, coord_file, connection_file, bond_file, plot_atom, G, B, method, NT, DT, NSTEP, warmup_steps, sample_start, sample_end(可选), ball_radius, ball_color_r/g/b, box_color_r/g/b(可选) out_dir: 输出目录,默认同脚本所在目录 返回:(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz) """ global box_a, alpha, COORD_FILE, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP global PLOT_ATOM_ID, PLOT_ATOM_ROW global ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS, ATOM_VELOCITIES, ATOM_FIXED global BOND_CONNECTION_FILE, BOND_PARAMETER_FILE, BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS, BOND_REST_LENGTHS global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX global ball_radius, ball_color_r, ball_color_g, ball_color_b global box_color_r, box_color_g, box_color_b global warmup_steps, sample_start, sample_end box_a = float(config.get("box_a", 10.0)) raw_alpha = config.get("alpha", 0.2) if isinstance(raw_alpha, (list, tuple)): alpha = [float(a) for a in raw_alpha] else: alpha = float(raw_alpha) COORD_FILE = str(config.get("coord_file", os.path.join("input", "coord.txt"))) coord_path = COORD_FILE if out_dir is not None and not os.path.isabs(coord_path): coord_path = os.path.join(out_dir, coord_path) (ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS, ATOM_VELOCITIES, ATOM_FIXED) = load_coord_file(coord_path) BOND_CONNECTION_FILE = str(config.get("connection_file", os.path.join("input", "connection.txt"))) BOND_PARAMETER_FILE = str(config.get("bond_file", os.path.join("input", "bond.txt"))) connection_path = BOND_CONNECTION_FILE bond_path = BOND_PARAMETER_FILE if out_dir is not None and not os.path.isabs(connection_path): connection_path = os.path.join(out_dir, connection_path) if out_dir is not None and not os.path.isabs(bond_path): bond_path = os.path.join(out_dir, bond_path) bond_map = load_bond_parameters(bond_path) (BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS, BOND_REST_LENGTHS) = load_bond_connections( connection_path, ATOM_IDS, ATOM_POSITIONS, bond_map) PLOT_ATOM_ID = int(config.get("plot_atom", int(ATOM_IDS[0]))) matches = np.where(ATOM_IDS == PLOT_ATOM_ID)[0] if len(matches) == 0: raise ValueError(f"plot_atom={PLOT_ATOM_ID} 不在坐标文件原子序号中") PLOT_ATOM_ROW = int(matches[0]) M = float(ATOM_MASSES[PLOT_ATOM_ROW]) X0, Y0, Z0 = [float(v) for v in ATOM_POSITIONS[PLOT_ATOM_ROW]] VX0, VY0, VZ0 = [float(v) for v in ATOM_VELOCITIES[PLOT_ATOM_ROW]] G = parse_gravity_vector(config["G"]) B = parse_damping_vector(config["B"]) METHOD = normalize_method_name(config.get("method", "explicit_euler")) NT = int(config["NT"]) DT = float(config["DT"]) NSTEP = int(config["NSTEP"]) # 步骤控制参数(可选,有默认值) warmup_steps = int(config.get("warmup_steps", 0)) sample_start = config.get("sample_start") # None 表示从头开始 sample_end = config.get("sample_end") # None 表示到末尾 X_MIN = -box_a; X_MAX = box_a Y_MIN = -box_a; Y_MAX = box_a Z_MIN = -box_a; Z_MAX = box_a ball_radius = float(config.get("ball_radius", ATOM_RADII[PLOT_ATOM_ROW])) ball_color_r = float(config.get("ball_color_r", 0.9)) ball_color_g = float(config.get("ball_color_g", 0.2)) ball_color_b = float(config.get("ball_color_b", 0.2)) box_color_r = float(config.get("box_color_r", 0.8)) box_color_g = float(config.get("box_color_g", 0.8)) box_color_b = float(config.get("box_color_b", 0.85)) print(f"[compute] 使用算法: {METHOD}") print(f"[compute] 已加载成键信息: {len(BOND_PAIRS)} 条键") if config.get("_skip_run", False): return None, None, None, None, None, None t_start = time.time() t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation() t_end = time.time() t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") elapsed = t_end - t_start # 写入 dynamics.log log_dir = get_output_dir(out_dir) log_path = os.path.join(log_dir, "dynamics.log") with open(log_path, "w", encoding="utf-8") as f: f.write("=" * 50 + "\n") f.write("Dynamics 计算日志\n") f.write("=" * 50 + "\n") f.write(f"计算开始时间: {t_start_str}\n") f.write(f"计算结束时间: {t_end_str}\n") f.write(f"计算耗时: {elapsed:.3f} s\n") f.write("-" * 50 + "\n") f.write("计算参数:\n") f.write(f" box_a: {box_a}\n") _alpha_str = alpha if isinstance(alpha, str) else (str(alpha) if isinstance(alpha, list) else str(alpha)) f.write(f" alpha: {_alpha_str}\n") f.write(f" coord_file: {COORD_FILE}\n") f.write(f" method: {METHOD}\n") f.write(f" NT: {NT}\n") f.write(f" DT: {DT}\n") f.write(f" NSTEP: {NSTEP}\n") f.write(f" G: ({G[0]:.2f}, {G[1]:.2f}, {G[2]:.2f})\n") f.write(f" B: ({B[0]:.2f}, {B[1]:.2f}, {B[2]:.2f})\n") f.write(f" warmup: {warmup_steps}\n") f.write(f" 原子数: {len(ATOM_IDS)}\n") f.write(f" 键数: {len(BOND_PAIRS)}\n") f.write("-" * 50 + "\n") f.write("初始状态 (plot_atom):\n") f.write(f" position: ({X0:.6f}, {Y0:.6f}, {Z0:.6f})\n") f.write(f" velocity: ({VX0:.6f}, {VY0:.6f}, {VZ0:.6f})\n") f.write(f" mass: {M}\n") f.write("=" * 50 + "\n") print(f"[compute] 日志已保存至: {log_path}") print(f"[compute] 计算耗时: {elapsed:.3f} s") return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz def run_engine(engine, input_dir, output_dir, config): """调用外部计算引擎(C/C++/Fortran),生成 trajectory.txt。 Args: engine: 引擎名称 ("c", "cpp", "fortran") input_dir: 输入目录(含 coord.txt, connection.txt, bond.txt) output_dir: 输出目录(轨迹文件写入位置) config: YAML 配置字典 Returns: (traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz) """ script_dir = os.path.dirname(os.path.abspath(__file__)) system = platform.system().lower() engine_map = { "c": "engines/c/build/dynamics_c", "cpp": "engines/cpp/build/dynamics_cpp", "fortran": "engines/fortran/build/dynamics_f90", } if engine not in engine_map: raise ValueError(f"不支持的引擎: {engine},可选: {list(engine_map.keys())}") engine_rel = engine_map[engine] engine_path = os.path.join(script_dir, engine_rel) # 自动检测可执行文件后缀和平台专用版本 candidates = [ engine_path, # 无后缀 engine_path + ".exe", # Windows .exe engine_path + f"_{system}.exe", # 平台专用 (c_linux.exe, c_darwin.exe) ] found = None for p in candidates: if os.path.exists(p): found = p break if found is None: raise FileNotFoundError( f"引擎可执行文件不存在: 尝试了 {candidates}\n" f"请先编译: cd engines/{engine} && make\n" f"或安装交叉编译器后: cd engines/{engine} && make {system}") # 构造 param.json(数值参数) G = parse_gravity_vector(config.get("G", [0, 0, -9.8])) B = parse_damping_vector(config.get("B", [0, 0, 0])) param_json = { "box_a": float(config.get("box_a", 10.0)), "NT": int(config["NT"]), "DT": float(config["DT"]), "NSTEP": int(config.get("NSTEP", 1)), "warmup_steps": int(config.get("warmup_steps", 0)), "G": [float(v) for v in G], "B": [float(v) for v in B], } param_path = os.path.join(script_dir, "engines", engine, "param.json") os.makedirs(os.path.dirname(param_path), exist_ok=True) with open(param_path, "w", encoding="utf-8") as f: json.dump(param_json, f, indent=2) # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) print(f"[compute] 引擎: {engine} → {os.path.basename(engine_path)}") print(f"[compute] input: {input_dir}") print(f"[compute] output: {output_dir}") t_start = time.time() t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") result = subprocess.run( [engine_path, os.path.abspath(input_dir), os.path.abspath(output_dir), param_path], capture_output=True, text=True, timeout=600) t_end = time.time() elapsed = t_end - t_start t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 打印引擎输出 if result.stdout: for line in result.stdout.strip().split("\n"): line = line.strip() if line: print(f" {line}") if result.returncode != 0: print(f"[compute] 引擎错误:\n{result.stderr}") raise RuntimeError(f"引擎 {engine} 返回错误码 {result.returncode}") # 加载输出的 trajectory.txt traj_path = os.path.join(os.path.abspath(output_dir), "trajectory.txt") if not os.path.exists(traj_path): raise FileNotFoundError(f"引擎未生成 trajectory.txt: {traj_path}") data = load_text_data(traj_path) traj_x = data["traj_x"] traj_y = data["traj_y"] traj_z = data["traj_z"] traj_vx = data["traj_vx"] traj_vy = data["traj_vy"] traj_vz = data["traj_vz"] # 写入日志文件 log_path = os.path.join(output_dir, "dynamics.log") with open(log_path, "w", encoding="utf-8") as f: f.write("=" * 50 + "\n") f.write("Dynamics 计算日志\n") f.write("=" * 50 + "\n") f.write(f"引擎: {engine}\n") f.write(f"计算开始时间: {t_start_str}\n") f.write(f"计算结束时间: {t_end_str}\n") f.write(f"计算耗时: {elapsed:.3f} s\n") print(f"[compute] 引擎完成: {len(traj_x)} 步, {traj_x.shape[1]} 原子, {elapsed:.3f} s") return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz def save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, out_dir=None): """将轨迹数据保存到结构化 txt 文件(含所有参数元数据)。""" if out_dir is None: out_dir = os.path.dirname(os.path.abspath(__file__)) out_path = os.path.join(get_output_dir(out_dir), "trajectory.txt") # 记录的实际步数(扣除预热后) record_steps = len(traj_x) sample_start_to_save = -1 if sample_start is None else int(sample_start) sample_end_to_save = -1 if sample_end is None else int(sample_end) payload = { "traj_x": traj_x, "traj_y": traj_y, "traj_z": traj_z, "traj_vx": traj_vx, "traj_vy": traj_vy, "traj_vz": traj_vz, "NT": record_steps, "DT": DT, "NSTEP": NSTEP, "method": METHOD, "coord_file": COORD_FILE, "connection_file": BOND_CONNECTION_FILE, "bond_file": BOND_PARAMETER_FILE, "atom_ids": ATOM_IDS, "atom_masses": ATOM_MASSES, "atom_radii": ATOM_RADII, "atom_positions": ATOM_POSITIONS, "atom_velocities": ATOM_VELOCITIES, "atom_fixed": ATOM_FIXED, "bond_pairs": BOND_PAIRS, "bond_names": BOND_NAMES, "bond_stiffness": BOND_STIFFNESS, "bond_rest_lengths": BOND_REST_LENGTHS, "G": G.tolist() if G is not None else None, "B": B.tolist() if B is not None else None, "plot_atom_id": PLOT_ATOM_ID, "plot_atom_row": PLOT_ATOM_ROW, "warmup_steps": warmup_steps or 0, "sample_start": sample_start_to_save, "sample_end": sample_end_to_save, "X_MIN": X_MIN, "X_MAX": X_MAX, "Y_MIN": Y_MIN, "Y_MAX": Y_MAX, "Z_MIN": Z_MIN, "Z_MAX": Z_MAX, "X0": X0, "Y0": Y0, "Z0": Z0, "VX0": VX0, "VY0": VY0, "VZ0": VZ0, "M": M, "alpha": alpha, "ball_radius": ball_radius, "ball_color_r": ball_color_r, "ball_color_g": ball_color_g, "ball_color_b": ball_color_b, "box_color_r": box_color_r, "box_color_g": box_color_g, "box_color_b": box_color_b, } save_text_data(out_path, payload) print(f"[compute] 轨迹数据已保存至: {out_path}") return out_path def load_parameters(param_file): """从文本文件加载参数,文件格式:参数 = 值""" global box_a, alpha, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX global ball_radius, ball_color_r, ball_color_g, ball_color_b global box_color_r, box_color_g, box_color_b print(f"[compute] 正在加载参数文件: {param_file}") with open(param_file, 'r', encoding='utf-8') as f: lines = f.readlines() for line in lines: line = line.strip() if not line or line.startswith('#'): continue if '=' not in line: continue # 分割键值,并去除行内注释 key, value = line.split('=', 1) key = key.strip() value = value.strip() # 去除值中可能存在的注释(# 之后的部分) if '#' in value: value = value.split('#', 1)[0].strip() # 尝试转换为 int 或 float try: if '.' in value or 'e' in value.lower(): value = float(value) else: value = int(value) except ValueError: pass # 保持字符串 # 设置全局变量 if key in globals(): globals()[key] = value else: print(f"警告: 未知参数 '{key}',忽略") # 计算派生边界 if box_a is not None: X_MIN = -box_a X_MAX = box_a Y_MIN = -box_a Y_MAX = box_a Z_MIN = -box_a Z_MAX = box_a else: print("错误: 未找到 box_a 参数") sys.exit(1) # 检查必需参数 required = ['alpha', 'X0', 'Y0', 'Z0', 'VX0', 'VY0', 'VZ0', 'M', 'G', 'B', 'NT', 'DT', 'NSTEP'] for param in required: if globals()[param] is None: print(f"错误: 未找到必需参数 '{param}'") sys.exit(1) METHOD = normalize_method_name(METHOD or "explicit_euler") # 设置绘图参数的默认值(如果未提供) if ball_radius is None: ball_radius = 0.28 if ball_color_r is None: ball_color_r = 0.9 if ball_color_g is None: ball_color_g = 0.2 if ball_color_b is None: ball_color_b = 0.2 if box_color_r is None: box_color_r = 0.8 if box_color_g is None: box_color_g = 0.8 if box_color_b is None: box_color_b = 0.85 print(f"[compute] 参数加载完成: box_a={box_a}, alpha={alpha}, NT={NT}, DT={DT}, method={METHOD}") def normalize_method_name(method): """Normalize method aliases from YAML/text config.""" aliases = { "explicit_euler": "explicit_euler", "explicit": "explicit_euler", "euler_explicit": "explicit_euler", "显式欧拉": "explicit_euler", "显式欧拉法": "explicit_euler", "implicit_euler": "implicit_euler", "implicit": "implicit_euler", "euler_implicit": "implicit_euler", "隐式欧拉": "implicit_euler", "隐式欧拉法": "implicit_euler", "midpoint": "midpoint", "mid_point": "midpoint", "midpoint_method": "midpoint", "中点法": "midpoint", "中点差分法": "midpoint", "leapfrog": "leapfrog", "leap_frog": "leapfrog", "蛙跳法": "leapfrog", } key = str(method).strip().lower() if key not in aliases: valid = ", ".join(["explicit_euler", "implicit_euler", "midpoint", "leapfrog"]) raise ValueError(f"未知算法 '{method}'。可选值: {valid}") return aliases[key] def compute_force(x, y, z, vx, vy, vz, m, g, b): """Compute total force from the current state. Current model: - gravity: F = m * g_vec - linear drag: F_drag = -B_vec * v - spring bonds: Hooke force based on bonded pair distance """ fx = m * g[0] - b[0] * vx fy = m * g[1] - b[1] * vy fz = m * g[2] - b[2] * vz if BOND_PAIRS is not None and len(BOND_PAIRS) > 0: for bond_idx, (idx_1, idx_2) in enumerate(BOND_PAIRS): dx = x[idx_2] - x[idx_1] dy = y[idx_2] - y[idx_1] dz = z[idx_2] - z[idx_1] dist = np.sqrt(dx * dx + dy * dy + dz * dz) if dist <= 1e-12: continue stretch = dist - BOND_REST_LENGTHS[bond_idx] force_mag = BOND_STIFFNESS[bond_idx] * stretch ux = dx / dist uy = dy / dist uz = dz / dist fx_bond = force_mag * ux fy_bond = force_mag * uy fz_bond = force_mag * uz fx[idx_1] += fx_bond fy[idx_1] += fy_bond fz[idx_1] += fz_bond fx[idx_2] -= fx_bond fy[idx_2] -= fy_bond fz[idx_2] -= fz_bond return fx, fy, fz def compute_acceleration(x, y, z, vx, vy, vz, m, g, b): """Compute acceleration from the shared force model.""" fx, fy, fz = compute_force(x, y, z, vx, vy, vz, m, g, b) return fx / m, fy / m, fz / m def Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b): ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b) x = x + vx * dt y = y + vy * dt z = z + vz * dt vx = vx + ax * dt vy = vy + ay * dt vz = vz + az * dt return x, y, z, vx, vy, vz def Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b): gamma_x = b[0] / m gamma_y = b[1] / m gamma_z = b[2] / m vx_next = (vx + g[0] * dt) / (1.0 + gamma_x * dt) vy_next = (vy + g[1] * dt) / (1.0 + gamma_y * dt) vz_next = (vz + g[2] * dt) / (1.0 + gamma_z * dt) ax_next, ay_next, az_next = compute_acceleration( x, y, z, vx_next, vy_next, vz_next, m, g, b) vx = vx + ax_next * dt vy = vy + ay_next * dt vz = vz + az_next * dt x = x + vx * dt y = y + vy * dt z = z + vz * dt return x, y, z, vx, vy, vz def Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b): ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b) x_mid = x + 0.5 * vx * dt y_mid = y + 0.5 * vy * dt z_mid = z + 0.5 * vz * dt vx_mid = vx + 0.5 * ax * dt vy_mid = vy + 0.5 * ay * dt vz_mid = vz + 0.5 * az * dt x = x + vx_mid * dt y = y + vy_mid * dt z = z + vz_mid * dt ax_mid, ay_mid, az_mid = compute_acceleration( x_mid, y_mid, z_mid, vx_mid, vy_mid, vz_mid, m, g, b) vx = vx + ax_mid * dt vy = vy + ay_mid * dt vz = vz + az_mid * dt return x, y, z, vx, vy, vz def Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b): ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b) vx_half = vx + 0.5 * ax * dt vy_half = vy + 0.5 * ay * dt vz_half = vz + 0.5 * az * dt x = x + vx_half * dt y = y + vy_half * dt z = z + vz_half * dt gamma_x = b[0] / m gamma_y = b[1] / m gamma_z = b[2] / m vx_next = (vx_half + 0.5 * g[0] * dt) / (1.0 + 0.5 * gamma_x * dt) vy_next = (vy_half + 0.5 * g[1] * dt) / (1.0 + 0.5 * gamma_y * dt) vz_next = (vz_half + 0.5 * g[2] * dt) / (1.0 + 0.5 * gamma_z * dt) ax_next, ay_next, az_next = compute_acceleration( x, y, z, vx_next, vy_next, vz_next, m, g, b) vx = vx_half + 0.5 * ax_next * dt vy = vy_half + 0.5 * ay_next * dt vz = vz_half + 0.5 * az_next * dt return x, y, z, vx, vy, vz def Limit_in_box(a, amin, amax, va): """限制物体在边界内,发生碰撞时反弹。""" over = a > amax under = a < amin a = np.where(over, amax, a) a = np.where(under, amin, a) va = np.where(over | under, -va, va) return a, va def apply_motion_update(x, y, z, vx, vy, vz, dt, m, g, b): """按配置选择的位置更新算法推进一步。""" if METHOD == "explicit_euler": x, y, z, vx, vy, vz = Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b) elif METHOD == "implicit_euler": x, y, z, vx, vy, vz = Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b) elif METHOD == "midpoint": x, y, z, vx, vy, vz = Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b) elif METHOD == "leapfrog": x, y, z, vx, vy, vz = Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b) else: raise ValueError(f"未知算法: {METHOD}") x, vx = Limit_in_box(x, X_MIN, X_MAX, vx) y, vy = Limit_in_box(y, Y_MIN, Y_MAX, vy) z, vz = Limit_in_box(z, Z_MIN, Z_MAX, vz) return x, y, z, vx, vy, vz def apply_fixed_constraints(x, y, z, vx, vy, vz): """Keep fixed degrees of freedom at their initial coordinate with zero speed.""" fixed = ATOM_FIXED != 0 positions = np.column_stack((x, y, z)) velocities = np.column_stack((vx, vy, vz)) positions = np.where(fixed, ATOM_POSITIONS, positions) velocities = np.where(fixed, 0.0, velocities) return ( positions[:, 0], positions[:, 1], positions[:, 2], velocities[:, 0], velocities[:, 1], velocities[:, 2], ) def wrap_position(x, y, z): """边界回绕( dynamics 模式)。""" x = np.where(x > X_MAX, X_MIN, x) x = np.where(x < X_MIN, X_MAX, x) y = np.where(y > Y_MAX, Y_MIN, y) y = np.where(y < Y_MIN, Y_MAX, y) z = np.where(z > Z_MAX, Z_MIN, z) z = np.where(z < Z_MIN, Z_MAX, z) return x, y, z # =========================================================================== # 主计算流程 # =========================================================================== def run_simulation(): """计算 NT 步轨迹,返回位置/速度数组。 步骤控制: - warmup_steps: 预热步数,跳过不记录(用于稳定初始状态) - 实际记录步数 = NT - warmup_steps """ # 预热阶段 x = ATOM_POSITIONS[:, 0].copy() y = ATOM_POSITIONS[:, 1].copy() z = ATOM_POSITIONS[:, 2].copy() vx = ATOM_VELOCITIES[:, 0].copy() vy = ATOM_VELOCITIES[:, 1].copy() vz = ATOM_VELOCITIES[:, 2].copy() x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz) if warmup_steps is not None and warmup_steps > 0: print(f"[compute] 预热阶段: 前 {warmup_steps} 步不记录") for step in trange(warmup_steps, desc="[compute] 预热"): t = (step + 1) * DT x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B) x, y, z = wrap_position(x, y, z) x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz) print( f"[compute] 预热完成,展示原子位置: " f"({x[PLOT_ATOM_ROW]:.4f}, {y[PLOT_ATOM_ROW]:.4f}, {z[PLOT_ATOM_ROW]:.4f})" ) # 记录阶段 record_steps = NT - (warmup_steps or 0) n_atoms = len(ATOM_IDS) traj_x = np.zeros((record_steps, n_atoms), dtype=np.float64) traj_y = np.zeros((record_steps, n_atoms), dtype=np.float64) traj_z = np.zeros((record_steps, n_atoms), dtype=np.float64) traj_vx = np.zeros((record_steps, n_atoms), dtype=np.float64) traj_vy = np.zeros((record_steps, n_atoms), dtype=np.float64) traj_vz = np.zeros((record_steps, n_atoms), dtype=np.float64) for step in trange(record_steps, desc="[compute] 计算中"): traj_x[step] = x traj_y[step] = y traj_z[step] = z traj_vx[step] = vx traj_vy[step] = vy traj_vz[step] = vz x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B) x, y, z = wrap_position(x, y, z) x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz) return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz def save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT): """将轨迹数据保存为逐行表格文本,便于教学查看。""" if np.ndim(traj_x) > 1: row = PLOT_ATOM_ROW if PLOT_ATOM_ROW is not None else 0 traj_x = traj_x[:, row] traj_y = traj_y[:, row] traj_z = traj_z[:, row] traj_vx = traj_vx[:, row] traj_vy = traj_vy[:, row] traj_vz = traj_vz[:, row] with open(txt_path, 'w', encoding='utf-8') as f: # 写入表头 f.write("# 步数, 时间(s), x, y, z, vx, vy, vz\n") for step in range(NT): t = step * DT x = traj_x[step] y = traj_y[step] z = traj_z[step] vx = traj_vx[step] vy = traj_vy[step] vz = traj_vz[step] f.write(f"{step+1:6d}, {t:.6f}, {x:.6f}, {y:.6f}, {z:.6f}, {vx:.6f}, {vy:.6f}, {vz:.6f}\n") print(f"[compute] 轨迹文本文件已保存至: {txt_path}") def main(): # 默认参数文件 script_dir = os.path.dirname(os.path.abspath(__file__)) param_file = os.path.join(get_input_dir(script_dir), "input.txt") if len(sys.argv) > 1: param_file = sys.argv[1] # 加载参数 load_parameters(param_file) script_dir = os.path.dirname(os.path.abspath(__file__)) output_dir = get_output_dir(script_dir) print(f"[compute] 开始计算 NT={NT} DT={DT} ") traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation() print(f"[compute] 计算完成,共 {NT} 步") save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, script_dir) # 同时保存为逐行表格,便于直接查看 txt_path = os.path.join(output_dir, "trajectory_table.txt") save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT) if __name__ == "__main__": main()