1033 lines
38 KiB
Python
1033 lines
38 KiB
Python
"""
|
||
compute.py
|
||
----------
|
||
纯 Python 计算脚本,不依赖 VisPy。
|
||
|
||
功能:
|
||
1. 运行 NT 步物理模拟( kinematics / dynamics 等运动模式)
|
||
2. 将每一步的 (x, y, z, vx, vy, vz) 保存到 output/trajectory.txt
|
||
3. 同时保存所有模拟参数元数据
|
||
"""
|
||
|
||
import json
|
||
import numpy as np
|
||
import os
|
||
import platform
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
import datetime
|
||
from tqdm import trange
|
||
|
||
# ===========================================================================
|
||
# 物理参数(必须与 draw.py 加载的边界常量保持一致)
|
||
# ===========================================================================
|
||
# 全局参数变量,将在 load_parameters 中填充
|
||
box_a = None
|
||
alpha = None
|
||
COORD_FILE = None
|
||
X0 = None
|
||
Y0 = None
|
||
Z0 = None
|
||
VX0 = None
|
||
VY0 = None
|
||
VZ0 = None
|
||
M = None
|
||
ATOM_IDS = None
|
||
ATOM_MASSES = None
|
||
ATOM_RADII = None
|
||
ATOM_POSITIONS = None
|
||
ATOM_VELOCITIES = None
|
||
ATOM_FIXED = None
|
||
BOND_CONNECTION_FILE = None
|
||
BOND_PARAMETER_FILE = None
|
||
BOND_PAIRS = None
|
||
BOND_NAMES = None
|
||
BOND_STIFFNESS = None
|
||
BOND_REST_LENGTHS = None
|
||
PLOT_ATOM_ID = None
|
||
PLOT_ATOM_ROW = None
|
||
G = None
|
||
B = None
|
||
METHOD = None
|
||
NT = None
|
||
DT = None
|
||
NSTEP = None
|
||
warmup_steps = None # 预热步数(跳过不保存)
|
||
sample_start = None # 抽帧起始索引
|
||
sample_end = None # 抽帧结束索引
|
||
ball_radius = None
|
||
ball_color_r = None
|
||
ball_color_g = None
|
||
ball_color_b = None
|
||
box_color_r = None
|
||
box_color_g = None
|
||
box_color_b = None
|
||
|
||
# 派生边界(根据 box_a 计算)
|
||
X_MIN = None
|
||
X_MAX = None
|
||
Y_MIN = None
|
||
Y_MAX = None
|
||
Z_MIN = None
|
||
Z_MAX = None
|
||
|
||
|
||
def _to_text_value(value):
|
||
"""Convert numpy-heavy objects into JSON-friendly plain Python values."""
|
||
if isinstance(value, np.ndarray):
|
||
return value.tolist()
|
||
if isinstance(value, (np.integer,)):
|
||
return int(value)
|
||
if isinstance(value, (np.floating,)):
|
||
return float(value)
|
||
if isinstance(value, dict):
|
||
return {key: _to_text_value(val) for key, val in value.items()}
|
||
if isinstance(value, (list, tuple)):
|
||
return [_to_text_value(item) for item in value]
|
||
return value
|
||
|
||
|
||
def _from_text_value(value):
|
||
"""Convert nested numeric lists back into numpy arrays when appropriate."""
|
||
if isinstance(value, list):
|
||
if not value:
|
||
return np.array([], dtype=np.float64)
|
||
if all(not isinstance(item, (list, dict)) for item in value):
|
||
if all(isinstance(item, (int, float, bool)) for item in value):
|
||
return np.array(value)
|
||
return value
|
||
if all(isinstance(item, list) for item in value):
|
||
return np.array(value)
|
||
return [_from_text_value(item) for item in value]
|
||
if isinstance(value, dict):
|
||
return {key: _from_text_value(val) for key, val in value.items()}
|
||
return value
|
||
|
||
|
||
def save_text_data(path, data):
|
||
"""Save structured simulation data as formatted JSON text."""
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
json.dump(_to_text_value(data), f, ensure_ascii=False, indent=2)
|
||
return path
|
||
|
||
|
||
def load_text_data(path):
|
||
"""Load structured simulation data from JSON text."""
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return _from_text_value(data)
|
||
|
||
|
||
def get_output_dir(base_dir=None):
|
||
"""Return the output directory used for generated artifacts."""
|
||
override = os.environ.get("DYNAMICS_OUTPUT_DIR")
|
||
if override:
|
||
output_dir = os.path.abspath(override)
|
||
else:
|
||
if base_dir is None:
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
output_dir = os.path.join(base_dir, "output")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
return output_dir
|
||
|
||
|
||
def get_input_dir(base_dir=None):
|
||
"""Return the input directory used for configuration and source data."""
|
||
override = os.environ.get("DYNAMICS_INPUT_DIR")
|
||
if override:
|
||
input_dir = os.path.abspath(override)
|
||
else:
|
||
if base_dir is None:
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
input_dir = os.path.join(base_dir, "input")
|
||
os.makedirs(input_dir, exist_ok=True)
|
||
return input_dir
|
||
|
||
|
||
def load_coord_file(coord_path):
|
||
"""Load atom ids, masses, radii, coordinates, and velocities."""
|
||
if not os.path.exists(coord_path):
|
||
raise FileNotFoundError(f"坐标文件不存在: {coord_path}")
|
||
|
||
rows = []
|
||
with open(coord_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
expected = ["n", "mass", "radius", "x", "y", "z", "vx", "vy", "vz"]
|
||
legacy = expected + ["fixed_x", "fixed_y", "fixed_z"]
|
||
if header not in (expected, legacy):
|
||
raise ValueError(
|
||
f"坐标文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if header == expected and len(parts) != 9:
|
||
raise ValueError(f"{coord_path}:{line_no} 应有 9 列,实际为 {len(parts)} 列")
|
||
if header == legacy and len(parts) != 12:
|
||
raise ValueError(f"{coord_path}:{line_no} 应有 12 列,实际为 {len(parts)} 列")
|
||
rows.append(parts)
|
||
|
||
if not rows:
|
||
raise ValueError(f"坐标文件没有原子数据: {coord_path}")
|
||
|
||
atom_ids = np.array([int(row[0]) for row in rows], dtype=np.int64)
|
||
masses = np.array([float(row[1]) for row in rows], dtype=np.float64)
|
||
radii = np.array([float(row[2]) for row in rows], dtype=np.float64)
|
||
positions = np.array([[float(row[3]), float(row[4]), float(row[5])] for row in rows],
|
||
dtype=np.float64)
|
||
velocities = np.array([[float(row[6]), float(row[7]), float(row[8])] for row in rows],
|
||
dtype=np.float64)
|
||
fixed = np.zeros((len(rows), 3), dtype=np.int64)
|
||
if header == legacy:
|
||
fixed = np.array([[int(row[9]), int(row[10]), int(row[11])] for row in rows],
|
||
dtype=np.int64)
|
||
|
||
if np.any(masses <= 0):
|
||
raise ValueError("坐标文件中的质量必须为正数")
|
||
if np.any(radii <= 0):
|
||
raise ValueError("坐标文件中的半径必须为正数")
|
||
|
||
print(f"[compute] 已加载坐标文件: {coord_path},原子数={len(atom_ids)}")
|
||
return atom_ids, masses, radii, positions, velocities, fixed
|
||
|
||
|
||
def load_bond_parameters(bond_path):
|
||
"""Load bond stiffness and optional rest length by bond name.
|
||
|
||
File format (表头):
|
||
bond_name k # 2 列:从初始坐标计算键长
|
||
bond_name k rest_length # 3 列:显式指定键长
|
||
"""
|
||
if not os.path.exists(bond_path):
|
||
raise FileNotFoundError(f"键参数文件不存在: {bond_path}")
|
||
|
||
bond_map = {}
|
||
with open(bond_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
|
||
if header == ["bond_name", "k"]:
|
||
has_rest_length = False
|
||
elif header == ["bond_name", "k", "rest_length"]:
|
||
has_rest_length = True
|
||
else:
|
||
expected = ["bond_name", "k"] + (["rest_length"] if len(header) == 3 else [])
|
||
raise ValueError(
|
||
f"键参数文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
|
||
|
||
ncols = 3 if has_rest_length else 2
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if len(parts) != ncols:
|
||
raise ValueError(
|
||
f"{bond_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)} 列")
|
||
bond_name = parts[0]
|
||
stiffness = float(parts[1])
|
||
if stiffness < 0:
|
||
raise ValueError(f"{bond_path}:{line_no} 劲度系数必须非负")
|
||
|
||
rest_length = None
|
||
if has_rest_length:
|
||
rest_length = float(parts[2])
|
||
if rest_length <= 0:
|
||
raise ValueError(f"{bond_path}:{line_no} 键长必须为正数")
|
||
|
||
bond_map[bond_name] = {"stiffness": stiffness, "rest_length": rest_length}
|
||
|
||
if not bond_map:
|
||
raise ValueError(f"键参数文件没有有效数据: {bond_path}")
|
||
return bond_map
|
||
|
||
|
||
def load_bond_connections(connection_path, atom_ids, atom_positions, bond_map):
|
||
"""Load bonded atom pairs and derive rest lengths.
|
||
|
||
键长优先级:
|
||
1. bond_map 中显式指定的 rest_length(bond.txt 第三列)
|
||
2. 否则从初始坐标计算
|
||
"""
|
||
if not os.path.exists(connection_path):
|
||
raise FileNotFoundError(f"成键连接文件不存在: {connection_path}")
|
||
|
||
atom_index = {int(atom_id): idx for idx, atom_id in enumerate(atom_ids)}
|
||
pairs = []
|
||
bond_names = []
|
||
stiffness = []
|
||
rest_lengths = []
|
||
|
||
with open(connection_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
expected = ["n1", "n2", "bond_name"]
|
||
if header != expected:
|
||
raise ValueError(
|
||
f"成键连接文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if len(parts) != 3:
|
||
raise ValueError(f"{connection_path}:{line_no} 应有 3 列,实际为 {len(parts)} 列")
|
||
|
||
atom_1 = int(parts[0])
|
||
atom_2 = int(parts[1])
|
||
bond_name = parts[2]
|
||
|
||
if atom_1 not in atom_index or atom_2 not in atom_index:
|
||
raise ValueError(
|
||
f"{connection_path}:{line_no} 中的原子序号 {atom_1}, {atom_2} 不在坐标文件里")
|
||
if bond_name not in bond_map:
|
||
raise ValueError(
|
||
f"{connection_path}:{line_no} 使用了未定义的键类型 {bond_name}")
|
||
|
||
idx_1 = atom_index[atom_1]
|
||
idx_2 = atom_index[atom_2]
|
||
|
||
bond_entry = bond_map[bond_name]
|
||
rest_length = bond_entry["rest_length"]
|
||
if rest_length is None:
|
||
# 未指定键长时从初始坐标计算
|
||
delta = atom_positions[idx_2] - atom_positions[idx_1]
|
||
rest_length = float(np.linalg.norm(delta))
|
||
|
||
pairs.append((idx_1, idx_2))
|
||
bond_names.append(bond_name)
|
||
stiffness.append(bond_entry["stiffness"])
|
||
rest_lengths.append(rest_length)
|
||
|
||
if not pairs:
|
||
return (
|
||
np.zeros((0, 2), dtype=np.int64),
|
||
[],
|
||
np.zeros(0, dtype=np.float64),
|
||
np.zeros(0, dtype=np.float64),
|
||
)
|
||
|
||
return (
|
||
np.array(pairs, dtype=np.int64),
|
||
bond_names,
|
||
np.array(stiffness, dtype=np.float64),
|
||
np.array(rest_lengths, dtype=np.float64),
|
||
)
|
||
|
||
|
||
def parse_gravity_vector(value):
|
||
"""Parse gravity into a 3D acceleration vector.
|
||
|
||
Backward compatibility:
|
||
- scalar `G: 9.8` -> [0, 0, -9.8]
|
||
- vector `G: [gx, gy, gz]` -> [gx, gy, gz]
|
||
"""
|
||
if isinstance(value, (int, float, np.integer, np.floating)):
|
||
return np.array([0.0, 0.0, -float(value)], dtype=np.float64)
|
||
vector = np.asarray(value, dtype=np.float64)
|
||
if vector.shape != (3,):
|
||
raise ValueError(f"G 必须是长度为 3 的分量数组,实际为 {value}")
|
||
return vector
|
||
|
||
|
||
def parse_damping_vector(value):
|
||
"""Parse damping into a 3D component vector.
|
||
|
||
Backward compatibility:
|
||
- scalar `B: 0.5` -> [0.5, 0.5, 0.5]
|
||
- vector `B: [bx, by, bz]` -> [bx, by, bz]
|
||
"""
|
||
if isinstance(value, (int, float, np.integer, np.floating)):
|
||
scalar = float(value)
|
||
return np.array([scalar, scalar, scalar], dtype=np.float64)
|
||
vector = np.asarray(value, dtype=np.float64)
|
||
if vector.shape != (3,):
|
||
raise ValueError(f"B 必须是长度为 3 的分量数组,实际为 {value}")
|
||
return vector
|
||
|
||
|
||
def run_from_config(config, out_dir=None):
|
||
"""直接接收配置字典,运行模拟并保存结果。
|
||
|
||
config 字典结构:
|
||
box_a, alpha, coord_file, connection_file, bond_file, plot_atom, G, B, method, NT, DT, NSTEP,
|
||
warmup_steps, sample_start, sample_end(可选),
|
||
ball_radius, ball_color_r/g/b, box_color_r/g/b(可选)
|
||
|
||
out_dir: 输出目录,默认同脚本所在目录
|
||
返回:(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
|
||
"""
|
||
global box_a, alpha, COORD_FILE, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
|
||
global PLOT_ATOM_ID, PLOT_ATOM_ROW
|
||
global ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS, ATOM_VELOCITIES, ATOM_FIXED
|
||
global BOND_CONNECTION_FILE, BOND_PARAMETER_FILE, BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS, BOND_REST_LENGTHS
|
||
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
|
||
global ball_radius, ball_color_r, ball_color_g, ball_color_b
|
||
global box_color_r, box_color_g, box_color_b
|
||
global warmup_steps, sample_start, sample_end
|
||
|
||
box_a = float(config.get("box_a", 10.0))
|
||
raw_alpha = config.get("alpha", 0.2)
|
||
if isinstance(raw_alpha, (list, tuple)):
|
||
alpha = [float(a) for a in raw_alpha]
|
||
else:
|
||
alpha = float(raw_alpha)
|
||
COORD_FILE = str(config.get("coord_file", os.path.join("input", "coord.txt")))
|
||
coord_path = COORD_FILE
|
||
if out_dir is not None and not os.path.isabs(coord_path):
|
||
coord_path = os.path.join(out_dir, coord_path)
|
||
(ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS,
|
||
ATOM_VELOCITIES, ATOM_FIXED) = load_coord_file(coord_path)
|
||
BOND_CONNECTION_FILE = str(config.get("connection_file", os.path.join("input", "connection.txt")))
|
||
BOND_PARAMETER_FILE = str(config.get("bond_file", os.path.join("input", "bond.txt")))
|
||
connection_path = BOND_CONNECTION_FILE
|
||
bond_path = BOND_PARAMETER_FILE
|
||
if out_dir is not None and not os.path.isabs(connection_path):
|
||
connection_path = os.path.join(out_dir, connection_path)
|
||
if out_dir is not None and not os.path.isabs(bond_path):
|
||
bond_path = os.path.join(out_dir, bond_path)
|
||
bond_map = load_bond_parameters(bond_path)
|
||
(BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS,
|
||
BOND_REST_LENGTHS) = load_bond_connections(
|
||
connection_path, ATOM_IDS, ATOM_POSITIONS, bond_map)
|
||
PLOT_ATOM_ID = int(config.get("plot_atom", int(ATOM_IDS[0])))
|
||
matches = np.where(ATOM_IDS == PLOT_ATOM_ID)[0]
|
||
if len(matches) == 0:
|
||
raise ValueError(f"plot_atom={PLOT_ATOM_ID} 不在坐标文件原子序号中")
|
||
PLOT_ATOM_ROW = int(matches[0])
|
||
|
||
M = float(ATOM_MASSES[PLOT_ATOM_ROW])
|
||
X0, Y0, Z0 = [float(v) for v in ATOM_POSITIONS[PLOT_ATOM_ROW]]
|
||
VX0, VY0, VZ0 = [float(v) for v in ATOM_VELOCITIES[PLOT_ATOM_ROW]]
|
||
G = parse_gravity_vector(config["G"])
|
||
B = parse_damping_vector(config["B"])
|
||
METHOD = normalize_method_name(config.get("method", "explicit_euler"))
|
||
NT = int(config["NT"])
|
||
DT = float(config["DT"])
|
||
NSTEP = int(config["NSTEP"])
|
||
|
||
# 步骤控制参数(可选,有默认值)
|
||
warmup_steps = int(config.get("warmup_steps", 0))
|
||
sample_start = config.get("sample_start") # None 表示从头开始
|
||
sample_end = config.get("sample_end") # None 表示到末尾
|
||
|
||
X_MIN = -box_a; X_MAX = box_a
|
||
Y_MIN = -box_a; Y_MAX = box_a
|
||
Z_MIN = -box_a; Z_MAX = box_a
|
||
|
||
ball_radius = float(config.get("ball_radius", ATOM_RADII[PLOT_ATOM_ROW]))
|
||
ball_color_r = float(config.get("ball_color_r", 0.9))
|
||
ball_color_g = float(config.get("ball_color_g", 0.2))
|
||
ball_color_b = float(config.get("ball_color_b", 0.2))
|
||
box_color_r = float(config.get("box_color_r", 0.8))
|
||
box_color_g = float(config.get("box_color_g", 0.8))
|
||
box_color_b = float(config.get("box_color_b", 0.85))
|
||
|
||
print(f"[compute] 使用算法: {METHOD}")
|
||
print(f"[compute] 已加载成键信息: {len(BOND_PAIRS)} 条键")
|
||
if config.get("_skip_run", False):
|
||
return None, None, None, None, None, None
|
||
t_start = time.time()
|
||
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation()
|
||
t_end = time.time()
|
||
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
elapsed = t_end - t_start
|
||
|
||
# 写入 dynamics.log
|
||
log_dir = get_output_dir(out_dir)
|
||
log_path = os.path.join(log_dir, "dynamics.log")
|
||
with open(log_path, "w", encoding="utf-8") as f:
|
||
f.write("=" * 50 + "\n")
|
||
f.write("Dynamics 计算日志\n")
|
||
f.write("=" * 50 + "\n")
|
||
f.write(f"计算开始时间: {t_start_str}\n")
|
||
f.write(f"计算结束时间: {t_end_str}\n")
|
||
f.write(f"计算耗时: {elapsed:.3f} s\n")
|
||
f.write("-" * 50 + "\n")
|
||
f.write("计算参数:\n")
|
||
f.write(f" box_a: {box_a}\n")
|
||
_alpha_str = alpha if isinstance(alpha, str) else (str(alpha) if isinstance(alpha, list) else str(alpha))
|
||
f.write(f" alpha: {_alpha_str}\n")
|
||
f.write(f" coord_file: {COORD_FILE}\n")
|
||
f.write(f" method: {METHOD}\n")
|
||
f.write(f" NT: {NT}\n")
|
||
f.write(f" DT: {DT}\n")
|
||
f.write(f" NSTEP: {NSTEP}\n")
|
||
f.write(f" G: ({G[0]:.2f}, {G[1]:.2f}, {G[2]:.2f})\n")
|
||
f.write(f" B: ({B[0]:.2f}, {B[1]:.2f}, {B[2]:.2f})\n")
|
||
f.write(f" warmup: {warmup_steps}\n")
|
||
f.write(f" 原子数: {len(ATOM_IDS)}\n")
|
||
f.write(f" 键数: {len(BOND_PAIRS)}\n")
|
||
f.write("-" * 50 + "\n")
|
||
f.write("初始状态 (plot_atom):\n")
|
||
f.write(f" position: ({X0:.6f}, {Y0:.6f}, {Z0:.6f})\n")
|
||
f.write(f" velocity: ({VX0:.6f}, {VY0:.6f}, {VZ0:.6f})\n")
|
||
f.write(f" mass: {M}\n")
|
||
f.write("=" * 50 + "\n")
|
||
print(f"[compute] 日志已保存至: {log_path}")
|
||
print(f"[compute] 计算耗时: {elapsed:.3f} s")
|
||
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
|
||
|
||
|
||
def run_engine(engine, input_dir, output_dir, config):
|
||
"""调用外部计算引擎(C/C++/Fortran),生成 trajectory.txt。
|
||
|
||
Args:
|
||
engine: 引擎名称 ("c", "cpp", "fortran")
|
||
input_dir: 输入目录(含 coord.txt, connection.txt, bond.txt)
|
||
output_dir: 输出目录(轨迹文件写入位置)
|
||
config: YAML 配置字典
|
||
Returns:
|
||
(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
|
||
"""
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
system = platform.system().lower()
|
||
engine_map = {
|
||
"c": "engines/c/build/dynamics_c",
|
||
"cpp": "engines/cpp/build/dynamics_cpp",
|
||
"fortran": "engines/fortran/build/dynamics_f90",
|
||
}
|
||
if engine not in engine_map:
|
||
raise ValueError(f"不支持的引擎: {engine},可选: {list(engine_map.keys())}")
|
||
|
||
engine_rel = engine_map[engine]
|
||
engine_path = os.path.join(script_dir, engine_rel)
|
||
|
||
# 自动检测可执行文件后缀和平台专用版本
|
||
candidates = [
|
||
engine_path, # 无后缀
|
||
engine_path + ".exe", # Windows .exe
|
||
engine_path + f"_{system}.exe", # 平台专用 (c_linux.exe, c_darwin.exe)
|
||
]
|
||
found = None
|
||
for p in candidates:
|
||
if os.path.exists(p):
|
||
found = p
|
||
break
|
||
if found is None:
|
||
raise FileNotFoundError(
|
||
f"引擎可执行文件不存在: 尝试了 {candidates}\n"
|
||
f"请先编译: cd engines/{engine} && make\n"
|
||
f"或安装交叉编译器后: cd engines/{engine} && make {system}")
|
||
|
||
# 构造 param.json(数值参数)
|
||
G = parse_gravity_vector(config.get("G", [0, 0, -9.8]))
|
||
B = parse_damping_vector(config.get("B", [0, 0, 0]))
|
||
param_json = {
|
||
"box_a": float(config.get("box_a", 10.0)),
|
||
"NT": int(config["NT"]),
|
||
"DT": float(config["DT"]),
|
||
"NSTEP": int(config.get("NSTEP", 1)),
|
||
"warmup_steps": int(config.get("warmup_steps", 0)),
|
||
"G": [float(v) for v in G],
|
||
"B": [float(v) for v in B],
|
||
}
|
||
param_path = os.path.join(script_dir, "engines", engine, "param.json")
|
||
os.makedirs(os.path.dirname(param_path), exist_ok=True)
|
||
with open(param_path, "w", encoding="utf-8") as f:
|
||
json.dump(param_json, f, indent=2)
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
print(f"[compute] 引擎: {engine} → {os.path.basename(engine_path)}")
|
||
print(f"[compute] input: {input_dir}")
|
||
print(f"[compute] output: {output_dir}")
|
||
|
||
t_start = time.time()
|
||
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
result = subprocess.run(
|
||
[engine_path, os.path.abspath(input_dir), os.path.abspath(output_dir), param_path],
|
||
capture_output=True, text=True, timeout=600)
|
||
t_end = time.time()
|
||
elapsed = t_end - t_start
|
||
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# 打印引擎输出
|
||
if result.stdout:
|
||
for line in result.stdout.strip().split("\n"):
|
||
line = line.strip()
|
||
if line:
|
||
print(f" {line}")
|
||
if result.returncode != 0:
|
||
print(f"[compute] 引擎错误:\n{result.stderr}")
|
||
raise RuntimeError(f"引擎 {engine} 返回错误码 {result.returncode}")
|
||
|
||
# 加载输出的 trajectory.txt
|
||
traj_path = os.path.join(os.path.abspath(output_dir), "trajectory.txt")
|
||
if not os.path.exists(traj_path):
|
||
raise FileNotFoundError(f"引擎未生成 trajectory.txt: {traj_path}")
|
||
data = load_text_data(traj_path)
|
||
traj_x = data["traj_x"]
|
||
traj_y = data["traj_y"]
|
||
traj_z = data["traj_z"]
|
||
traj_vx = data["traj_vx"]
|
||
traj_vy = data["traj_vy"]
|
||
traj_vz = data["traj_vz"]
|
||
|
||
# 写入日志文件
|
||
log_path = os.path.join(output_dir, "dynamics.log")
|
||
with open(log_path, "w", encoding="utf-8") as f:
|
||
f.write("=" * 50 + "\n")
|
||
f.write("Dynamics 计算日志\n")
|
||
f.write("=" * 50 + "\n")
|
||
f.write(f"引擎: {engine}\n")
|
||
f.write(f"计算开始时间: {t_start_str}\n")
|
||
f.write(f"计算结束时间: {t_end_str}\n")
|
||
f.write(f"计算耗时: {elapsed:.3f} s\n")
|
||
|
||
print(f"[compute] 引擎完成: {len(traj_x)} 步, {traj_x.shape[1]} 原子, {elapsed:.3f} s")
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
|
||
|
||
|
||
def save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, out_dir=None):
|
||
"""将轨迹数据保存到结构化 txt 文件(含所有参数元数据)。"""
|
||
if out_dir is None:
|
||
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||
out_path = os.path.join(get_output_dir(out_dir), "trajectory.txt")
|
||
|
||
# 记录的实际步数(扣除预热后)
|
||
record_steps = len(traj_x)
|
||
sample_start_to_save = -1 if sample_start is None else int(sample_start)
|
||
sample_end_to_save = -1 if sample_end is None else int(sample_end)
|
||
|
||
payload = {
|
||
"traj_x": traj_x, "traj_y": traj_y, "traj_z": traj_z,
|
||
"traj_vx": traj_vx, "traj_vy": traj_vy, "traj_vz": traj_vz,
|
||
"NT": record_steps, "DT": DT, "NSTEP": NSTEP,
|
||
"method": METHOD,
|
||
"coord_file": COORD_FILE,
|
||
"connection_file": BOND_CONNECTION_FILE,
|
||
"bond_file": BOND_PARAMETER_FILE,
|
||
"atom_ids": ATOM_IDS,
|
||
"atom_masses": ATOM_MASSES,
|
||
"atom_radii": ATOM_RADII,
|
||
"atom_positions": ATOM_POSITIONS,
|
||
"atom_velocities": ATOM_VELOCITIES,
|
||
"atom_fixed": ATOM_FIXED,
|
||
"bond_pairs": BOND_PAIRS,
|
||
"bond_names": BOND_NAMES,
|
||
"bond_stiffness": BOND_STIFFNESS,
|
||
"bond_rest_lengths": BOND_REST_LENGTHS,
|
||
"G": G.tolist() if G is not None else None,
|
||
"B": B.tolist() if B is not None else None,
|
||
"plot_atom_id": PLOT_ATOM_ID,
|
||
"plot_atom_row": PLOT_ATOM_ROW,
|
||
"warmup_steps": warmup_steps or 0,
|
||
"sample_start": sample_start_to_save,
|
||
"sample_end": sample_end_to_save,
|
||
"X_MIN": X_MIN, "X_MAX": X_MAX,
|
||
"Y_MIN": Y_MIN, "Y_MAX": Y_MAX,
|
||
"Z_MIN": Z_MIN, "Z_MAX": Z_MAX,
|
||
"X0": X0, "Y0": Y0, "Z0": Z0,
|
||
"VX0": VX0, "VY0": VY0, "VZ0": VZ0,
|
||
"M": M,
|
||
"alpha": alpha,
|
||
"ball_radius": ball_radius,
|
||
"ball_color_r": ball_color_r,
|
||
"ball_color_g": ball_color_g,
|
||
"ball_color_b": ball_color_b,
|
||
"box_color_r": box_color_r,
|
||
"box_color_g": box_color_g,
|
||
"box_color_b": box_color_b,
|
||
}
|
||
save_text_data(out_path, payload)
|
||
print(f"[compute] 轨迹数据已保存至: {out_path}")
|
||
return out_path
|
||
|
||
|
||
def load_parameters(param_file):
|
||
"""从文本文件加载参数,文件格式:参数 = 值"""
|
||
global box_a, alpha, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
|
||
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
|
||
global ball_radius, ball_color_r, ball_color_g, ball_color_b
|
||
global box_color_r, box_color_g, box_color_b
|
||
|
||
print(f"[compute] 正在加载参数文件: {param_file}")
|
||
with open(param_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line or line.startswith('#'):
|
||
continue
|
||
if '=' not in line:
|
||
continue
|
||
# 分割键值,并去除行内注释
|
||
key, value = line.split('=', 1)
|
||
key = key.strip()
|
||
value = value.strip()
|
||
# 去除值中可能存在的注释(# 之后的部分)
|
||
if '#' in value:
|
||
value = value.split('#', 1)[0].strip()
|
||
# 尝试转换为 int 或 float
|
||
try:
|
||
if '.' in value or 'e' in value.lower():
|
||
value = float(value)
|
||
else:
|
||
value = int(value)
|
||
except ValueError:
|
||
pass # 保持字符串
|
||
|
||
# 设置全局变量
|
||
if key in globals():
|
||
globals()[key] = value
|
||
else:
|
||
print(f"警告: 未知参数 '{key}',忽略")
|
||
|
||
# 计算派生边界
|
||
if box_a is not None:
|
||
X_MIN = -box_a
|
||
X_MAX = box_a
|
||
Y_MIN = -box_a
|
||
Y_MAX = box_a
|
||
Z_MIN = -box_a
|
||
Z_MAX = box_a
|
||
else:
|
||
print("错误: 未找到 box_a 参数")
|
||
sys.exit(1)
|
||
|
||
# 检查必需参数
|
||
required = ['alpha', 'X0', 'Y0', 'Z0', 'VX0', 'VY0', 'VZ0', 'M', 'G', 'B', 'NT', 'DT', 'NSTEP']
|
||
for param in required:
|
||
if globals()[param] is None:
|
||
print(f"错误: 未找到必需参数 '{param}'")
|
||
sys.exit(1)
|
||
|
||
METHOD = normalize_method_name(METHOD or "explicit_euler")
|
||
|
||
# 设置绘图参数的默认值(如果未提供)
|
||
if ball_radius is None:
|
||
ball_radius = 0.28
|
||
if ball_color_r is None:
|
||
ball_color_r = 0.9
|
||
if ball_color_g is None:
|
||
ball_color_g = 0.2
|
||
if ball_color_b is None:
|
||
ball_color_b = 0.2
|
||
if box_color_r is None:
|
||
box_color_r = 0.8
|
||
if box_color_g is None:
|
||
box_color_g = 0.8
|
||
if box_color_b is None:
|
||
box_color_b = 0.85
|
||
|
||
print(f"[compute] 参数加载完成: box_a={box_a}, alpha={alpha}, NT={NT}, DT={DT}, method={METHOD}")
|
||
|
||
def normalize_method_name(method):
|
||
"""Normalize method aliases from YAML/text config."""
|
||
aliases = {
|
||
"explicit_euler": "explicit_euler",
|
||
"explicit": "explicit_euler",
|
||
"euler_explicit": "explicit_euler",
|
||
"显式欧拉": "explicit_euler",
|
||
"显式欧拉法": "explicit_euler",
|
||
"implicit_euler": "implicit_euler",
|
||
"implicit": "implicit_euler",
|
||
"euler_implicit": "implicit_euler",
|
||
"隐式欧拉": "implicit_euler",
|
||
"隐式欧拉法": "implicit_euler",
|
||
"midpoint": "midpoint",
|
||
"mid_point": "midpoint",
|
||
"midpoint_method": "midpoint",
|
||
"中点法": "midpoint",
|
||
"中点差分法": "midpoint",
|
||
"leapfrog": "leapfrog",
|
||
"leap_frog": "leapfrog",
|
||
"蛙跳法": "leapfrog",
|
||
}
|
||
key = str(method).strip().lower()
|
||
if key not in aliases:
|
||
valid = ", ".join(["explicit_euler", "implicit_euler", "midpoint", "leapfrog"])
|
||
raise ValueError(f"未知算法 '{method}'。可选值: {valid}")
|
||
return aliases[key]
|
||
|
||
|
||
def compute_force(x, y, z, vx, vy, vz, m, g, b):
|
||
"""Compute total force from the current state.
|
||
|
||
Current model:
|
||
- gravity: F = m * g_vec
|
||
- linear drag: F_drag = -B_vec * v
|
||
- spring bonds: Hooke force based on bonded pair distance
|
||
"""
|
||
fx = m * g[0] - b[0] * vx
|
||
fy = m * g[1] - b[1] * vy
|
||
fz = m * g[2] - b[2] * vz
|
||
|
||
if BOND_PAIRS is not None and len(BOND_PAIRS) > 0:
|
||
for bond_idx, (idx_1, idx_2) in enumerate(BOND_PAIRS):
|
||
dx = x[idx_2] - x[idx_1]
|
||
dy = y[idx_2] - y[idx_1]
|
||
dz = z[idx_2] - z[idx_1]
|
||
dist = np.sqrt(dx * dx + dy * dy + dz * dz)
|
||
if dist <= 1e-12:
|
||
continue
|
||
|
||
stretch = dist - BOND_REST_LENGTHS[bond_idx]
|
||
force_mag = BOND_STIFFNESS[bond_idx] * stretch
|
||
ux = dx / dist
|
||
uy = dy / dist
|
||
uz = dz / dist
|
||
|
||
fx_bond = force_mag * ux
|
||
fy_bond = force_mag * uy
|
||
fz_bond = force_mag * uz
|
||
|
||
fx[idx_1] += fx_bond
|
||
fy[idx_1] += fy_bond
|
||
fz[idx_1] += fz_bond
|
||
fx[idx_2] -= fx_bond
|
||
fy[idx_2] -= fy_bond
|
||
fz[idx_2] -= fz_bond
|
||
|
||
return fx, fy, fz
|
||
|
||
|
||
def compute_acceleration(x, y, z, vx, vy, vz, m, g, b):
|
||
"""Compute acceleration from the shared force model."""
|
||
fx, fy, fz = compute_force(x, y, z, vx, vy, vz, m, g, b)
|
||
return fx / m, fy / m, fz / m
|
||
|
||
|
||
def Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
|
||
x = x + vx * dt
|
||
y = y + vy * dt
|
||
z = z + vz * dt
|
||
vx = vx + ax * dt
|
||
vy = vy + ay * dt
|
||
vz = vz + az * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
gamma_x = b[0] / m
|
||
gamma_y = b[1] / m
|
||
gamma_z = b[2] / m
|
||
vx_next = (vx + g[0] * dt) / (1.0 + gamma_x * dt)
|
||
vy_next = (vy + g[1] * dt) / (1.0 + gamma_y * dt)
|
||
vz_next = (vz + g[2] * dt) / (1.0 + gamma_z * dt)
|
||
ax_next, ay_next, az_next = compute_acceleration(
|
||
x, y, z, vx_next, vy_next, vz_next, m, g, b)
|
||
|
||
vx = vx + ax_next * dt
|
||
vy = vy + ay_next * dt
|
||
vz = vz + az_next * dt
|
||
x = x + vx * dt
|
||
y = y + vy * dt
|
||
z = z + vz * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
|
||
x_mid = x + 0.5 * vx * dt
|
||
y_mid = y + 0.5 * vy * dt
|
||
z_mid = z + 0.5 * vz * dt
|
||
vx_mid = vx + 0.5 * ax * dt
|
||
vy_mid = vy + 0.5 * ay * dt
|
||
vz_mid = vz + 0.5 * az * dt
|
||
|
||
x = x + vx_mid * dt
|
||
y = y + vy_mid * dt
|
||
z = z + vz_mid * dt
|
||
|
||
ax_mid, ay_mid, az_mid = compute_acceleration(
|
||
x_mid, y_mid, z_mid, vx_mid, vy_mid, vz_mid, m, g, b)
|
||
vx = vx + ax_mid * dt
|
||
vy = vy + ay_mid * dt
|
||
vz = vz + az_mid * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
|
||
vx_half = vx + 0.5 * ax * dt
|
||
vy_half = vy + 0.5 * ay * dt
|
||
vz_half = vz + 0.5 * az * dt
|
||
|
||
x = x + vx_half * dt
|
||
y = y + vy_half * dt
|
||
z = z + vz_half * dt
|
||
|
||
gamma_x = b[0] / m
|
||
gamma_y = b[1] / m
|
||
gamma_z = b[2] / m
|
||
vx_next = (vx_half + 0.5 * g[0] * dt) / (1.0 + 0.5 * gamma_x * dt)
|
||
vy_next = (vy_half + 0.5 * g[1] * dt) / (1.0 + 0.5 * gamma_y * dt)
|
||
vz_next = (vz_half + 0.5 * g[2] * dt) / (1.0 + 0.5 * gamma_z * dt)
|
||
ax_next, ay_next, az_next = compute_acceleration(
|
||
x, y, z, vx_next, vy_next, vz_next, m, g, b)
|
||
vx = vx_half + 0.5 * ax_next * dt
|
||
vy = vy_half + 0.5 * ay_next * dt
|
||
vz = vz_half + 0.5 * az_next * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
def Limit_in_box(a, amin, amax, va):
|
||
"""限制物体在边界内,发生碰撞时反弹。"""
|
||
over = a > amax
|
||
under = a < amin
|
||
a = np.where(over, amax, a)
|
||
a = np.where(under, amin, a)
|
||
va = np.where(over | under, -va, va)
|
||
return a, va
|
||
|
||
def apply_motion_update(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
"""按配置选择的位置更新算法推进一步。"""
|
||
if METHOD == "explicit_euler":
|
||
x, y, z, vx, vy, vz = Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
elif METHOD == "implicit_euler":
|
||
x, y, z, vx, vy, vz = Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
elif METHOD == "midpoint":
|
||
x, y, z, vx, vy, vz = Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
elif METHOD == "leapfrog":
|
||
x, y, z, vx, vy, vz = Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
else:
|
||
raise ValueError(f"未知算法: {METHOD}")
|
||
|
||
x, vx = Limit_in_box(x, X_MIN, X_MAX, vx)
|
||
y, vy = Limit_in_box(y, Y_MIN, Y_MAX, vy)
|
||
z, vz = Limit_in_box(z, Z_MIN, Z_MAX, vz)
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def apply_fixed_constraints(x, y, z, vx, vy, vz):
|
||
"""Keep fixed degrees of freedom at their initial coordinate with zero speed."""
|
||
fixed = ATOM_FIXED != 0
|
||
positions = np.column_stack((x, y, z))
|
||
velocities = np.column_stack((vx, vy, vz))
|
||
positions = np.where(fixed, ATOM_POSITIONS, positions)
|
||
velocities = np.where(fixed, 0.0, velocities)
|
||
return (
|
||
positions[:, 0], positions[:, 1], positions[:, 2],
|
||
velocities[:, 0], velocities[:, 1], velocities[:, 2],
|
||
)
|
||
|
||
def wrap_position(x, y, z):
|
||
"""边界回绕( dynamics 模式)。"""
|
||
x = np.where(x > X_MAX, X_MIN, x)
|
||
x = np.where(x < X_MIN, X_MAX, x)
|
||
y = np.where(y > Y_MAX, Y_MIN, y)
|
||
y = np.where(y < Y_MIN, Y_MAX, y)
|
||
z = np.where(z > Z_MAX, Z_MIN, z)
|
||
z = np.where(z < Z_MIN, Z_MAX, z)
|
||
return x, y, z
|
||
|
||
|
||
# ===========================================================================
|
||
# 主计算流程
|
||
# ===========================================================================
|
||
|
||
def run_simulation():
|
||
"""计算 NT 步轨迹,返回位置/速度数组。
|
||
|
||
步骤控制:
|
||
- warmup_steps: 预热步数,跳过不记录(用于稳定初始状态)
|
||
- 实际记录步数 = NT - warmup_steps
|
||
"""
|
||
# 预热阶段
|
||
x = ATOM_POSITIONS[:, 0].copy()
|
||
y = ATOM_POSITIONS[:, 1].copy()
|
||
z = ATOM_POSITIONS[:, 2].copy()
|
||
vx = ATOM_VELOCITIES[:, 0].copy()
|
||
vy = ATOM_VELOCITIES[:, 1].copy()
|
||
vz = ATOM_VELOCITIES[:, 2].copy()
|
||
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
|
||
|
||
if warmup_steps is not None and warmup_steps > 0:
|
||
print(f"[compute] 预热阶段: 前 {warmup_steps} 步不记录")
|
||
for step in trange(warmup_steps, desc="[compute] 预热"):
|
||
t = (step + 1) * DT
|
||
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
|
||
x, y, z = wrap_position(x, y, z)
|
||
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
|
||
print(
|
||
f"[compute] 预热完成,展示原子位置: "
|
||
f"({x[PLOT_ATOM_ROW]:.4f}, {y[PLOT_ATOM_ROW]:.4f}, {z[PLOT_ATOM_ROW]:.4f})"
|
||
)
|
||
|
||
# 记录阶段
|
||
record_steps = NT - (warmup_steps or 0)
|
||
n_atoms = len(ATOM_IDS)
|
||
traj_x = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_y = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_z = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_vx = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_vy = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_vz = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
|
||
for step in trange(record_steps, desc="[compute] 计算中"):
|
||
traj_x[step] = x
|
||
traj_y[step] = y
|
||
traj_z[step] = z
|
||
traj_vx[step] = vx
|
||
traj_vy[step] = vy
|
||
traj_vz[step] = vz
|
||
|
||
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
|
||
x, y, z = wrap_position(x, y, z)
|
||
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
|
||
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
|
||
|
||
|
||
def save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT):
|
||
"""将轨迹数据保存为逐行表格文本,便于教学查看。"""
|
||
if np.ndim(traj_x) > 1:
|
||
row = PLOT_ATOM_ROW if PLOT_ATOM_ROW is not None else 0
|
||
traj_x = traj_x[:, row]
|
||
traj_y = traj_y[:, row]
|
||
traj_z = traj_z[:, row]
|
||
traj_vx = traj_vx[:, row]
|
||
traj_vy = traj_vy[:, row]
|
||
traj_vz = traj_vz[:, row]
|
||
|
||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||
# 写入表头
|
||
f.write("# 步数, 时间(s), x, y, z, vx, vy, vz\n")
|
||
for step in range(NT):
|
||
t = step * DT
|
||
x = traj_x[step]
|
||
y = traj_y[step]
|
||
z = traj_z[step]
|
||
vx = traj_vx[step]
|
||
vy = traj_vy[step]
|
||
vz = traj_vz[step]
|
||
f.write(f"{step+1:6d}, {t:.6f}, {x:.6f}, {y:.6f}, {z:.6f}, {vx:.6f}, {vy:.6f}, {vz:.6f}\n")
|
||
print(f"[compute] 轨迹文本文件已保存至: {txt_path}")
|
||
|
||
|
||
def main():
|
||
# 默认参数文件
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
param_file = os.path.join(get_input_dir(script_dir), "input.txt")
|
||
if len(sys.argv) > 1:
|
||
param_file = sys.argv[1]
|
||
|
||
# 加载参数
|
||
load_parameters(param_file)
|
||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
output_dir = get_output_dir(script_dir)
|
||
|
||
print(f"[compute] 开始计算 NT={NT} DT={DT} ")
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation()
|
||
print(f"[compute] 计算完成,共 {NT} 步")
|
||
|
||
save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, script_dir)
|
||
|
||
# 同时保存为逐行表格,便于直接查看
|
||
txt_path = os.path.join(output_dir, "trajectory_table.txt")
|
||
save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|