Files
dynamics/compute.py
T
admin dc7bc00616 feat: C/C++ 引擎支持 save_trajectory=0 时直接写 display.txt
所有引擎(Python/C/C++)在 save_trajectory=0 时行为一致:
- 计算时按 NSTEP 抽帧,只存 sampled 缓冲区
- 直接写入 display.txt(新文本格式)
- 不生成 trajectory.txt

Python 引擎:run_simulation 已支持 
C 引擎:采样缓冲区 + write_display_txt 
C++ 引擎:采样缓冲区 + write_display_txt 
Fortran 引擎:待完成

compute.py run_engine:save_trajectory=0 时跳过 trajectory.txt 加载
dynamics.py:引擎直接输出 display.txt 时跳过抽帧步骤
2026-06-12 08:25:27 +08:00

1606 lines
61 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
compute.py
----------
纯 Python 计算脚本,不依赖 VisPy。
功能:
1. 运行 NT 步物理模拟( kinematics / dynamics 等运动模式)
2. 按 NSTEP 抽帧,输出 output/display.txt(新文本格式)
3. 可选(save_trajectory=1)保留完整轨迹 output/trajectory.txtJSON
"""
import json
import re
import numpy as np
import os
import platform
import subprocess
import sys
import time
import datetime
from tqdm import trange
# ===========================================================================
# 物理参数(必须与 draw.py 加载的边界常量保持一致)
# ===========================================================================
# 全局参数变量,将在 load_parameters 中填充
box_a = None
alpha = None
COORD_FILE = None
X0 = None
Y0 = None
Z0 = None
VX0 = None
VY0 = None
VZ0 = None
M = None
ATOM_IDS = None
ATOM_MASSES = None
ATOM_RADII = None
ATOM_POSITIONS = None
ATOM_VELOCITIES = None
ATOM_FIXED = None
BOND_CONNECTION_FILE = None
BOND_PARAMETER_FILE = None
BOND_PAIRS = None
BOND_NAMES = None
BOND_STIFFNESS = None
BOND_REST_LENGTHS = None
PLOT_ATOM_ID = None
PLOT_ATOM_ROW = None
G = None
B = None
METHOD = None
NT = None
DT = None
NSTEP = None
warmup_steps = None # 预热步数(跳过不保存)
sample_start = None # 抽帧起始索引
sample_end = None # 抽帧结束索引
ball_radius = None
ball_color_r = None
ball_color_g = None
ball_color_b = None
box_color_r = None
box_color_g = None
box_color_b = None
use_marker = 0
camera_keyframes_raw = ""
# 力开关
GRAVITY_FIELD = 1 # 均匀重力场
GRAVITY_INTERACTION = 0 # 原子间万有引力
ELASTIC_FORCE = 1 # 弹簧键力
DAMPING_FORCE = 0 # 阻尼
DRIVING_FORCE = 0 # 驱动力
GRAVITY_STRENGTH = 1.0
# 驱动力数据
DRIVER_DATA = None # 加载自 driver.txt
# 派生边界(根据 box_a 计算)
X_MIN = None
X_MAX = None
Y_MIN = None
Y_MAX = None
Z_MIN = None
Z_MAX = None
def _load_camera_motion(path):
"""读取 move_camera.txt(速度段格式),返回 JSON 字符串。
格式:每行是一个运动段
start-end vx=f1 vy=f2 vz=f3 rx=d1 ry=d2 rz=d3
示例:
1-60 vx=1.0 rx=10
30-90 vy=2.0 ry=20 rz=10
返回 JSON: [{"start":N,"end":N,"v":[x,y,z],"r":[x,y,z]},...]
"""
import re
if not os.path.exists(path):
print(f"[compute] 警告: 未找到 {path},跳过运动相机")
return ""
segments = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
# 解析帧范围:支持 "all"(全程)或 "N-M"(区间)
if line.lower().startswith("all") or re.match(r'^\s*all\s', line, re.IGNORECASE):
start, end = 0, 10**9
else:
m = re.match(r'(\d+)\s*-\s*(\d+)', line)
if not m:
continue
start, end = int(m.group(1)), int(m.group(2))
v = [0.0, 0.0, 0.0]
r = [0.0, 0.0, 0.0]
# 解析 vx=, vy=, vz=
for i, axis in enumerate(['x', 'y', 'z']):
m2 = re.search(r'v' + axis + r'\s*=\s*([-\d.]+)', line)
if m2:
v[i] = float(m2.group(1))
# 解析 rx=, ry=, rz=
for i, axis in enumerate(['x', 'y', 'z']):
m2 = re.search(r'r' + axis + r'\s*=\s*([-\d.]+)', line)
if m2:
r[i] = float(m2.group(1))
if any(v) or any(r):
segments.append({"start": start, "end": end, "v": v, "r": r})
if not segments:
return ""
import json
return json.dumps(segments)
def _to_text_value(value):
"""Convert numpy-heavy objects into JSON-friendly plain Python values."""
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, (np.integer,)):
return int(value)
if isinstance(value, (np.floating,)):
return float(value)
if isinstance(value, dict):
return {key: _to_text_value(val) for key, val in value.items()}
if isinstance(value, (list, tuple)):
return [_to_text_value(item) for item in value]
return value
def _from_text_value(value):
"""Convert nested numeric lists back into numpy arrays when appropriate."""
if isinstance(value, list):
if not value:
return np.array([], dtype=np.float64)
if all(not isinstance(item, (list, dict)) for item in value):
if all(isinstance(item, (int, float, bool)) for item in value):
return np.array(value)
return value
if all(isinstance(item, list) for item in value):
return np.array(value)
return [_from_text_value(item) for item in value]
if isinstance(value, dict):
return {key: _from_text_value(val) for key, val in value.items()}
return value
def save_text_data(path, data):
"""Save structured simulation data as formatted JSON text (旧格式,保留兼容)."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(_to_text_value(data), f, ensure_ascii=False, indent=2)
return path
def load_text_data(path):
"""Load structured simulation data from JSON text (旧格式)."""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return _from_text_value(data)
# ========================================================================
# 新 display.txt 格式:纯文本,按帧分块
# 第1行: number of frames: N
# 第2行: number of particles: M
# 第3行: frame: 1
# 第4行: n x y z vx vy vz
# 第5+行: 数据行(每个原子一行)
# 重复第3-5行直到所有帧
# ========================================================================
def save_display_txt(path, frames_x, frames_y, frames_z,
frames_vx, frames_vy, frames_vz,
atom_ids, n_total_frames, n_total_particles,
header_fields=None):
"""Write display.txt in new text format.
Args:
path: 输出文件路径
frames_x/y/z/vx/vy/vz: (n_frames, n_atoms) 数组
atom_ids: (n_atoms,) 原子编号数组
n_total_frames: 总帧数(含未采样)
n_total_particles: 总粒子数
header_fields: 可选的额外元数据字典(写入文件头之后)
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
n_frames = frames_x.shape[0]
n_atoms = frames_x.shape[1]
# 格式化辅助:固定宽度,6位小数
def fmt(v): return f"{v:13.6f}"
with open(path, "w", encoding="utf-8") as f:
f.write(f"number of frames: {n_total_frames}\n")
f.write(f"number of particles: {n_total_particles}\n")
# 写入额外元数据
if header_fields:
for k, v in header_fields.items():
f.write(f"{k}: {v}\n")
for fr in range(n_frames):
f.write(f"\nframe: {fr + 1}\n")
f.write(f"n x y z vx vy vz\n")
for a in range(n_atoms):
f.write(f"{atom_ids[a]:d}"
f"{fmt(frames_x[fr, a])}"
f"{fmt(frames_y[fr, a])}"
f"{fmt(frames_z[fr, a])}"
f"{fmt(frames_vx[fr, a])}"
f"{fmt(frames_vy[fr, a])}"
f"{fmt(frames_vz[fr, a])}\n")
return path
def load_display_txt(path):
"""Read display.txt new text format into numpy arrays(快速版).
Returns dict with keys: frames_x/y/z/vx/vy/vz, atom_ids,
n_total_frames, n_total_particles, header_fields
"""
import re
with open(path, "r", encoding="utf-8") as f:
raw = f.read()
# 解析 header 行
header_fields = {}
n_total_frames = 0
n_total_particles = 0
lines = raw.splitlines()
data_start = 0
for i, line in enumerate(lines):
line_stripped = line.strip()
if line_stripped.startswith("number of frames:"):
n_total_frames = int(line_stripped.split(":")[1].strip())
elif line_stripped.startswith("number of particles:"):
n_total_particles = int(line_stripped.split(":")[1].strip())
elif line_stripped.startswith("frame:"):
data_start = i
break
else:
if ":" in line_stripped:
k, v = line_stripped.split(":", 1)
header_fields[k.strip()] = v.strip()
# 快速定位所有数据行:跳过 frame header 和 column header
# 数据行格式:每行 7 个字段(n x y z vx vy vz),固定宽度列
data_text = []
i = data_start
n_frames = 0
while i < len(lines):
line = lines[i].strip()
if line.startswith("frame:"):
n_frames += 1
i += 2 # 跳过 "frame: N" 和列头行
continue
if line:
data_text.append(line)
i += 1
if n_frames == 0 or not data_text:
raise ValueError(f"{path} 中没有有效帧数据")
# 用 numpy 批量解析所有数据行(远比逐行 split+float 快)
data_array = np.genfromtxt(data_text, dtype=np.float64)
# data_array shape: (n_frames * n_atoms, 7) — 列: n, x, y, z, vx, vy, vz
n_atoms = n_total_particles
atoms_per_frame = len(data_text) // n_frames
# 提取原子ID(第一帧即可)
atom_ids = data_array[0:n_atoms, 0].astype(np.int64)
# 重塑为 (n_frames, n_atoms, 6) — 去掉第0列(原子ID)
all_data = data_array[:, 1:].reshape(n_frames, n_atoms, 6)
return {
"frames_x": all_data[:, :, 0],
"frames_y": all_data[:, :, 1],
"frames_z": all_data[:, :, 2],
"frames_vx": all_data[:, :, 3],
"frames_vy": all_data[:, :, 4],
"frames_vz": all_data[:, :, 5],
"atom_ids": atom_ids,
"n_total_frames": n_total_frames,
"n_total_particles": n_total_particles,
"header_fields": header_fields,
}
def get_output_dir(base_dir=None):
"""Return the output directory used for generated artifacts."""
override = os.environ.get("DYNAMICS_OUTPUT_DIR")
if override:
output_dir = os.path.abspath(override)
else:
if base_dir is None:
base_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(base_dir, "output")
os.makedirs(output_dir, exist_ok=True)
return output_dir
def get_input_dir(base_dir=None):
"""Return the input directory used for configuration and source data."""
override = os.environ.get("DYNAMICS_INPUT_DIR")
if override:
input_dir = os.path.abspath(override)
else:
if base_dir is None:
base_dir = os.path.dirname(os.path.abspath(__file__))
input_dir = os.path.join(base_dir, "input")
os.makedirs(input_dir, exist_ok=True)
return input_dir
def load_coord_file(coord_path):
"""Load atom ids, masses, radii, coordinates, and velocities."""
if not os.path.exists(coord_path):
raise FileNotFoundError(f"坐标文件不存在: {coord_path}")
rows = []
with open(coord_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
expected = ["n", "mass", "radius", "x", "y", "z", "vx", "vy", "vz"]
legacy = expected + ["fixed_x", "fixed_y", "fixed_z"]
legacy_new = expected + ["fix_x", "fix_y", "fix_z"]
if header not in (expected, legacy, legacy_new):
raise ValueError(
f"坐标文件表头应为: {' '.join(expected)} 或加三列 fix_x fix_y fix_z"
f"实际为: {' '.join(header)}")
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if header == expected and len(parts) != 9:
raise ValueError(f"{coord_path}:{line_no} 应有 9 列,实际为 {len(parts)}")
if header in (legacy, legacy_new) and len(parts) != 12:
raise ValueError(f"{coord_path}:{line_no} 应有 12 列,实际为 {len(parts)}")
rows.append(parts)
if not rows:
raise ValueError(f"坐标文件没有原子数据: {coord_path}")
atom_ids = np.array([int(row[0]) for row in rows], dtype=np.int64)
masses = np.array([float(row[1]) for row in rows], dtype=np.float64)
radii = np.array([float(row[2]) for row in rows], dtype=np.float64)
positions = np.array([[float(row[3]), float(row[4]), float(row[5])] for row in rows],
dtype=np.float64)
velocities = np.array([[float(row[6]), float(row[7]), float(row[8])] for row in rows],
dtype=np.float64)
fixed = np.zeros((len(rows), 3), dtype=np.int64)
if header in (legacy, legacy_new):
fixed = np.array([[int(row[9]), int(row[10]), int(row[11])] for row in rows],
dtype=np.int64)
if np.any(masses <= 0):
raise ValueError("坐标文件中的质量必须为正数")
if np.any(radii <= 0):
raise ValueError("坐标文件中的半径必须为正数")
print(f"[compute] 已加载坐标文件: {coord_path},原子数={len(atom_ids)}")
return atom_ids, masses, radii, positions, velocities, fixed
def load_bond_parameters(bond_path):
"""Load bond stiffness and optional rest length by bond name.
File format (表头):
bond_name k # 2 列:从初始坐标计算键长
bond_name k rest_length # 3 列:显式指定键长
"""
if not os.path.exists(bond_path):
raise FileNotFoundError(f"键参数文件不存在: {bond_path}")
bond_map = {}
with open(bond_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
if header == ["bond_name", "k"]:
has_rest_length = False
elif header == ["bond_name", "k", "rest_length"]:
has_rest_length = True
else:
expected = ["bond_name", "k"] + (["rest_length"] if len(header) == 3 else [])
raise ValueError(
f"键参数文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
ncols = 3 if has_rest_length else 2
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) != ncols:
raise ValueError(
f"{bond_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)}")
bond_name = parts[0]
stiffness = float(parts[1])
if stiffness < 0:
raise ValueError(f"{bond_path}:{line_no} 劲度系数必须非负")
rest_length = None
if has_rest_length:
rest_length = float(parts[2])
if rest_length <= 0:
raise ValueError(f"{bond_path}:{line_no} 键长必须为正数")
bond_map[bond_name] = {"stiffness": stiffness, "rest_length": rest_length}
if not bond_map:
raise ValueError(f"键参数文件没有有效数据: {bond_path}")
return bond_map
def load_bond_connections(connection_path, atom_ids, atom_positions, bond_map):
"""Load bonded atom pairs and derive rest lengths.
键长优先级:
1. bond_map 中显式指定的 rest_lengthbond.txt 第三列)
2. 否则从初始坐标计算
"""
if not os.path.exists(connection_path):
raise FileNotFoundError(f"成键连接文件不存在: {connection_path}")
atom_index = {int(atom_id): idx for idx, atom_id in enumerate(atom_ids)}
pairs = []
bond_names = []
stiffness = []
rest_lengths = []
with open(connection_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
expected = ["n1", "n2", "bond_name"]
if header != expected:
raise ValueError(
f"成键连接文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) != 3:
raise ValueError(f"{connection_path}:{line_no} 应有 3 列,实际为 {len(parts)}")
atom_1 = int(parts[0])
atom_2 = int(parts[1])
bond_name = parts[2]
if atom_1 not in atom_index or atom_2 not in atom_index:
raise ValueError(
f"{connection_path}:{line_no} 中的原子序号 {atom_1}, {atom_2} 不在坐标文件里")
if bond_name not in bond_map:
raise ValueError(
f"{connection_path}:{line_no} 使用了未定义的键类型 {bond_name}")
idx_1 = atom_index[atom_1]
idx_2 = atom_index[atom_2]
bond_entry = bond_map[bond_name]
rest_length = bond_entry["rest_length"]
if rest_length is None:
# 未指定键长时从初始坐标计算
delta = atom_positions[idx_2] - atom_positions[idx_1]
rest_length = float(np.linalg.norm(delta))
pairs.append((idx_1, idx_2))
bond_names.append(bond_name)
stiffness.append(bond_entry["stiffness"])
rest_lengths.append(rest_length)
if not pairs:
return (
np.zeros((0, 2), dtype=np.int64),
[],
np.zeros(0, dtype=np.float64),
np.zeros(0, dtype=np.float64),
)
return (
np.array(pairs, dtype=np.int64),
bond_names,
np.array(stiffness, dtype=np.float64),
np.array(rest_lengths, dtype=np.float64),
)
# ===========================================================================
# 驱动力加载 & 应用
# ===========================================================================
def load_driver_file(driver_path, atom_ids):
"""从 driver.txt 加载驱动力定义。
格式:
n amp_x amp_y amp_z freq_x freq_y freq_z phi_x phi_y phi_z period
数值 0 0 0 0 0 10 0 0 90 all
其中:
position = A * cos(2π f t + φ), φ 为角度制(自动转弧度)
period = all | 数值(周期数,结束后原子静止)
"""
if not os.path.exists(driver_path):
print(f"[compute] 警告: 驱动力文件不存在: {driver_path}")
return None
atom_index = {int(aid): idx for idx, aid in enumerate(atom_ids)}
drivers = []
ncols = 11
with open(driver_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
expected = ["n", "amp_x", "amp_y", "amp_z",
"freq_x", "freq_y", "freq_z",
"phi_x", "phi_y", "phi_z", "period"]
if header != expected:
raise ValueError(
f"driver.txt 表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) != ncols:
raise ValueError(
f"{driver_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)}")
n = int(parts[0])
if n not in atom_index:
raise ValueError(f"{driver_path}:{line_no} 原子序号 {n} 不在 coord.txt 中")
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])],
dtype=np.float64)
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])],
dtype=np.float64)
phi_deg = np.array([float(parts[7]), float(parts[8]), float(parts[9])],
dtype=np.float64)
phi_rad = phi_deg * np.pi / 180.0
period_str = parts[10]
drivers.append({
"atom_index": atom_index[n],
"atom_id": n,
"amp": amp,
"freq": freq,
"phi": phi_rad,
"period_str": period_str,
"period_cycles": None if period_str == "all" else float(period_str),
# 在模拟中动态记录:冻结步数索引、冻结位置
"freeze_step": None,
"freeze_pos": None,
})
if not drivers:
print(f"[compute] 警告: driver.txt 中没有有效数据")
return None
print(f"[compute] 已加载驱动力: {len(drivers)} 条定义")
for d in drivers:
print(f" 原子 {d['atom_id']}: "
f"A=({d['amp'][0]},{d['amp'][1]},{d['amp'][2]}), "
f"f=({d['freq'][0]},{d['freq'][1]},{d['freq'][2]}), "
f"φ=({phi_deg[d['amp'].tolist().index(max(d['amp']))]}° 等), "
f"period={d['period_str']}")
return drivers
def apply_driving_force(x, y, z, vx, vy, vz, t, step, drivers, dt):
"""对受驱原子按驱动力函数覆盖位置/速度。
驱动力公式:pos = A * cos(2π f t + φ)
vel = -A * 2π f * sin(2π f t + φ)
"""
if drivers is None:
return x, y, z, vx, vy, vz
for d in drivers:
idx = d["atom_index"]
# 确定驱动力持续到哪一步
if d["period_cycles"] is not None:
max_freq = np.max(np.abs(d["freq"]))
if max_freq > 1e-12:
period_duration = d["period_cycles"] / max_freq
period_steps = int(period_duration / dt)
else:
period_steps = 0
if step > period_steps:
# 驱动力已结束:原子静止(保持最后位置,速度归零)
if d["freeze_pos"] is not None:
x[idx], y[idx], z[idx] = d["freeze_pos"]
vx[idx] = vy[idx] = vz[idx] = 0.0
continue
else:
period_steps = None # 全程驱动
# 当前驱动力下的位置 / 速度
t_vec = np.array([t, t, t], dtype=np.float64)
pos_drive = d["amp"] * np.cos(2.0 * np.pi * d["freq"] * t_vec + d["phi"])
vel_drive = -d["amp"] * 2.0 * np.pi * d["freq"] * np.sin(2.0 * np.pi * d["freq"] * t_vec + d["phi"])
x[idx] = pos_drive[0]
y[idx] = pos_drive[1]
z[idx] = pos_drive[2]
vx[idx] = vel_drive[0]
vy[idx] = vel_drive[1]
vz[idx] = vel_drive[2]
# 若周期有限,记录冻结位置(驱动力最后一帧的位置)
if period_steps is not None and step == period_steps:
d["freeze_pos"] = (float(pos_drive[0]), float(pos_drive[1]), float(pos_drive[2]))
return x, y, z, vx, vy, vz
def parse_gravity_vector(value):
"""Parse gravity into a 3D acceleration vector.
Backward compatibility:
- scalar `G: 9.8` -> [0, 0, -9.8]
- vector `G: [gx, gy, gz]` -> [gx, gy, gz]
"""
if isinstance(value, (int, float, np.integer, np.floating)):
return np.array([0.0, 0.0, -float(value)], dtype=np.float64)
vector = np.asarray(value, dtype=np.float64)
if vector.shape != (3,):
raise ValueError(f"G 必须是长度为 3 的分量数组,实际为 {value}")
return vector
def parse_damping_vector(value):
"""Parse damping into a 3D component vector.
Backward compatibility:
- scalar `B: 0.5` -> [0.5, 0.5, 0.5]
- vector `B: [bx, by, bz]` -> [bx, by, bz]
"""
if isinstance(value, (int, float, np.integer, np.floating)):
scalar = float(value)
return np.array([scalar, scalar, scalar], dtype=np.float64)
vector = np.asarray(value, dtype=np.float64)
if vector.shape != (3,):
raise ValueError(f"B 必须是长度为 3 的分量数组,实际为 {value}")
return vector
def run_from_config(config, out_dir=None):
"""直接接收配置字典,运行模拟并保存结果。
config 字典结构:
box_a, alpha, coord_file, connection_file, bond_file, plot_atom, G, B, method, NT, DT, NSTEP,
warmup_steps, sample_start, sample_end(可选),
ball_radius, ball_color_r/g/b, box_color_r/g/b(可选)
out_dir: 输出目录,默认同脚本所在目录
返回:(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
"""
global box_a, alpha, COORD_FILE, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
global PLOT_ATOM_ID, PLOT_ATOM_ROW
global ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS, ATOM_VELOCITIES, ATOM_FIXED
global BOND_CONNECTION_FILE, BOND_PARAMETER_FILE, BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS, BOND_REST_LENGTHS
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
global ball_radius, ball_color_r, ball_color_g, ball_color_b
global box_color_r, box_color_g, box_color_b
global use_marker, camera_keyframes_raw
global warmup_steps, sample_start, sample_end
global GRAVITY_FIELD, GRAVITY_INTERACTION, ELASTIC_FORCE, DAMPING_FORCE, GRAVITY_STRENGTH
global DRIVING_FORCE, DRIVER_DATA
box_a = float(config.get("box_a", 10.0))
raw_alpha = config.get("alpha", 0.2)
if isinstance(raw_alpha, (list, tuple)):
alpha = [float(a) for a in raw_alpha]
else:
alpha = float(raw_alpha)
COORD_FILE = str(config.get("coord_file", os.path.join("input", "coord.txt")))
coord_path = COORD_FILE
if out_dir is not None and not os.path.isabs(coord_path):
coord_path = os.path.join(out_dir, coord_path)
(ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS,
ATOM_VELOCITIES, ATOM_FIXED) = load_coord_file(coord_path)
BOND_CONNECTION_FILE = str(config.get("connection_file", os.path.join("input", "connection.txt")))
BOND_PARAMETER_FILE = str(config.get("bond_file", os.path.join("input", "bond.txt")))
connection_path = BOND_CONNECTION_FILE
bond_path = BOND_PARAMETER_FILE
if out_dir is not None and not os.path.isabs(connection_path):
connection_path = os.path.join(out_dir, connection_path)
if out_dir is not None and not os.path.isabs(bond_path):
bond_path = os.path.join(out_dir, bond_path)
bond_map = load_bond_parameters(bond_path)
(BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS,
BOND_REST_LENGTHS) = load_bond_connections(
connection_path, ATOM_IDS, ATOM_POSITIONS, bond_map)
PLOT_ATOM_ID = int(config.get("plot_atom", int(ATOM_IDS[0])))
matches = np.where(ATOM_IDS == PLOT_ATOM_ID)[0]
if len(matches) == 0:
raise ValueError(f"plot_atom={PLOT_ATOM_ID} 不在坐标文件原子序号中")
PLOT_ATOM_ROW = int(matches[0])
M = float(ATOM_MASSES[PLOT_ATOM_ROW])
X0, Y0, Z0 = [float(v) for v in ATOM_POSITIONS[PLOT_ATOM_ROW]]
VX0, VY0, VZ0 = [float(v) for v in ATOM_VELOCITIES[PLOT_ATOM_ROW]]
G = parse_gravity_vector(config["G"])
B = parse_damping_vector(config["B"])
METHOD = normalize_method_name(config.get("method", "explicit_euler"))
NT = int(config["NT"])
DT = float(config["DT"])
NSTEP = int(config["NSTEP"])
# 步骤控制参数(可选,有默认值)
warmup_steps = int(config.get("warmup_steps", 0))
sample_start = config.get("sample_start") # None 表示从头开始
sample_end = config.get("sample_end") # None 表示到末尾
X_MIN = -box_a; X_MAX = box_a
Y_MIN = -box_a; Y_MAX = box_a
Z_MIN = -box_a; Z_MAX = box_a
ball_radius = float(config.get("ball_radius", ATOM_RADII[PLOT_ATOM_ROW]))
ball_color_r = float(config.get("ball_color_r", 0.9))
ball_color_g = float(config.get("ball_color_g", 0.2))
ball_color_b = float(config.get("ball_color_b", 0.2))
box_color_r = float(config.get("box_color_r", 0.8))
box_color_g = float(config.get("box_color_g", 0.8))
box_color_b = float(config.get("box_color_b", 0.85))
use_marker = int(config.get("use_marker", 0))
# 力开关
global GRAVITY_FIELD, GRAVITY_INTERACTION, ELASTIC_FORCE, DAMPING_FORCE, GRAVITY_STRENGTH
global DRIVING_FORCE, DRIVER_DATA
GRAVITY_FIELD = int(config.get("gravity_field", 1))
GRAVITY_INTERACTION = int(config.get("gravity_interaction", 0))
ELASTIC_FORCE = int(config.get("elastic_force", 1))
DAMPING_FORCE = int(config.get("damping_force", 0))
GRAVITY_STRENGTH = float(config.get("gravity_strength", 1.0))
DRIVING_FORCE = int(config.get("driving_force", 0))
# 加载驱动力定义
DRIVER_DATA = None
if DRIVING_FORCE:
driver_rel = str(config.get("driver_file", os.path.join("input", "driver.txt")))
driver_path = driver_rel
if out_dir is not None and not os.path.isabs(driver_rel):
driver_path = os.path.join(out_dir, driver_rel)
DRIVER_DATA = load_driver_file(driver_path, ATOM_IDS)
# 加载运动相机关键帧
camera_keyframes_raw = ""
move_camera = int(config.get("move_camera", 0))
camera_keyframes_raw = ""
if move_camera:
cam_rel = str(config.get("move_camera_file", os.path.join("input", "move_camera.txt")))
cam_path = cam_rel
if out_dir is not None and not os.path.isabs(cam_rel):
cam_path = os.path.join(out_dir, cam_rel)
camera_keyframes_raw = _load_camera_motion(cam_path)
print(f"[compute] 使用算法: {METHOD}")
print(f"[compute] 已加载成键信息: {len(BOND_PAIRS)} 条键")
if config.get("_skip_run", False):
return None, None, None, None, None, None
t_start = time.time()
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation(
save_trajectory=int(config.get("save_trajectory", 0)))
t_end = time.time()
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
elapsed = t_end - t_start
# 写入 dynamics.log
log_dir = get_output_dir(out_dir)
log_path = os.path.join(log_dir, "dynamics.log")
with open(log_path, "w", encoding="utf-8") as f:
f.write("=" * 50 + "\n")
f.write("Dynamics 计算日志\n")
f.write("=" * 50 + "\n")
f.write(f"计算开始时间: {t_start_str}\n")
f.write(f"计算结束时间: {t_end_str}\n")
f.write(f"计算耗时: {elapsed:.3f} s\n")
f.write("-" * 50 + "\n")
f.write("计算参数:\n")
f.write(f" box_a: {box_a}\n")
_alpha_str = alpha if isinstance(alpha, str) else (str(alpha) if isinstance(alpha, list) else str(alpha))
f.write(f" alpha: {_alpha_str}\n")
f.write(f" coord_file: {COORD_FILE}\n")
f.write(f" method: {METHOD}\n")
f.write(f" NT: {NT}\n")
f.write(f" DT: {DT}\n")
f.write(f" NSTEP: {NSTEP}\n")
f.write(f" G: ({G[0]:.2f}, {G[1]:.2f}, {G[2]:.2f})\n")
f.write(f" B: ({B[0]:.2f}, {B[1]:.2f}, {B[2]:.2f})\n")
f.write(f" warmup: {warmup_steps}\n")
f.write(f" 原子数: {len(ATOM_IDS)}\n")
f.write(f" 键数: {len(BOND_PAIRS)}\n")
f.write("-" * 50 + "\n")
f.write("初始状态 (plot_atom):\n")
f.write(f" position: ({X0:.6f}, {Y0:.6f}, {Z0:.6f})\n")
f.write(f" velocity: ({VX0:.6f}, {VY0:.6f}, {VZ0:.6f})\n")
f.write(f" mass: {M}\n")
f.write("=" * 50 + "\n")
print(f"[compute] 日志已保存至: {log_path}")
record = len(traj_x)
print(f"[compute] 计算完成: {record} 步, {elapsed:.3f} s")
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
def run_engine(engine, input_dir, output_dir, config):
"""调用外部计算引擎(C/C++/Fortran),生成 trajectory.txt。
Args:
engine: 引擎名称 ("c", "cpp", "fortran")
input_dir: 输入目录(含 coord.txt, connection.txt, bond.txt
output_dir: 输出目录(轨迹文件写入位置)
config: YAML 配置字典
Returns:
(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
system = platform.system().lower()
engine_map = {
"c": "engines/c/build/dynamics_c",
"cpp": "engines/cpp/build/dynamics_cpp",
"fortran": "engines/fortran/build/dynamics_f90",
}
if engine not in engine_map:
raise ValueError(f"不支持的引擎: {engine},可选: {list(engine_map.keys())}")
engine_rel = engine_map[engine]
engine_path = os.path.join(script_dir, engine_rel)
# 自动检测可执行文件后缀和平台专用版本
candidates = [
engine_path, # 无后缀
engine_path + ".exe", # Windows .exe
engine_path + f"_{system}.exe", # 平台专用 (c_linux.exe, c_darwin.exe)
]
found = None
for p in candidates:
if os.path.exists(p):
found = p
break
if found is None:
raise FileNotFoundError(
f"引擎可执行文件不存在: 尝试了 {candidates}\n"
f"请先编译: cd engines/{engine} && make\n"
f"或安装交叉编译器后: cd engines/{engine} && make {system}")
# 构造 param.json(数值参数)
G = parse_gravity_vector(config.get("G", [0, 0, -9.8]))
B = parse_damping_vector(config.get("B", [0, 0, 0]))
param_json = {
"box_a": float(config.get("box_a", 10.0)),
"NT": int(config["NT"]),
"DT": float(config["DT"]),
"NSTEP": int(config.get("NSTEP", 1)),
"warmup_steps": int(config.get("warmup_steps", 0)),
"method": normalize_method_name(config.get("method", "leapfrog")),
"G": [float(v) for v in G],
"B": [float(v) for v in B],
"gravity_field": int(config.get("gravity_field", 1)),
"gravity_interaction": int(config.get("gravity_interaction", 0)),
"elastic_force": int(config.get("elastic_force", 1)),
"damping_force": int(config.get("damping_force", 0)),
"gravity_strength": float(config.get("gravity_strength", 1.0)),
"driving_force": int(config.get("driving_force", 0)),
"save_trajectory": int(config.get("save_trajectory", 0)),
}
param_path = os.path.join(script_dir, "engines", engine, "param.json")
os.makedirs(os.path.dirname(param_path), exist_ok=True)
with open(param_path, "w", encoding="utf-8") as f:
json.dump(param_json, f, indent=2)
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
print(f"[compute] 引擎: {engine}{os.path.basename(engine_path)}")
print(f"[compute] input: {input_dir}")
print(f"[compute] output: {output_dir}")
# ── 预校准:跑 NT=1000 步测速 ────────────────────────────────
total_steps = int(config["NT"]) - int(config.get("warmup_steps", 0))
_calib_nt = min(1000, max(100, total_steps // 10))
_calib_param = dict(param_json)
_calib_param["NT"] = _calib_nt
_calib_path = os.path.join(script_dir, "engines", engine, "_calib.json")
with open(_calib_path, "w", encoding="utf-8") as _cf:
json.dump(_calib_param, _cf, indent=2)
_calib_outdir = os.path.join(script_dir, "engines", engine, "_calib_out")
os.makedirs(_calib_outdir, exist_ok=True)
_ct0 = time.time()
subprocess.run(
[engine_path, os.path.abspath(input_dir), _calib_outdir, _calib_path],
capture_output=True, timeout=60)
# 清理校准临时文件
for _f in os.listdir(_calib_outdir):
try: os.remove(os.path.join(_calib_outdir, _f))
except OSError: pass
try: os.rmdir(_calib_outdir)
except OSError: pass
os.remove(_calib_path)
_calib_elapsed = max(time.time() - _ct0, 0.001)
_overhead = _calib_elapsed * 0.15
_step_time = max(_calib_elapsed - _overhead, 0.0001) / _calib_nt
_est_total = max(_calib_elapsed, _overhead + _step_time * total_steps)
t_start = time.time()
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
_p = subprocess.Popen(
[engine_path, os.path.abspath(input_dir), os.path.abspath(output_dir), param_path],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, encoding='utf-8', errors='replace')
_engine_lines = []
try:
from tqdm import tqdm as _tqdm
_pbar = _tqdm(total=total_steps, desc=f"[compute] 引擎 {engine}",
unit="", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
except ImportError:
_pbar = None
try:
while True:
_line = _p.stdout.readline() if _p.stdout else ''
if _line:
_line = _line.strip()
if _line:
_engine_lines.append(_line)
# 读取外部引擎真实进度:格式 "[xxx-engine] progress: N/total"
_prog_match = re.search(r'progress:\s*(\d+)/(\d+)', _line)
if _pbar is not None and _prog_match:
_prog_done = int(_prog_match.group(1))
_prog_total = int(_prog_match.group(2))
if _prog_total > 0:
_pbar.n = min(_prog_done, total_steps)
_pbar.refresh()
if _p.poll() is not None:
if _p.stdout:
for _r in _p.stdout:
_r = _r.strip()
if _r:
_engine_lines.append(_r)
# 读取残留在管道中的进度消息,避免 20%→100% 跳变
_prog_match = re.search(r'progress:\s*(\d+)/(\d+)', _r)
if _pbar is not None and _prog_match:
_p_done = min(int(_prog_match.group(1)), total_steps)
_pbar.n = max(_pbar.n, _p_done)
if _pbar is not None: _pbar.refresh()
break
time.sleep(0.05)
finally:
if _pbar is not None:
_pbar.n = total_steps
_pbar.refresh()
_pbar.close()
if _p.poll() is None:
_p.kill()
_p.wait(timeout=5)
_rc = _p.returncode
t_end = time.time()
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
elapsed = t_end - t_start
if _rc != 0:
_err = _p.stderr.read() if _p.stderr else ''
if _err:
print(f"[compute] 引擎错误:\n{_err}")
raise RuntimeError(f"引擎 {engine} 返回错误码 {_rc}")
if _engine_lines:
_log_path = os.path.join(output_dir, "dynamics.log")
try:
with open(_log_path, "a", encoding="utf-8") as _lf:
_lf.write("\n--- 引擎输出 ---\n")
for _ln in _engine_lines:
_lf.write(_ln + "\n")
except OSError:
pass
# 加载输出的 trajectory.txt / display.txt
save_traj = int(config.get("save_trajectory", 0))
if not save_traj:
# save_trajectory=0:引擎只写 display.txt
disp_path = os.path.join(os.path.abspath(output_dir), "display.txt")
if not os.path.exists(disp_path):
raise FileNotFoundError(f"引擎未生成 display.txt: {disp_path}")
print(f"[compute] 引擎已生成 {disp_path}")
return None, None, None, None, None, None
# save_trajectory=1:加载完整 trajectory.txt
traj_path = os.path.join(os.path.abspath(output_dir), "trajectory.txt")
if not os.path.exists(traj_path):
raise FileNotFoundError(f"引擎未生成 trajectory.txt: {traj_path}")
data = load_text_data(traj_path)
traj_x = data["traj_x"]
traj_y = data["traj_y"]
traj_z = data["traj_z"]
traj_vx = data["traj_vx"]
traj_vy = data["traj_vy"]
traj_vz = data["traj_vz"]
# 写入日志文件
log_path = os.path.join(output_dir, "dynamics.log")
with open(log_path, "w", encoding="utf-8") as f:
f.write("=" * 50 + "\n")
f.write("Dynamics 计算日志\n")
f.write("=" * 50 + "\n")
f.write(f"引擎: {engine}\n")
f.write(f"计算开始时间: {t_start_str}\n")
f.write(f"计算结束时间: {t_end_str}\n")
f.write(f"计算耗时: {elapsed:.3f} s\n")
print(f"[compute] 引擎完成: {len(traj_x)} 步, {traj_x.shape[1]} 原子, {elapsed:.3f} s")
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
def save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, out_dir=None):
"""将轨迹数据保存到结构化 txt 文件(含所有参数元数据)。"""
if out_dir is None:
out_dir = os.path.dirname(os.path.abspath(__file__))
out_path = os.path.join(get_output_dir(out_dir), "trajectory.txt")
# 记录的实际步数(扣除预热后)
record_steps = len(traj_x)
sample_start_to_save = -1 if sample_start is None else int(sample_start)
sample_end_to_save = -1 if sample_end is None else int(sample_end)
payload = {
"traj_x": traj_x, "traj_y": traj_y, "traj_z": traj_z,
"traj_vx": traj_vx, "traj_vy": traj_vy, "traj_vz": traj_vz,
"NT": record_steps, "DT": DT, "NSTEP": NSTEP,
"method": METHOD,
"coord_file": COORD_FILE,
"connection_file": BOND_CONNECTION_FILE,
"bond_file": BOND_PARAMETER_FILE,
"atom_ids": ATOM_IDS,
"atom_masses": ATOM_MASSES,
"atom_radii": ATOM_RADII,
"atom_positions": ATOM_POSITIONS,
"atom_velocities": ATOM_VELOCITIES,
"atom_fixed": ATOM_FIXED,
"bond_pairs": BOND_PAIRS,
"bond_names": BOND_NAMES,
"bond_stiffness": BOND_STIFFNESS,
"bond_rest_lengths": BOND_REST_LENGTHS,
"G": G.tolist() if G is not None else None,
"B": B.tolist() if B is not None else None,
"plot_atom_id": PLOT_ATOM_ID,
"plot_atom_row": PLOT_ATOM_ROW,
"warmup_steps": warmup_steps or 0,
"sample_start": sample_start_to_save,
"sample_end": sample_end_to_save,
"X_MIN": X_MIN, "X_MAX": X_MAX,
"Y_MIN": Y_MIN, "Y_MAX": Y_MAX,
"Z_MIN": Z_MIN, "Z_MAX": Z_MAX,
"X0": X0, "Y0": Y0, "Z0": Z0,
"VX0": VX0, "VY0": VY0, "VZ0": VZ0,
"M": M,
"alpha": alpha,
"ball_radius": ball_radius,
"ball_color_r": ball_color_r,
"ball_color_g": ball_color_g,
"ball_color_b": ball_color_b,
"box_color_r": box_color_r,
"box_color_g": box_color_g,
"box_color_b": box_color_b,
"gravity_field": GRAVITY_FIELD,
"gravity_interaction": GRAVITY_INTERACTION,
"elastic_force": ELASTIC_FORCE,
"damping_force": DAMPING_FORCE,
"gravity_strength": GRAVITY_STRENGTH,
}
save_text_data(out_path, payload)
print(f"[compute] 轨迹数据已保存至: {out_path}")
return out_path
def load_parameters(param_file):
"""从文本文件加载参数,文件格式:参数 = 值"""
global box_a, alpha, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
global ball_radius, ball_color_r, ball_color_g, ball_color_b
global box_color_r, box_color_g, box_color_b
print(f"[compute] 正在加载参数文件: {param_file}")
with open(param_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
if '=' not in line:
continue
# 分割键值,并去除行内注释
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
# 去除值中可能存在的注释(# 之后的部分)
if '#' in value:
value = value.split('#', 1)[0].strip()
# 尝试转换为 int 或 float
try:
if '.' in value or 'e' in value.lower():
value = float(value)
else:
value = int(value)
except ValueError:
pass # 保持字符串
# 设置全局变量
if key in globals():
globals()[key] = value
else:
print(f"警告: 未知参数 '{key}',忽略")
# 计算派生边界
if box_a is not None:
X_MIN = -box_a
X_MAX = box_a
Y_MIN = -box_a
Y_MAX = box_a
Z_MIN = -box_a
Z_MAX = box_a
else:
print("错误: 未找到 box_a 参数")
sys.exit(1)
# 检查必需参数
required = ['alpha', 'X0', 'Y0', 'Z0', 'VX0', 'VY0', 'VZ0', 'M', 'G', 'B', 'NT', 'DT', 'NSTEP']
for param in required:
if globals()[param] is None:
print(f"错误: 未找到必需参数 '{param}'")
sys.exit(1)
METHOD = normalize_method_name(METHOD or "explicit_euler")
# 设置绘图参数的默认值(如果未提供)
if ball_radius is None:
ball_radius = 0.28
if ball_color_r is None:
ball_color_r = 0.9
if ball_color_g is None:
ball_color_g = 0.2
if ball_color_b is None:
ball_color_b = 0.2
if box_color_r is None:
box_color_r = 0.8
if box_color_g is None:
box_color_g = 0.8
if box_color_b is None:
box_color_b = 0.85
print(f"[compute] 参数加载完成: box_a={box_a}, alpha={alpha}, NT={NT}, DT={DT}, method={METHOD}")
def normalize_method_name(method):
"""Normalize method aliases from YAML/text config."""
aliases = {
"explicit_euler": "explicit_euler",
"explicit": "explicit_euler",
"euler_explicit": "explicit_euler",
"显式欧拉": "explicit_euler",
"显式欧拉法": "explicit_euler",
"implicit_euler": "implicit_euler",
"implicit": "implicit_euler",
"euler_implicit": "implicit_euler",
"隐式欧拉": "implicit_euler",
"隐式欧拉法": "implicit_euler",
"midpoint": "midpoint",
"mid_point": "midpoint",
"midpoint_method": "midpoint",
"中点法": "midpoint",
"中点差分法": "midpoint",
"leapfrog": "leapfrog",
"leap_frog": "leapfrog",
"蛙跳法": "leapfrog",
}
key = str(method).strip().lower()
if key not in aliases:
valid = ", ".join(["explicit_euler", "implicit_euler", "midpoint", "leapfrog"])
raise ValueError(f"未知算法 '{method}'。可选值: {valid}")
return aliases[key]
def compute_force(x, y, z, vx, vy, vz, m, g, b):
"""Compute total force from the current state.
Each force type is independently controlled by a global switch:
- GRAVITY_FIELD: uniform gravity F = m * g_vec
- DAMPING_FORCE: linear drag F_drag = -B_vec * v
- ELASTIC_FORCE: spring bonds based on bonded pair distance
- GRAVITY_INTERACTION: Newtonian gravity between atom pairs
"""
fx = np.zeros_like(x)
fy = np.zeros_like(y)
fz = np.zeros_like(z)
if GRAVITY_FIELD:
fx += m * g[0]
fy += m * g[1]
fz += m * g[2]
if DAMPING_FORCE:
fx -= b[0] * vx
fy -= b[1] * vy
fz -= b[2] * vz
if ELASTIC_FORCE and BOND_PAIRS is not None and len(BOND_PAIRS) > 0:
for bond_idx, (idx_1, idx_2) in enumerate(BOND_PAIRS):
dx = x[idx_2] - x[idx_1]
dy = y[idx_2] - y[idx_1]
dz = z[idx_2] - z[idx_1]
dist = np.sqrt(dx * dx + dy * dy + dz * dz)
if dist <= 1e-12:
continue
stretch = dist - BOND_REST_LENGTHS[bond_idx]
force_mag = BOND_STIFFNESS[bond_idx] * stretch
ux = dx / dist
uy = dy / dist
uz = dz / dist
fx_bond = force_mag * ux
fy_bond = force_mag * uy
fz_bond = force_mag * uz
fx[idx_1] += fx_bond
fy[idx_1] += fy_bond
fz[idx_1] += fz_bond
fx[idx_2] -= fx_bond
fy[idx_2] -= fy_bond
fz[idx_2] -= fz_bond
if GRAVITY_INTERACTION:
n = len(m)
for i in range(n):
for j in range(i + 1, n):
dx = x[j] - x[i]
dy = y[j] - y[i]
dz = z[j] - z[i]
r2 = dx * dx + dy * dy + dz * dz
if r2 <= 1e-12:
continue
mi = m[i]; mj = m[j]
f_mag = GRAVITY_STRENGTH * mi * mj / r2
r = np.sqrt(r2)
fx_i = f_mag * dx / r
fy_i = f_mag * dy / r
fz_i = f_mag * dz / r
fx[i] += fx_i; fy[i] += fy_i; fz[i] += fz_i
fx[j] -= fx_i; fy[j] -= fy_i; fz[j] -= fz_i
return fx, fy, fz
def compute_acceleration(x, y, z, vx, vy, vz, m, g, b):
"""Compute acceleration from the shared force model."""
fx, fy, fz = compute_force(x, y, z, vx, vy, vz, m, g, b)
return fx / m, fy / m, fz / m
def Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
x = x + vx * dt
y = y + vy * dt
z = z + vz * dt
vx = vx + ax * dt
vy = vy + ay * dt
vz = vz + az * dt
return x, y, z, vx, vy, vz
def Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
gamma_x = b[0] / m
gamma_y = b[1] / m
gamma_z = b[2] / m
vx_next = (vx + g[0] * dt) / (1.0 + gamma_x * dt)
vy_next = (vy + g[1] * dt) / (1.0 + gamma_y * dt)
vz_next = (vz + g[2] * dt) / (1.0 + gamma_z * dt)
ax_next, ay_next, az_next = compute_acceleration(
x, y, z, vx_next, vy_next, vz_next, m, g, b)
vx = vx + ax_next * dt
vy = vy + ay_next * dt
vz = vz + az_next * dt
x = x + vx * dt
y = y + vy * dt
z = z + vz * dt
return x, y, z, vx, vy, vz
def Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b):
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
x_mid = x + 0.5 * vx * dt
y_mid = y + 0.5 * vy * dt
z_mid = z + 0.5 * vz * dt
vx_mid = vx + 0.5 * ax * dt
vy_mid = vy + 0.5 * ay * dt
vz_mid = vz + 0.5 * az * dt
x = x + vx_mid * dt
y = y + vy_mid * dt
z = z + vz_mid * dt
ax_mid, ay_mid, az_mid = compute_acceleration(
x_mid, y_mid, z_mid, vx_mid, vy_mid, vz_mid, m, g, b)
vx = vx + ax_mid * dt
vy = vy + ay_mid * dt
vz = vz + az_mid * dt
return x, y, z, vx, vy, vz
def Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b):
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
vx_half = vx + 0.5 * ax * dt
vy_half = vy + 0.5 * ay * dt
vz_half = vz + 0.5 * az * dt
x = x + vx_half * dt
y = y + vy_half * dt
z = z + vz_half * dt
# 显式预测器:用原始加速度外推半步(标准 Velocity-Verlet 预测步)
# v_pred = v_half + 0.5 * a(t)*dt,包含重力+阻尼+弹簧的所有贡献
vx_pred = vx_half + 0.5 * ax * dt
vy_pred = vy_half + 0.5 * ay * dt
vz_pred = vz_half + 0.5 * az * dt
ax_next, ay_next, az_next = compute_acceleration(
x, y, z, vx_pred, vy_pred, vz_pred, m, g, b)
vx = vx_half + 0.5 * ax_next * dt
vy = vy_half + 0.5 * ay_next * dt
vz = vz_half + 0.5 * az_next * dt
return x, y, z, vx, vy, vz
def Limit_in_box(a, amin, amax, va):
"""限制物体在边界内,发生碰撞时反弹。"""
over = a > amax
under = a < amin
a = np.where(over, amax, a)
a = np.where(under, amin, a)
va = np.where(over | under, -va, va)
return a, va
def apply_motion_update(x, y, z, vx, vy, vz, dt, m, g, b):
"""按配置选择的位置更新算法推进一步。"""
if METHOD == "explicit_euler":
x, y, z, vx, vy, vz = Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
elif METHOD == "implicit_euler":
x, y, z, vx, vy, vz = Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
elif METHOD == "midpoint":
x, y, z, vx, vy, vz = Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b)
elif METHOD == "leapfrog":
x, y, z, vx, vy, vz = Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b)
else:
raise ValueError(f"未知算法: {METHOD}")
x, vx = Limit_in_box(x, X_MIN, X_MAX, vx)
y, vy = Limit_in_box(y, Y_MIN, Y_MAX, vy)
z, vz = Limit_in_box(z, Z_MIN, Z_MAX, vz)
return x, y, z, vx, vy, vz
def apply_fixed_constraints(x, y, z, vx, vy, vz):
"""Keep fixed degrees of freedom at their initial coordinate with zero speed."""
fixed = ATOM_FIXED != 0
positions = np.column_stack((x, y, z))
velocities = np.column_stack((vx, vy, vz))
positions = np.where(fixed, ATOM_POSITIONS, positions)
velocities = np.where(fixed, 0.0, velocities)
return (
positions[:, 0], positions[:, 1], positions[:, 2],
velocities[:, 0], velocities[:, 1], velocities[:, 2],
)
def wrap_position(x, y, z):
"""边界回绕( dynamics 模式)。"""
x = np.where(x > X_MAX, X_MIN, x)
x = np.where(x < X_MIN, X_MAX, x)
y = np.where(y > Y_MAX, Y_MIN, y)
y = np.where(y < Y_MIN, Y_MAX, y)
z = np.where(z > Z_MAX, Z_MIN, z)
z = np.where(z < Z_MIN, Z_MAX, z)
return x, y, z
# ===========================================================================
# 主计算流程
# ===========================================================================
def run_simulation(save_trajectory=0):
"""计算 NT 步轨迹,直接抽帧并保存 display.txt。
步骤控制:
- warmup_steps: 预热步数,跳过不记录(用于稳定初始状态)
- 按 NSTEP 抽帧保存到 display.txt(新格式)
- save_trajectory=1 时额外保存完整 trajectory.txtJSON 旧格式)
"""
# 预热阶段
x = ATOM_POSITIONS[:, 0].copy()
y = ATOM_POSITIONS[:, 1].copy()
z = ATOM_POSITIONS[:, 2].copy()
vx = ATOM_VELOCITIES[:, 0].copy()
vy = ATOM_VELOCITIES[:, 1].copy()
vz = ATOM_VELOCITIES[:, 2].copy()
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, 0.0, 0, DRIVER_DATA, DT)
if warmup_steps is not None and warmup_steps > 0:
print(f"[compute] 预热阶段: 前 {warmup_steps} 步不记录")
for step in trange(warmup_steps, desc="[compute] 预热"):
t = (step + 1) * DT
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, t, step, DRIVER_DATA, DT)
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
x, y, z = wrap_position(x, y, z)
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
print(
f"[compute] 预热完成,展示原子位置: "
f"({x[PLOT_ATOM_ROW]:.4f}, {y[PLOT_ATOM_ROW]:.4f}, {z[PLOT_ATOM_ROW]:.4f})"
)
# 记录阶段 - 按 NSTEP 抽帧保存
record_steps = NT - (warmup_steps or 0)
n_atoms = len(ATOM_IDS)
n_frames = (record_steps + NSTEP - 1) // NSTEP
frame_indices = []
# 按 NSTEP 抽帧的临时缓冲区(远小于全量轨迹)
sampled_x = np.zeros((n_frames, n_atoms), dtype=np.float64)
sampled_y = np.zeros((n_frames, n_atoms), dtype=np.float64)
sampled_z = np.zeros((n_frames, n_atoms), dtype=np.float64)
sampled_vx = np.zeros((n_frames, n_atoms), dtype=np.float64)
sampled_vy = np.zeros((n_frames, n_atoms), dtype=np.float64)
sampled_vz = np.zeros((n_frames, n_atoms), dtype=np.float64)
# 如果 save_trajectory,准备完整轨迹缓冲区
if save_trajectory:
traj_x = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_y = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_z = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_vx = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_vy = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_vz = np.zeros((record_steps, n_atoms), dtype=np.float64)
for step in trange(record_steps, desc="[compute] 计算中"):
t = (step + (warmup_steps or 0)) * DT
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, t, step, DRIVER_DATA, DT)
if save_trajectory:
traj_x[step] = x
traj_y[step] = y
traj_z[step] = z
traj_vx[step] = vx
traj_vy[step] = vy
traj_vz[step] = vz
# 抽帧:NSTEP 间隔保存
if step % NSTEP == 0:
fi = step // NSTEP
sampled_x[fi] = x
sampled_y[fi] = y
sampled_z[fi] = z
sampled_vx[fi] = vx
sampled_vy[fi] = vy
sampled_vz[fi] = vz
frame_indices.append(step)
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
x, y, z = wrap_position(x, y, z)
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
# 写入 display.txt(新格式)
output_dir = get_output_dir()
disp_path = os.path.join(output_dir, "display.txt")
n_frames_actual = len(frame_indices)
save_display_txt(
disp_path,
sampled_x[:n_frames_actual], sampled_y[:n_frames_actual], sampled_z[:n_frames_actual],
sampled_vx[:n_frames_actual], sampled_vy[:n_frames_actual], sampled_vz[:n_frames_actual],
np.array(ATOM_IDS), n_frames_actual, n_atoms,
header_fields={"DT": str(DT), "NSTEP": str(NSTEP), "method": str(METHOD),
"warmup_steps": str(warmup_steps or 0),
"dynamic_steps": str(record_steps),
"T_total": str(NT * DT),
"X_MAX": str(X_MAX), "X_MIN": str(X_MIN),
"Y_MAX": str(Y_MAX), "Y_MIN": str(Y_MIN),
"Z_MAX": str(Z_MAX), "Z_MIN": str(Z_MIN),
"ball_radius": str(ball_radius),
"ball_color_r": str(ball_color_r),
"ball_color_g": str(ball_color_g),
"ball_color_b": str(ball_color_b),
"box_color_r": str(box_color_r),
"box_color_g": str(box_color_g),
"box_color_b": str(box_color_b),
"gravity_field": str(GRAVITY_FIELD),
"gravity_interaction": str(GRAVITY_INTERACTION),
"elastic_force": str(ELASTIC_FORCE),
"damping_force": str(DAMPING_FORCE),
"gravity_strength": str(GRAVITY_STRENGTH),
"driving_force": str(DRIVING_FORCE),
"use_marker": str(use_marker),
"alpha": ",".join(str(a) for a in (alpha if isinstance(alpha, list) else [alpha])),
"atom_radii": ",".join(str(r) for r in ATOM_RADII),
"camera_distance": str(config.get("camera_distance", 40.0)),
"camera_elevation": str(config.get("camera_elevation", 0)),
"camera_azimuth": str(config.get("camera_azimuth", 0)),
"camera_keyframes": str(camera_keyframes_raw)}
)
print(f"[compute] display.txt 已保存至: {disp_path} ({n_frames_actual} 帧)")
# 可选:保存完整 trajectory.txt
if save_trajectory:
save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
print(f"[compute] trajectory.txt 已保存(完整轨迹)")
return sampled_x[:n_frames_actual], sampled_y[:n_frames_actual], sampled_z[:n_frames_actual], \
sampled_vx[:n_frames_actual], sampled_vy[:n_frames_actual], sampled_vz[:n_frames_actual]
def save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT):
"""将轨迹数据保存为逐行表格文本,便于教学查看。"""
if np.ndim(traj_x) > 1:
row = PLOT_ATOM_ROW if PLOT_ATOM_ROW is not None else 0
traj_x = traj_x[:, row]
traj_y = traj_y[:, row]
traj_z = traj_z[:, row]
traj_vx = traj_vx[:, row]
traj_vy = traj_vy[:, row]
traj_vz = traj_vz[:, row]
with open(txt_path, 'w', encoding='utf-8') as f:
# 写入表头
f.write("# 步数, 时间(s), x, y, z, vx, vy, vz\n")
for step in range(NT):
t = step * DT
x = traj_x[step]
y = traj_y[step]
z = traj_z[step]
vx = traj_vx[step]
vy = traj_vy[step]
vz = traj_vz[step]
f.write(f"{step+1:6d}, {t:.6f}, {x:.6f}, {y:.6f}, {z:.6f}, {vx:.6f}, {vy:.6f}, {vz:.6f}\n")
print(f"[compute] 轨迹文本文件已保存至: {txt_path}")
def main():
# 默认参数文件
script_dir = os.path.dirname(os.path.abspath(__file__))
param_file = os.path.join(get_input_dir(script_dir), "input.txt")
if len(sys.argv) > 1:
param_file = sys.argv[1]
# 加载参数
load_parameters(param_file)
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = get_output_dir(script_dir)
print(f"[compute] 开始计算 NT={NT} DT={DT} ")
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation(save_trajectory=0)
print(f"[compute] 计算完成,共 {NT}")
print(f"[compute] display.txt 已在 run_simulation 中保存")
# 如果需要完整轨迹,以上传 save_trajectory=1 重新运行
# 以下旧函数保留兼容但不再自动调用
# save_trajectory_txt(...)
# 同时保存为逐行表格,便于直接查看
txt_path = os.path.join(output_dir, "trajectory_table.txt")
save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT)
if __name__ == "__main__":
main()