Files
dynamics/compute.py
T
admin 854f00ae44 feat: 增加驱动力系统、Marker渲染模式、动画防闪退、案例文档
- 新增 driving_force 驱动力系统(driver.txt 定义,支持周期控制)
- 新增 use_marker 渲染开关(GPU实例化点精灵,提升大量原子性能)
- 修复动画闪退:独立控制台、错误日志、启动存活检测
- 重绘 draw.py 架构:双渲染模式 + 预分配键线缓冲区
- 修复 raw trajectory 采样时间变量遮蔽 bug
- 重构 case05: 60原子一维链 + 驱动力 + 完整案例文档
- 修复所有案例 Readme.md 编码(GBK → UTF-8)
- 所有 input.txt 新增 driver_file / driving_force / use_marker 参数
2026-06-10 15:34:53 +08:00

1307 lines
49 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
compute.py
----------
纯 Python 计算脚本,不依赖 VisPy。
功能:
1. 运行 NT 步物理模拟( kinematics / dynamics 等运动模式)
2. 将每一步的 (x, y, z, vx, vy, vz) 保存到 output/trajectory.txt
3. 同时保存所有模拟参数元数据
"""
import json
import numpy as np
import os
import platform
import subprocess
import sys
import time
import datetime
from tqdm import trange
# ===========================================================================
# 物理参数(必须与 draw.py 加载的边界常量保持一致)
# ===========================================================================
# 全局参数变量,将在 load_parameters 中填充
box_a = None
alpha = None
COORD_FILE = None
X0 = None
Y0 = None
Z0 = None
VX0 = None
VY0 = None
VZ0 = None
M = None
ATOM_IDS = None
ATOM_MASSES = None
ATOM_RADII = None
ATOM_POSITIONS = None
ATOM_VELOCITIES = None
ATOM_FIXED = None
BOND_CONNECTION_FILE = None
BOND_PARAMETER_FILE = None
BOND_PAIRS = None
BOND_NAMES = None
BOND_STIFFNESS = None
BOND_REST_LENGTHS = None
PLOT_ATOM_ID = None
PLOT_ATOM_ROW = None
G = None
B = None
METHOD = None
NT = None
DT = None
NSTEP = None
warmup_steps = None # 预热步数(跳过不保存)
sample_start = None # 抽帧起始索引
sample_end = None # 抽帧结束索引
ball_radius = None
ball_color_r = None
ball_color_g = None
ball_color_b = None
box_color_r = None
box_color_g = None
box_color_b = None
# 力开关
GRAVITY_FIELD = 1 # 均匀重力场
GRAVITY_INTERACTION = 0 # 原子间万有引力
ELASTIC_FORCE = 1 # 弹簧键力
DAMPING_FORCE = 0 # 阻尼
DRIVING_FORCE = 0 # 驱动力
GRAVITY_STRENGTH = 1.0
# 驱动力数据
DRIVER_DATA = None # 加载自 driver.txt
# 派生边界(根据 box_a 计算)
X_MIN = None
X_MAX = None
Y_MIN = None
Y_MAX = None
Z_MIN = None
Z_MAX = None
def _to_text_value(value):
"""Convert numpy-heavy objects into JSON-friendly plain Python values."""
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, (np.integer,)):
return int(value)
if isinstance(value, (np.floating,)):
return float(value)
if isinstance(value, dict):
return {key: _to_text_value(val) for key, val in value.items()}
if isinstance(value, (list, tuple)):
return [_to_text_value(item) for item in value]
return value
def _from_text_value(value):
"""Convert nested numeric lists back into numpy arrays when appropriate."""
if isinstance(value, list):
if not value:
return np.array([], dtype=np.float64)
if all(not isinstance(item, (list, dict)) for item in value):
if all(isinstance(item, (int, float, bool)) for item in value):
return np.array(value)
return value
if all(isinstance(item, list) for item in value):
return np.array(value)
return [_from_text_value(item) for item in value]
if isinstance(value, dict):
return {key: _from_text_value(val) for key, val in value.items()}
return value
def save_text_data(path, data):
"""Save structured simulation data as formatted JSON text."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(_to_text_value(data), f, ensure_ascii=False, indent=2)
return path
def load_text_data(path):
"""Load structured simulation data from JSON text."""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return _from_text_value(data)
def get_output_dir(base_dir=None):
"""Return the output directory used for generated artifacts."""
override = os.environ.get("DYNAMICS_OUTPUT_DIR")
if override:
output_dir = os.path.abspath(override)
else:
if base_dir is None:
base_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(base_dir, "output")
os.makedirs(output_dir, exist_ok=True)
return output_dir
def get_input_dir(base_dir=None):
"""Return the input directory used for configuration and source data."""
override = os.environ.get("DYNAMICS_INPUT_DIR")
if override:
input_dir = os.path.abspath(override)
else:
if base_dir is None:
base_dir = os.path.dirname(os.path.abspath(__file__))
input_dir = os.path.join(base_dir, "input")
os.makedirs(input_dir, exist_ok=True)
return input_dir
def load_coord_file(coord_path):
"""Load atom ids, masses, radii, coordinates, and velocities."""
if not os.path.exists(coord_path):
raise FileNotFoundError(f"坐标文件不存在: {coord_path}")
rows = []
with open(coord_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
expected = ["n", "mass", "radius", "x", "y", "z", "vx", "vy", "vz"]
legacy = expected + ["fixed_x", "fixed_y", "fixed_z"]
legacy_new = expected + ["fix_x", "fix_y", "fix_z"]
if header not in (expected, legacy, legacy_new):
raise ValueError(
f"坐标文件表头应为: {' '.join(expected)} 或加三列 fix_x fix_y fix_z"
f"实际为: {' '.join(header)}")
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if header == expected and len(parts) != 9:
raise ValueError(f"{coord_path}:{line_no} 应有 9 列,实际为 {len(parts)}")
if header in (legacy, legacy_new) and len(parts) != 12:
raise ValueError(f"{coord_path}:{line_no} 应有 12 列,实际为 {len(parts)}")
rows.append(parts)
if not rows:
raise ValueError(f"坐标文件没有原子数据: {coord_path}")
atom_ids = np.array([int(row[0]) for row in rows], dtype=np.int64)
masses = np.array([float(row[1]) for row in rows], dtype=np.float64)
radii = np.array([float(row[2]) for row in rows], dtype=np.float64)
positions = np.array([[float(row[3]), float(row[4]), float(row[5])] for row in rows],
dtype=np.float64)
velocities = np.array([[float(row[6]), float(row[7]), float(row[8])] for row in rows],
dtype=np.float64)
fixed = np.zeros((len(rows), 3), dtype=np.int64)
if header in (legacy, legacy_new):
fixed = np.array([[int(row[9]), int(row[10]), int(row[11])] for row in rows],
dtype=np.int64)
if np.any(masses <= 0):
raise ValueError("坐标文件中的质量必须为正数")
if np.any(radii <= 0):
raise ValueError("坐标文件中的半径必须为正数")
print(f"[compute] 已加载坐标文件: {coord_path},原子数={len(atom_ids)}")
return atom_ids, masses, radii, positions, velocities, fixed
def load_bond_parameters(bond_path):
"""Load bond stiffness and optional rest length by bond name.
File format (表头):
bond_name k # 2 列:从初始坐标计算键长
bond_name k rest_length # 3 列:显式指定键长
"""
if not os.path.exists(bond_path):
raise FileNotFoundError(f"键参数文件不存在: {bond_path}")
bond_map = {}
with open(bond_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
if header == ["bond_name", "k"]:
has_rest_length = False
elif header == ["bond_name", "k", "rest_length"]:
has_rest_length = True
else:
expected = ["bond_name", "k"] + (["rest_length"] if len(header) == 3 else [])
raise ValueError(
f"键参数文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
ncols = 3 if has_rest_length else 2
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) != ncols:
raise ValueError(
f"{bond_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)}")
bond_name = parts[0]
stiffness = float(parts[1])
if stiffness < 0:
raise ValueError(f"{bond_path}:{line_no} 劲度系数必须非负")
rest_length = None
if has_rest_length:
rest_length = float(parts[2])
if rest_length <= 0:
raise ValueError(f"{bond_path}:{line_no} 键长必须为正数")
bond_map[bond_name] = {"stiffness": stiffness, "rest_length": rest_length}
if not bond_map:
raise ValueError(f"键参数文件没有有效数据: {bond_path}")
return bond_map
def load_bond_connections(connection_path, atom_ids, atom_positions, bond_map):
"""Load bonded atom pairs and derive rest lengths.
键长优先级:
1. bond_map 中显式指定的 rest_lengthbond.txt 第三列)
2. 否则从初始坐标计算
"""
if not os.path.exists(connection_path):
raise FileNotFoundError(f"成键连接文件不存在: {connection_path}")
atom_index = {int(atom_id): idx for idx, atom_id in enumerate(atom_ids)}
pairs = []
bond_names = []
stiffness = []
rest_lengths = []
with open(connection_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
expected = ["n1", "n2", "bond_name"]
if header != expected:
raise ValueError(
f"成键连接文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) != 3:
raise ValueError(f"{connection_path}:{line_no} 应有 3 列,实际为 {len(parts)}")
atom_1 = int(parts[0])
atom_2 = int(parts[1])
bond_name = parts[2]
if atom_1 not in atom_index or atom_2 not in atom_index:
raise ValueError(
f"{connection_path}:{line_no} 中的原子序号 {atom_1}, {atom_2} 不在坐标文件里")
if bond_name not in bond_map:
raise ValueError(
f"{connection_path}:{line_no} 使用了未定义的键类型 {bond_name}")
idx_1 = atom_index[atom_1]
idx_2 = atom_index[atom_2]
bond_entry = bond_map[bond_name]
rest_length = bond_entry["rest_length"]
if rest_length is None:
# 未指定键长时从初始坐标计算
delta = atom_positions[idx_2] - atom_positions[idx_1]
rest_length = float(np.linalg.norm(delta))
pairs.append((idx_1, idx_2))
bond_names.append(bond_name)
stiffness.append(bond_entry["stiffness"])
rest_lengths.append(rest_length)
if not pairs:
return (
np.zeros((0, 2), dtype=np.int64),
[],
np.zeros(0, dtype=np.float64),
np.zeros(0, dtype=np.float64),
)
return (
np.array(pairs, dtype=np.int64),
bond_names,
np.array(stiffness, dtype=np.float64),
np.array(rest_lengths, dtype=np.float64),
)
# ===========================================================================
# 驱动力加载 & 应用
# ===========================================================================
def load_driver_file(driver_path, atom_ids):
"""从 driver.txt 加载驱动力定义。
格式:
n amp_x amp_y amp_z freq_x freq_y freq_z phi_x phi_y phi_z period
数值 0 0 0 0 0 10 0 0 90 all
其中:
position = A * cos(2π f t + φ), φ 为角度制(自动转弧度)
period = all | 数值(周期数,结束后原子静止)
"""
if not os.path.exists(driver_path):
print(f"[compute] 警告: 驱动力文件不存在: {driver_path}")
return None
atom_index = {int(aid): idx for idx, aid in enumerate(atom_ids)}
drivers = []
ncols = 11
with open(driver_path, "r", encoding="utf-8") as f:
header = f.readline().strip().split()
expected = ["n", "amp_x", "amp_y", "amp_z",
"freq_x", "freq_y", "freq_z",
"phi_x", "phi_y", "phi_z", "period"]
if header != expected:
raise ValueError(
f"driver.txt 表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
for line_no, line in enumerate(f, start=2):
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) != ncols:
raise ValueError(
f"{driver_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)}")
n = int(parts[0])
if n not in atom_index:
raise ValueError(f"{driver_path}:{line_no} 原子序号 {n} 不在 coord.txt 中")
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])],
dtype=np.float64)
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])],
dtype=np.float64)
phi_deg = np.array([float(parts[7]), float(parts[8]), float(parts[9])],
dtype=np.float64)
phi_rad = phi_deg * np.pi / 180.0
period_str = parts[10]
drivers.append({
"atom_index": atom_index[n],
"atom_id": n,
"amp": amp,
"freq": freq,
"phi": phi_rad,
"period_str": period_str,
"period_cycles": None if period_str == "all" else float(period_str),
# 在模拟中动态记录:冻结步数索引、冻结位置
"freeze_step": None,
"freeze_pos": None,
})
if not drivers:
print(f"[compute] 警告: driver.txt 中没有有效数据")
return None
print(f"[compute] 已加载驱动力: {len(drivers)} 条定义")
for d in drivers:
print(f" 原子 {d['atom_id']}: "
f"A=({d['amp'][0]},{d['amp'][1]},{d['amp'][2]}), "
f"f=({d['freq'][0]},{d['freq'][1]},{d['freq'][2]}), "
f"φ=({phi_deg[d['amp'].tolist().index(max(d['amp']))]}° 等), "
f"period={d['period_str']}")
return drivers
def apply_driving_force(x, y, z, vx, vy, vz, t, step, drivers, dt):
"""对受驱原子按驱动力函数覆盖位置/速度。
驱动力公式:pos = A * cos(2π f t + φ)
vel = -A * 2π f * sin(2π f t + φ)
"""
if drivers is None:
return x, y, z, vx, vy, vz
for d in drivers:
idx = d["atom_index"]
# 确定驱动力持续到哪一步
if d["period_cycles"] is not None:
max_freq = np.max(np.abs(d["freq"]))
if max_freq > 1e-12:
period_duration = d["period_cycles"] / max_freq
period_steps = int(period_duration / dt)
else:
period_steps = 0
if step > period_steps:
# 驱动力已结束:原子静止(保持最后位置,速度归零)
if d["freeze_pos"] is not None:
x[idx], y[idx], z[idx] = d["freeze_pos"]
vx[idx] = vy[idx] = vz[idx] = 0.0
continue
else:
period_steps = None # 全程驱动
# 当前驱动力下的位置 / 速度
t_vec = np.array([t, t, t], dtype=np.float64)
pos_drive = d["amp"] * np.cos(2.0 * np.pi * d["freq"] * t_vec + d["phi"])
vel_drive = -d["amp"] * 2.0 * np.pi * d["freq"] * np.sin(2.0 * np.pi * d["freq"] * t_vec + d["phi"])
x[idx] = pos_drive[0]
y[idx] = pos_drive[1]
z[idx] = pos_drive[2]
vx[idx] = vel_drive[0]
vy[idx] = vel_drive[1]
vz[idx] = vel_drive[2]
# 若周期有限,记录冻结位置(驱动力最后一帧的位置)
if period_steps is not None and step == period_steps:
d["freeze_pos"] = (float(pos_drive[0]), float(pos_drive[1]), float(pos_drive[2]))
return x, y, z, vx, vy, vz
def parse_gravity_vector(value):
"""Parse gravity into a 3D acceleration vector.
Backward compatibility:
- scalar `G: 9.8` -> [0, 0, -9.8]
- vector `G: [gx, gy, gz]` -> [gx, gy, gz]
"""
if isinstance(value, (int, float, np.integer, np.floating)):
return np.array([0.0, 0.0, -float(value)], dtype=np.float64)
vector = np.asarray(value, dtype=np.float64)
if vector.shape != (3,):
raise ValueError(f"G 必须是长度为 3 的分量数组,实际为 {value}")
return vector
def parse_damping_vector(value):
"""Parse damping into a 3D component vector.
Backward compatibility:
- scalar `B: 0.5` -> [0.5, 0.5, 0.5]
- vector `B: [bx, by, bz]` -> [bx, by, bz]
"""
if isinstance(value, (int, float, np.integer, np.floating)):
scalar = float(value)
return np.array([scalar, scalar, scalar], dtype=np.float64)
vector = np.asarray(value, dtype=np.float64)
if vector.shape != (3,):
raise ValueError(f"B 必须是长度为 3 的分量数组,实际为 {value}")
return vector
def run_from_config(config, out_dir=None):
"""直接接收配置字典,运行模拟并保存结果。
config 字典结构:
box_a, alpha, coord_file, connection_file, bond_file, plot_atom, G, B, method, NT, DT, NSTEP,
warmup_steps, sample_start, sample_end(可选),
ball_radius, ball_color_r/g/b, box_color_r/g/b(可选)
out_dir: 输出目录,默认同脚本所在目录
返回:(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
"""
global box_a, alpha, COORD_FILE, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
global PLOT_ATOM_ID, PLOT_ATOM_ROW
global ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS, ATOM_VELOCITIES, ATOM_FIXED
global BOND_CONNECTION_FILE, BOND_PARAMETER_FILE, BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS, BOND_REST_LENGTHS
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
global ball_radius, ball_color_r, ball_color_g, ball_color_b
global box_color_r, box_color_g, box_color_b
global warmup_steps, sample_start, sample_end
global GRAVITY_FIELD, GRAVITY_INTERACTION, ELASTIC_FORCE, DAMPING_FORCE, GRAVITY_STRENGTH
global DRIVING_FORCE, DRIVER_DATA
box_a = float(config.get("box_a", 10.0))
raw_alpha = config.get("alpha", 0.2)
if isinstance(raw_alpha, (list, tuple)):
alpha = [float(a) for a in raw_alpha]
else:
alpha = float(raw_alpha)
COORD_FILE = str(config.get("coord_file", os.path.join("input", "coord.txt")))
coord_path = COORD_FILE
if out_dir is not None and not os.path.isabs(coord_path):
coord_path = os.path.join(out_dir, coord_path)
(ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS,
ATOM_VELOCITIES, ATOM_FIXED) = load_coord_file(coord_path)
BOND_CONNECTION_FILE = str(config.get("connection_file", os.path.join("input", "connection.txt")))
BOND_PARAMETER_FILE = str(config.get("bond_file", os.path.join("input", "bond.txt")))
connection_path = BOND_CONNECTION_FILE
bond_path = BOND_PARAMETER_FILE
if out_dir is not None and not os.path.isabs(connection_path):
connection_path = os.path.join(out_dir, connection_path)
if out_dir is not None and not os.path.isabs(bond_path):
bond_path = os.path.join(out_dir, bond_path)
bond_map = load_bond_parameters(bond_path)
(BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS,
BOND_REST_LENGTHS) = load_bond_connections(
connection_path, ATOM_IDS, ATOM_POSITIONS, bond_map)
PLOT_ATOM_ID = int(config.get("plot_atom", int(ATOM_IDS[0])))
matches = np.where(ATOM_IDS == PLOT_ATOM_ID)[0]
if len(matches) == 0:
raise ValueError(f"plot_atom={PLOT_ATOM_ID} 不在坐标文件原子序号中")
PLOT_ATOM_ROW = int(matches[0])
M = float(ATOM_MASSES[PLOT_ATOM_ROW])
X0, Y0, Z0 = [float(v) for v in ATOM_POSITIONS[PLOT_ATOM_ROW]]
VX0, VY0, VZ0 = [float(v) for v in ATOM_VELOCITIES[PLOT_ATOM_ROW]]
G = parse_gravity_vector(config["G"])
B = parse_damping_vector(config["B"])
METHOD = normalize_method_name(config.get("method", "explicit_euler"))
NT = int(config["NT"])
DT = float(config["DT"])
NSTEP = int(config["NSTEP"])
# 步骤控制参数(可选,有默认值)
warmup_steps = int(config.get("warmup_steps", 0))
sample_start = config.get("sample_start") # None 表示从头开始
sample_end = config.get("sample_end") # None 表示到末尾
X_MIN = -box_a; X_MAX = box_a
Y_MIN = -box_a; Y_MAX = box_a
Z_MIN = -box_a; Z_MAX = box_a
ball_radius = float(config.get("ball_radius", ATOM_RADII[PLOT_ATOM_ROW]))
ball_color_r = float(config.get("ball_color_r", 0.9))
ball_color_g = float(config.get("ball_color_g", 0.2))
ball_color_b = float(config.get("ball_color_b", 0.2))
box_color_r = float(config.get("box_color_r", 0.8))
box_color_g = float(config.get("box_color_g", 0.8))
box_color_b = float(config.get("box_color_b", 0.85))
# 力开关
global GRAVITY_FIELD, GRAVITY_INTERACTION, ELASTIC_FORCE, DAMPING_FORCE, GRAVITY_STRENGTH
global DRIVING_FORCE, DRIVER_DATA
GRAVITY_FIELD = int(config.get("gravity_field", 1))
GRAVITY_INTERACTION = int(config.get("gravity_interaction", 0))
ELASTIC_FORCE = int(config.get("elastic_force", 1))
DAMPING_FORCE = int(config.get("damping_force", 0))
GRAVITY_STRENGTH = float(config.get("gravity_strength", 1.0))
DRIVING_FORCE = int(config.get("driving_force", 0))
# 加载驱动力定义
DRIVER_DATA = None
if DRIVING_FORCE:
driver_rel = str(config.get("driver_file", os.path.join("input", "driver.txt")))
driver_path = driver_rel
if out_dir is not None and not os.path.isabs(driver_rel):
driver_path = os.path.join(out_dir, driver_rel)
DRIVER_DATA = load_driver_file(driver_path, ATOM_IDS)
print(f"[compute] 使用算法: {METHOD}")
print(f"[compute] 已加载成键信息: {len(BOND_PAIRS)} 条键")
if config.get("_skip_run", False):
return None, None, None, None, None, None
t_start = time.time()
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation()
t_end = time.time()
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
elapsed = t_end - t_start
# 写入 dynamics.log
log_dir = get_output_dir(out_dir)
log_path = os.path.join(log_dir, "dynamics.log")
with open(log_path, "w", encoding="utf-8") as f:
f.write("=" * 50 + "\n")
f.write("Dynamics 计算日志\n")
f.write("=" * 50 + "\n")
f.write(f"计算开始时间: {t_start_str}\n")
f.write(f"计算结束时间: {t_end_str}\n")
f.write(f"计算耗时: {elapsed:.3f} s\n")
f.write("-" * 50 + "\n")
f.write("计算参数:\n")
f.write(f" box_a: {box_a}\n")
_alpha_str = alpha if isinstance(alpha, str) else (str(alpha) if isinstance(alpha, list) else str(alpha))
f.write(f" alpha: {_alpha_str}\n")
f.write(f" coord_file: {COORD_FILE}\n")
f.write(f" method: {METHOD}\n")
f.write(f" NT: {NT}\n")
f.write(f" DT: {DT}\n")
f.write(f" NSTEP: {NSTEP}\n")
f.write(f" G: ({G[0]:.2f}, {G[1]:.2f}, {G[2]:.2f})\n")
f.write(f" B: ({B[0]:.2f}, {B[1]:.2f}, {B[2]:.2f})\n")
f.write(f" warmup: {warmup_steps}\n")
f.write(f" 原子数: {len(ATOM_IDS)}\n")
f.write(f" 键数: {len(BOND_PAIRS)}\n")
f.write("-" * 50 + "\n")
f.write("初始状态 (plot_atom):\n")
f.write(f" position: ({X0:.6f}, {Y0:.6f}, {Z0:.6f})\n")
f.write(f" velocity: ({VX0:.6f}, {VY0:.6f}, {VZ0:.6f})\n")
f.write(f" mass: {M}\n")
f.write("=" * 50 + "\n")
print(f"[compute] 日志已保存至: {log_path}")
record = len(traj_x)
print(f"[compute] 计算完成: {record} 步, {elapsed:.3f} s")
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
def run_engine(engine, input_dir, output_dir, config):
"""调用外部计算引擎(C/C++/Fortran),生成 trajectory.txt。
Args:
engine: 引擎名称 ("c", "cpp", "fortran")
input_dir: 输入目录(含 coord.txt, connection.txt, bond.txt
output_dir: 输出目录(轨迹文件写入位置)
config: YAML 配置字典
Returns:
(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
system = platform.system().lower()
engine_map = {
"c": "engines/c/build/dynamics_c",
"cpp": "engines/cpp/build/dynamics_cpp",
"fortran": "engines/fortran/build/dynamics_f90",
}
if engine not in engine_map:
raise ValueError(f"不支持的引擎: {engine},可选: {list(engine_map.keys())}")
engine_rel = engine_map[engine]
engine_path = os.path.join(script_dir, engine_rel)
# 自动检测可执行文件后缀和平台专用版本
candidates = [
engine_path, # 无后缀
engine_path + ".exe", # Windows .exe
engine_path + f"_{system}.exe", # 平台专用 (c_linux.exe, c_darwin.exe)
]
found = None
for p in candidates:
if os.path.exists(p):
found = p
break
if found is None:
raise FileNotFoundError(
f"引擎可执行文件不存在: 尝试了 {candidates}\n"
f"请先编译: cd engines/{engine} && make\n"
f"或安装交叉编译器后: cd engines/{engine} && make {system}")
# 构造 param.json(数值参数)
G = parse_gravity_vector(config.get("G", [0, 0, -9.8]))
B = parse_damping_vector(config.get("B", [0, 0, 0]))
param_json = {
"box_a": float(config.get("box_a", 10.0)),
"NT": int(config["NT"]),
"DT": float(config["DT"]),
"NSTEP": int(config.get("NSTEP", 1)),
"warmup_steps": int(config.get("warmup_steps", 0)),
"method": normalize_method_name(config.get("method", "leapfrog")),
"G": [float(v) for v in G],
"B": [float(v) for v in B],
"gravity_field": int(config.get("gravity_field", 1)),
"gravity_interaction": int(config.get("gravity_interaction", 0)),
"elastic_force": int(config.get("elastic_force", 1)),
"damping_force": int(config.get("damping_force", 0)),
"gravity_strength": float(config.get("gravity_strength", 1.0)),
}
param_path = os.path.join(script_dir, "engines", engine, "param.json")
os.makedirs(os.path.dirname(param_path), exist_ok=True)
with open(param_path, "w", encoding="utf-8") as f:
json.dump(param_json, f, indent=2)
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
print(f"[compute] 引擎: {engine}{os.path.basename(engine_path)}")
print(f"[compute] input: {input_dir}")
print(f"[compute] output: {output_dir}")
# ── 预校准:跑 NT=1000 步测速 ────────────────────────────────
total_steps = int(config["NT"]) - int(config.get("warmup_steps", 0))
_calib_nt = min(1000, max(100, total_steps // 10))
_calib_param = dict(param_json)
_calib_param["NT"] = _calib_nt
_calib_path = os.path.join(script_dir, "engines", engine, "_calib.json")
with open(_calib_path, "w", encoding="utf-8") as _cf:
json.dump(_calib_param, _cf, indent=2)
_ct0 = time.time()
subprocess.run(
[engine_path, os.path.abspath(input_dir), os.devnull, _calib_path],
capture_output=True, timeout=60)
os.remove(_calib_path)
_calib_elapsed = max(time.time() - _ct0, 0.001)
_overhead = _calib_elapsed * 0.15
_step_time = max(_calib_elapsed - _overhead, 0.0001) / _calib_nt
_est_total = max(_calib_elapsed, _overhead + _step_time * total_steps)
t_start = time.time()
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
_p = subprocess.Popen(
[engine_path, os.path.abspath(input_dir), os.path.abspath(output_dir), param_path],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, encoding='utf-8', errors='replace')
_engine_lines = []
try:
from tqdm import tqdm as _tqdm
_pbar = _tqdm(total=total_steps, desc=f"[compute] 引擎 {engine}",
unit="", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
except ImportError:
_pbar = None
try:
while True:
_line = _p.stdout.readline() if _p.stdout else ''
if _line:
_line = _line.strip()
if _line:
_engine_lines.append(_line)
if _p.poll() is not None:
if _p.stdout:
for _r in _p.stdout:
_r = _r.strip()
if _r:
_engine_lines.append(_r)
break
if _pbar is not None and _est_total > 0:
_pbar.n = int(min((time.time() - t_start) / _est_total, 0.99) * total_steps)
_pbar.refresh()
time.sleep(0.2)
finally:
if _pbar is not None:
_pbar.n = total_steps
_pbar.refresh()
_pbar.close()
if _p.poll() is None:
_p.kill()
_p.wait(timeout=5)
_rc = _p.returncode
t_end = time.time()
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
elapsed = t_end - t_start
if _rc != 0:
_err = _p.stderr.read() if _p.stderr else ''
if _err:
print(f"[compute] 引擎错误:\n{_err}")
raise RuntimeError(f"引擎 {engine} 返回错误码 {_rc}")
if _engine_lines:
_log_path = os.path.join(output_dir, "dynamics.log")
try:
with open(_log_path, "a", encoding="utf-8") as _lf:
_lf.write("\n--- 引擎输出 ---\n")
for _ln in _engine_lines:
_lf.write(_ln + "\n")
except OSError:
pass
# 加载输出的 trajectory.txt
traj_path = os.path.join(os.path.abspath(output_dir), "trajectory.txt")
if not os.path.exists(traj_path):
raise FileNotFoundError(f"引擎未生成 trajectory.txt: {traj_path}")
data = load_text_data(traj_path)
traj_x = data["traj_x"]
traj_y = data["traj_y"]
traj_z = data["traj_z"]
traj_vx = data["traj_vx"]
traj_vy = data["traj_vy"]
traj_vz = data["traj_vz"]
# 写入日志文件
log_path = os.path.join(output_dir, "dynamics.log")
with open(log_path, "w", encoding="utf-8") as f:
f.write("=" * 50 + "\n")
f.write("Dynamics 计算日志\n")
f.write("=" * 50 + "\n")
f.write(f"引擎: {engine}\n")
f.write(f"计算开始时间: {t_start_str}\n")
f.write(f"计算结束时间: {t_end_str}\n")
f.write(f"计算耗时: {elapsed:.3f} s\n")
print(f"[compute] 引擎完成: {len(traj_x)} 步, {traj_x.shape[1]} 原子, {elapsed:.3f} s")
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
def save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, out_dir=None):
"""将轨迹数据保存到结构化 txt 文件(含所有参数元数据)。"""
if out_dir is None:
out_dir = os.path.dirname(os.path.abspath(__file__))
out_path = os.path.join(get_output_dir(out_dir), "trajectory.txt")
# 记录的实际步数(扣除预热后)
record_steps = len(traj_x)
sample_start_to_save = -1 if sample_start is None else int(sample_start)
sample_end_to_save = -1 if sample_end is None else int(sample_end)
payload = {
"traj_x": traj_x, "traj_y": traj_y, "traj_z": traj_z,
"traj_vx": traj_vx, "traj_vy": traj_vy, "traj_vz": traj_vz,
"NT": record_steps, "DT": DT, "NSTEP": NSTEP,
"method": METHOD,
"coord_file": COORD_FILE,
"connection_file": BOND_CONNECTION_FILE,
"bond_file": BOND_PARAMETER_FILE,
"atom_ids": ATOM_IDS,
"atom_masses": ATOM_MASSES,
"atom_radii": ATOM_RADII,
"atom_positions": ATOM_POSITIONS,
"atom_velocities": ATOM_VELOCITIES,
"atom_fixed": ATOM_FIXED,
"bond_pairs": BOND_PAIRS,
"bond_names": BOND_NAMES,
"bond_stiffness": BOND_STIFFNESS,
"bond_rest_lengths": BOND_REST_LENGTHS,
"G": G.tolist() if G is not None else None,
"B": B.tolist() if B is not None else None,
"plot_atom_id": PLOT_ATOM_ID,
"plot_atom_row": PLOT_ATOM_ROW,
"warmup_steps": warmup_steps or 0,
"sample_start": sample_start_to_save,
"sample_end": sample_end_to_save,
"X_MIN": X_MIN, "X_MAX": X_MAX,
"Y_MIN": Y_MIN, "Y_MAX": Y_MAX,
"Z_MIN": Z_MIN, "Z_MAX": Z_MAX,
"X0": X0, "Y0": Y0, "Z0": Z0,
"VX0": VX0, "VY0": VY0, "VZ0": VZ0,
"M": M,
"alpha": alpha,
"ball_radius": ball_radius,
"ball_color_r": ball_color_r,
"ball_color_g": ball_color_g,
"ball_color_b": ball_color_b,
"box_color_r": box_color_r,
"box_color_g": box_color_g,
"box_color_b": box_color_b,
"gravity_field": GRAVITY_FIELD,
"gravity_interaction": GRAVITY_INTERACTION,
"elastic_force": ELASTIC_FORCE,
"damping_force": DAMPING_FORCE,
"gravity_strength": GRAVITY_STRENGTH,
}
save_text_data(out_path, payload)
print(f"[compute] 轨迹数据已保存至: {out_path}")
return out_path
def load_parameters(param_file):
"""从文本文件加载参数,文件格式:参数 = 值"""
global box_a, alpha, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
global ball_radius, ball_color_r, ball_color_g, ball_color_b
global box_color_r, box_color_g, box_color_b
print(f"[compute] 正在加载参数文件: {param_file}")
with open(param_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
if '=' not in line:
continue
# 分割键值,并去除行内注释
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
# 去除值中可能存在的注释(# 之后的部分)
if '#' in value:
value = value.split('#', 1)[0].strip()
# 尝试转换为 int 或 float
try:
if '.' in value or 'e' in value.lower():
value = float(value)
else:
value = int(value)
except ValueError:
pass # 保持字符串
# 设置全局变量
if key in globals():
globals()[key] = value
else:
print(f"警告: 未知参数 '{key}',忽略")
# 计算派生边界
if box_a is not None:
X_MIN = -box_a
X_MAX = box_a
Y_MIN = -box_a
Y_MAX = box_a
Z_MIN = -box_a
Z_MAX = box_a
else:
print("错误: 未找到 box_a 参数")
sys.exit(1)
# 检查必需参数
required = ['alpha', 'X0', 'Y0', 'Z0', 'VX0', 'VY0', 'VZ0', 'M', 'G', 'B', 'NT', 'DT', 'NSTEP']
for param in required:
if globals()[param] is None:
print(f"错误: 未找到必需参数 '{param}'")
sys.exit(1)
METHOD = normalize_method_name(METHOD or "explicit_euler")
# 设置绘图参数的默认值(如果未提供)
if ball_radius is None:
ball_radius = 0.28
if ball_color_r is None:
ball_color_r = 0.9
if ball_color_g is None:
ball_color_g = 0.2
if ball_color_b is None:
ball_color_b = 0.2
if box_color_r is None:
box_color_r = 0.8
if box_color_g is None:
box_color_g = 0.8
if box_color_b is None:
box_color_b = 0.85
print(f"[compute] 参数加载完成: box_a={box_a}, alpha={alpha}, NT={NT}, DT={DT}, method={METHOD}")
def normalize_method_name(method):
"""Normalize method aliases from YAML/text config."""
aliases = {
"explicit_euler": "explicit_euler",
"explicit": "explicit_euler",
"euler_explicit": "explicit_euler",
"显式欧拉": "explicit_euler",
"显式欧拉法": "explicit_euler",
"implicit_euler": "implicit_euler",
"implicit": "implicit_euler",
"euler_implicit": "implicit_euler",
"隐式欧拉": "implicit_euler",
"隐式欧拉法": "implicit_euler",
"midpoint": "midpoint",
"mid_point": "midpoint",
"midpoint_method": "midpoint",
"中点法": "midpoint",
"中点差分法": "midpoint",
"leapfrog": "leapfrog",
"leap_frog": "leapfrog",
"蛙跳法": "leapfrog",
}
key = str(method).strip().lower()
if key not in aliases:
valid = ", ".join(["explicit_euler", "implicit_euler", "midpoint", "leapfrog"])
raise ValueError(f"未知算法 '{method}'。可选值: {valid}")
return aliases[key]
def compute_force(x, y, z, vx, vy, vz, m, g, b):
"""Compute total force from the current state.
Each force type is independently controlled by a global switch:
- GRAVITY_FIELD: uniform gravity F = m * g_vec
- DAMPING_FORCE: linear drag F_drag = -B_vec * v
- ELASTIC_FORCE: spring bonds based on bonded pair distance
- GRAVITY_INTERACTION: Newtonian gravity between atom pairs
"""
fx = np.zeros_like(x)
fy = np.zeros_like(y)
fz = np.zeros_like(z)
if GRAVITY_FIELD:
fx += m * g[0]
fy += m * g[1]
fz += m * g[2]
if DAMPING_FORCE:
fx -= b[0] * vx
fy -= b[1] * vy
fz -= b[2] * vz
if ELASTIC_FORCE and BOND_PAIRS is not None and len(BOND_PAIRS) > 0:
for bond_idx, (idx_1, idx_2) in enumerate(BOND_PAIRS):
dx = x[idx_2] - x[idx_1]
dy = y[idx_2] - y[idx_1]
dz = z[idx_2] - z[idx_1]
dist = np.sqrt(dx * dx + dy * dy + dz * dz)
if dist <= 1e-12:
continue
stretch = dist - BOND_REST_LENGTHS[bond_idx]
force_mag = BOND_STIFFNESS[bond_idx] * stretch
ux = dx / dist
uy = dy / dist
uz = dz / dist
fx_bond = force_mag * ux
fy_bond = force_mag * uy
fz_bond = force_mag * uz
fx[idx_1] += fx_bond
fy[idx_1] += fy_bond
fz[idx_1] += fz_bond
fx[idx_2] -= fx_bond
fy[idx_2] -= fy_bond
fz[idx_2] -= fz_bond
if GRAVITY_INTERACTION:
n = len(m)
for i in range(n):
for j in range(i + 1, n):
dx = x[j] - x[i]
dy = y[j] - y[i]
dz = z[j] - z[i]
r2 = dx * dx + dy * dy + dz * dz
if r2 <= 1e-12:
continue
mi = m[i]; mj = m[j]
f_mag = GRAVITY_STRENGTH * mi * mj / r2
r = np.sqrt(r2)
fx_i = f_mag * dx / r
fy_i = f_mag * dy / r
fz_i = f_mag * dz / r
fx[i] += fx_i; fy[i] += fy_i; fz[i] += fz_i
fx[j] -= fx_i; fy[j] -= fy_i; fz[j] -= fz_i
return fx, fy, fz
def compute_acceleration(x, y, z, vx, vy, vz, m, g, b):
"""Compute acceleration from the shared force model."""
fx, fy, fz = compute_force(x, y, z, vx, vy, vz, m, g, b)
return fx / m, fy / m, fz / m
def Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
x = x + vx * dt
y = y + vy * dt
z = z + vz * dt
vx = vx + ax * dt
vy = vy + ay * dt
vz = vz + az * dt
return x, y, z, vx, vy, vz
def Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
gamma_x = b[0] / m
gamma_y = b[1] / m
gamma_z = b[2] / m
vx_next = (vx + g[0] * dt) / (1.0 + gamma_x * dt)
vy_next = (vy + g[1] * dt) / (1.0 + gamma_y * dt)
vz_next = (vz + g[2] * dt) / (1.0 + gamma_z * dt)
ax_next, ay_next, az_next = compute_acceleration(
x, y, z, vx_next, vy_next, vz_next, m, g, b)
vx = vx + ax_next * dt
vy = vy + ay_next * dt
vz = vz + az_next * dt
x = x + vx * dt
y = y + vy * dt
z = z + vz * dt
return x, y, z, vx, vy, vz
def Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b):
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
x_mid = x + 0.5 * vx * dt
y_mid = y + 0.5 * vy * dt
z_mid = z + 0.5 * vz * dt
vx_mid = vx + 0.5 * ax * dt
vy_mid = vy + 0.5 * ay * dt
vz_mid = vz + 0.5 * az * dt
x = x + vx_mid * dt
y = y + vy_mid * dt
z = z + vz_mid * dt
ax_mid, ay_mid, az_mid = compute_acceleration(
x_mid, y_mid, z_mid, vx_mid, vy_mid, vz_mid, m, g, b)
vx = vx + ax_mid * dt
vy = vy + ay_mid * dt
vz = vz + az_mid * dt
return x, y, z, vx, vy, vz
def Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b):
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
vx_half = vx + 0.5 * ax * dt
vy_half = vy + 0.5 * ay * dt
vz_half = vz + 0.5 * az * dt
x = x + vx_half * dt
y = y + vy_half * dt
z = z + vz_half * dt
# 显式预测器:用原始加速度外推半步(标准 Velocity-Verlet 预测步)
# v_pred = v_half + 0.5 * a(t)*dt,包含重力+阻尼+弹簧的所有贡献
vx_pred = vx_half + 0.5 * ax * dt
vy_pred = vy_half + 0.5 * ay * dt
vz_pred = vz_half + 0.5 * az * dt
ax_next, ay_next, az_next = compute_acceleration(
x, y, z, vx_pred, vy_pred, vz_pred, m, g, b)
vx = vx_half + 0.5 * ax_next * dt
vy = vy_half + 0.5 * ay_next * dt
vz = vz_half + 0.5 * az_next * dt
return x, y, z, vx, vy, vz
def Limit_in_box(a, amin, amax, va):
"""限制物体在边界内,发生碰撞时反弹。"""
over = a > amax
under = a < amin
a = np.where(over, amax, a)
a = np.where(under, amin, a)
va = np.where(over | under, -va, va)
return a, va
def apply_motion_update(x, y, z, vx, vy, vz, dt, m, g, b):
"""按配置选择的位置更新算法推进一步。"""
if METHOD == "explicit_euler":
x, y, z, vx, vy, vz = Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
elif METHOD == "implicit_euler":
x, y, z, vx, vy, vz = Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
elif METHOD == "midpoint":
x, y, z, vx, vy, vz = Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b)
elif METHOD == "leapfrog":
x, y, z, vx, vy, vz = Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b)
else:
raise ValueError(f"未知算法: {METHOD}")
x, vx = Limit_in_box(x, X_MIN, X_MAX, vx)
y, vy = Limit_in_box(y, Y_MIN, Y_MAX, vy)
z, vz = Limit_in_box(z, Z_MIN, Z_MAX, vz)
return x, y, z, vx, vy, vz
def apply_fixed_constraints(x, y, z, vx, vy, vz):
"""Keep fixed degrees of freedom at their initial coordinate with zero speed."""
fixed = ATOM_FIXED != 0
positions = np.column_stack((x, y, z))
velocities = np.column_stack((vx, vy, vz))
positions = np.where(fixed, ATOM_POSITIONS, positions)
velocities = np.where(fixed, 0.0, velocities)
return (
positions[:, 0], positions[:, 1], positions[:, 2],
velocities[:, 0], velocities[:, 1], velocities[:, 2],
)
def wrap_position(x, y, z):
"""边界回绕( dynamics 模式)。"""
x = np.where(x > X_MAX, X_MIN, x)
x = np.where(x < X_MIN, X_MAX, x)
y = np.where(y > Y_MAX, Y_MIN, y)
y = np.where(y < Y_MIN, Y_MAX, y)
z = np.where(z > Z_MAX, Z_MIN, z)
z = np.where(z < Z_MIN, Z_MAX, z)
return x, y, z
# ===========================================================================
# 主计算流程
# ===========================================================================
def run_simulation():
"""计算 NT 步轨迹,返回位置/速度数组。
步骤控制:
- warmup_steps: 预热步数,跳过不记录(用于稳定初始状态)
- 实际记录步数 = NT - warmup_steps
"""
# 预热阶段
x = ATOM_POSITIONS[:, 0].copy()
y = ATOM_POSITIONS[:, 1].copy()
z = ATOM_POSITIONS[:, 2].copy()
vx = ATOM_VELOCITIES[:, 0].copy()
vy = ATOM_VELOCITIES[:, 1].copy()
vz = ATOM_VELOCITIES[:, 2].copy()
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
# 初始时刻驱动力(t=0 时原子 1 的位置由驱动力决定而非 coord.txt)
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, 0.0, 0, DRIVER_DATA, DT)
if warmup_steps is not None and warmup_steps > 0:
print(f"[compute] 预热阶段: 前 {warmup_steps} 步不记录")
for step in trange(warmup_steps, desc="[compute] 预热"):
t = (step + 1) * DT
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, t, step, DRIVER_DATA, DT)
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
x, y, z = wrap_position(x, y, z)
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
print(
f"[compute] 预热完成,展示原子位置: "
f"({x[PLOT_ATOM_ROW]:.4f}, {y[PLOT_ATOM_ROW]:.4f}, {z[PLOT_ATOM_ROW]:.4f})"
)
# 记录阶段
record_steps = NT - (warmup_steps or 0)
n_atoms = len(ATOM_IDS)
traj_x = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_y = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_z = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_vx = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_vy = np.zeros((record_steps, n_atoms), dtype=np.float64)
traj_vz = np.zeros((record_steps, n_atoms), dtype=np.float64)
for step in trange(record_steps, desc="[compute] 计算中"):
t = (step + (warmup_steps or 0)) * DT
# 先施加驱动力(受驱原子的位置覆盖初始/固定约束,为弹簧力提供正确参考)
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, t, step, DRIVER_DATA, DT)
traj_x[step] = x
traj_y[step] = y
traj_z[step] = z
traj_vx[step] = vx
traj_vy[step] = vy
traj_vz[step] = vz
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
x, y, z = wrap_position(x, y, z)
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
def save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT):
"""将轨迹数据保存为逐行表格文本,便于教学查看。"""
if np.ndim(traj_x) > 1:
row = PLOT_ATOM_ROW if PLOT_ATOM_ROW is not None else 0
traj_x = traj_x[:, row]
traj_y = traj_y[:, row]
traj_z = traj_z[:, row]
traj_vx = traj_vx[:, row]
traj_vy = traj_vy[:, row]
traj_vz = traj_vz[:, row]
with open(txt_path, 'w', encoding='utf-8') as f:
# 写入表头
f.write("# 步数, 时间(s), x, y, z, vx, vy, vz\n")
for step in range(NT):
t = step * DT
x = traj_x[step]
y = traj_y[step]
z = traj_z[step]
vx = traj_vx[step]
vy = traj_vy[step]
vz = traj_vz[step]
f.write(f"{step+1:6d}, {t:.6f}, {x:.6f}, {y:.6f}, {z:.6f}, {vx:.6f}, {vy:.6f}, {vz:.6f}\n")
print(f"[compute] 轨迹文本文件已保存至: {txt_path}")
def main():
# 默认参数文件
script_dir = os.path.dirname(os.path.abspath(__file__))
param_file = os.path.join(get_input_dir(script_dir), "input.txt")
if len(sys.argv) > 1:
param_file = sys.argv[1]
# 加载参数
load_parameters(param_file)
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = get_output_dir(script_dir)
print(f"[compute] 开始计算 NT={NT} DT={DT} ")
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation()
print(f"[compute] 计算完成,共 {NT}")
save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, script_dir)
# 同时保存为逐行表格,便于直接查看
txt_path = os.path.join(output_dir, "trajectory_table.txt")
save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT)
if __name__ == "__main__":
main()