dc7bc00616
所有引擎(Python/C/C++)在 save_trajectory=0 时行为一致: - 计算时按 NSTEP 抽帧,只存 sampled 缓冲区 - 直接写入 display.txt(新文本格式) - 不生成 trajectory.txt Python 引擎:run_simulation 已支持 ✅ C 引擎:采样缓冲区 + write_display_txt ✅ C++ 引擎:采样缓冲区 + write_display_txt ✅ Fortran 引擎:待完成 compute.py run_engine:save_trajectory=0 时跳过 trajectory.txt 加载 dynamics.py:引擎直接输出 display.txt 时跳过抽帧步骤
1606 lines
61 KiB
Python
1606 lines
61 KiB
Python
"""
|
||
compute.py
|
||
----------
|
||
纯 Python 计算脚本,不依赖 VisPy。
|
||
|
||
功能:
|
||
1. 运行 NT 步物理模拟( kinematics / dynamics 等运动模式)
|
||
2. 按 NSTEP 抽帧,输出 output/display.txt(新文本格式)
|
||
3. 可选(save_trajectory=1)保留完整轨迹 output/trajectory.txt(JSON)
|
||
"""
|
||
|
||
import json
|
||
import re
|
||
import numpy as np
|
||
import os
|
||
import platform
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
import datetime
|
||
from tqdm import trange
|
||
|
||
# ===========================================================================
|
||
# 物理参数(必须与 draw.py 加载的边界常量保持一致)
|
||
# ===========================================================================
|
||
# 全局参数变量,将在 load_parameters 中填充
|
||
box_a = None
|
||
alpha = None
|
||
COORD_FILE = None
|
||
X0 = None
|
||
Y0 = None
|
||
Z0 = None
|
||
VX0 = None
|
||
VY0 = None
|
||
VZ0 = None
|
||
M = None
|
||
ATOM_IDS = None
|
||
ATOM_MASSES = None
|
||
ATOM_RADII = None
|
||
ATOM_POSITIONS = None
|
||
ATOM_VELOCITIES = None
|
||
ATOM_FIXED = None
|
||
BOND_CONNECTION_FILE = None
|
||
BOND_PARAMETER_FILE = None
|
||
BOND_PAIRS = None
|
||
BOND_NAMES = None
|
||
BOND_STIFFNESS = None
|
||
BOND_REST_LENGTHS = None
|
||
PLOT_ATOM_ID = None
|
||
PLOT_ATOM_ROW = None
|
||
G = None
|
||
B = None
|
||
METHOD = None
|
||
NT = None
|
||
DT = None
|
||
NSTEP = None
|
||
warmup_steps = None # 预热步数(跳过不保存)
|
||
sample_start = None # 抽帧起始索引
|
||
sample_end = None # 抽帧结束索引
|
||
ball_radius = None
|
||
ball_color_r = None
|
||
ball_color_g = None
|
||
ball_color_b = None
|
||
box_color_r = None
|
||
box_color_g = None
|
||
box_color_b = None
|
||
use_marker = 0
|
||
camera_keyframes_raw = ""
|
||
|
||
# 力开关
|
||
GRAVITY_FIELD = 1 # 均匀重力场
|
||
GRAVITY_INTERACTION = 0 # 原子间万有引力
|
||
ELASTIC_FORCE = 1 # 弹簧键力
|
||
DAMPING_FORCE = 0 # 阻尼
|
||
DRIVING_FORCE = 0 # 驱动力
|
||
GRAVITY_STRENGTH = 1.0
|
||
|
||
# 驱动力数据
|
||
DRIVER_DATA = None # 加载自 driver.txt
|
||
|
||
# 派生边界(根据 box_a 计算)
|
||
X_MIN = None
|
||
X_MAX = None
|
||
Y_MIN = None
|
||
Y_MAX = None
|
||
Z_MIN = None
|
||
Z_MAX = None
|
||
|
||
|
||
def _load_camera_motion(path):
|
||
"""读取 move_camera.txt(速度段格式),返回 JSON 字符串。
|
||
|
||
格式:每行是一个运动段
|
||
start-end vx=f1 vy=f2 vz=f3 rx=d1 ry=d2 rz=d3
|
||
示例:
|
||
1-60 vx=1.0 rx=10
|
||
30-90 vy=2.0 ry=20 rz=10
|
||
|
||
返回 JSON: [{"start":N,"end":N,"v":[x,y,z],"r":[x,y,z]},...]
|
||
"""
|
||
import re
|
||
if not os.path.exists(path):
|
||
print(f"[compute] 警告: 未找到 {path},跳过运动相机")
|
||
return ""
|
||
segments = []
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
# 解析帧范围:支持 "all"(全程)或 "N-M"(区间)
|
||
if line.lower().startswith("all") or re.match(r'^\s*all\s', line, re.IGNORECASE):
|
||
start, end = 0, 10**9
|
||
else:
|
||
m = re.match(r'(\d+)\s*-\s*(\d+)', line)
|
||
if not m:
|
||
continue
|
||
start, end = int(m.group(1)), int(m.group(2))
|
||
v = [0.0, 0.0, 0.0]
|
||
r = [0.0, 0.0, 0.0]
|
||
# 解析 vx=, vy=, vz=
|
||
for i, axis in enumerate(['x', 'y', 'z']):
|
||
m2 = re.search(r'v' + axis + r'\s*=\s*([-\d.]+)', line)
|
||
if m2:
|
||
v[i] = float(m2.group(1))
|
||
# 解析 rx=, ry=, rz=
|
||
for i, axis in enumerate(['x', 'y', 'z']):
|
||
m2 = re.search(r'r' + axis + r'\s*=\s*([-\d.]+)', line)
|
||
if m2:
|
||
r[i] = float(m2.group(1))
|
||
if any(v) or any(r):
|
||
segments.append({"start": start, "end": end, "v": v, "r": r})
|
||
if not segments:
|
||
return ""
|
||
import json
|
||
return json.dumps(segments)
|
||
|
||
|
||
def _to_text_value(value):
|
||
"""Convert numpy-heavy objects into JSON-friendly plain Python values."""
|
||
if isinstance(value, np.ndarray):
|
||
return value.tolist()
|
||
if isinstance(value, (np.integer,)):
|
||
return int(value)
|
||
if isinstance(value, (np.floating,)):
|
||
return float(value)
|
||
if isinstance(value, dict):
|
||
return {key: _to_text_value(val) for key, val in value.items()}
|
||
if isinstance(value, (list, tuple)):
|
||
return [_to_text_value(item) for item in value]
|
||
return value
|
||
|
||
|
||
def _from_text_value(value):
|
||
"""Convert nested numeric lists back into numpy arrays when appropriate."""
|
||
if isinstance(value, list):
|
||
if not value:
|
||
return np.array([], dtype=np.float64)
|
||
if all(not isinstance(item, (list, dict)) for item in value):
|
||
if all(isinstance(item, (int, float, bool)) for item in value):
|
||
return np.array(value)
|
||
return value
|
||
if all(isinstance(item, list) for item in value):
|
||
return np.array(value)
|
||
return [_from_text_value(item) for item in value]
|
||
if isinstance(value, dict):
|
||
return {key: _from_text_value(val) for key, val in value.items()}
|
||
return value
|
||
|
||
|
||
def save_text_data(path, data):
|
||
"""Save structured simulation data as formatted JSON text (旧格式,保留兼容)."""
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
json.dump(_to_text_value(data), f, ensure_ascii=False, indent=2)
|
||
return path
|
||
|
||
|
||
def load_text_data(path):
|
||
"""Load structured simulation data from JSON text (旧格式)."""
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return _from_text_value(data)
|
||
|
||
|
||
# ========================================================================
|
||
# 新 display.txt 格式:纯文本,按帧分块
|
||
# 第1行: number of frames: N
|
||
# 第2行: number of particles: M
|
||
# 第3行: frame: 1
|
||
# 第4行: n x y z vx vy vz
|
||
# 第5+行: 数据行(每个原子一行)
|
||
# 重复第3-5行直到所有帧
|
||
# ========================================================================
|
||
|
||
def save_display_txt(path, frames_x, frames_y, frames_z,
|
||
frames_vx, frames_vy, frames_vz,
|
||
atom_ids, n_total_frames, n_total_particles,
|
||
header_fields=None):
|
||
"""Write display.txt in new text format.
|
||
|
||
Args:
|
||
path: 输出文件路径
|
||
frames_x/y/z/vx/vy/vz: (n_frames, n_atoms) 数组
|
||
atom_ids: (n_atoms,) 原子编号数组
|
||
n_total_frames: 总帧数(含未采样)
|
||
n_total_particles: 总粒子数
|
||
header_fields: 可选的额外元数据字典(写入文件头之后)
|
||
"""
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
n_frames = frames_x.shape[0]
|
||
n_atoms = frames_x.shape[1]
|
||
|
||
# 格式化辅助:固定宽度,6位小数
|
||
def fmt(v): return f"{v:13.6f}"
|
||
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
f.write(f"number of frames: {n_total_frames}\n")
|
||
f.write(f"number of particles: {n_total_particles}\n")
|
||
# 写入额外元数据
|
||
if header_fields:
|
||
for k, v in header_fields.items():
|
||
f.write(f"{k}: {v}\n")
|
||
|
||
for fr in range(n_frames):
|
||
f.write(f"\nframe: {fr + 1}\n")
|
||
f.write(f"n x y z vx vy vz\n")
|
||
for a in range(n_atoms):
|
||
f.write(f"{atom_ids[a]:d}"
|
||
f"{fmt(frames_x[fr, a])}"
|
||
f"{fmt(frames_y[fr, a])}"
|
||
f"{fmt(frames_z[fr, a])}"
|
||
f"{fmt(frames_vx[fr, a])}"
|
||
f"{fmt(frames_vy[fr, a])}"
|
||
f"{fmt(frames_vz[fr, a])}\n")
|
||
return path
|
||
|
||
|
||
def load_display_txt(path):
|
||
"""Read display.txt new text format into numpy arrays(快速版).
|
||
|
||
Returns dict with keys: frames_x/y/z/vx/vy/vz, atom_ids,
|
||
n_total_frames, n_total_particles, header_fields
|
||
"""
|
||
import re
|
||
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
raw = f.read()
|
||
|
||
# 解析 header 行
|
||
header_fields = {}
|
||
n_total_frames = 0
|
||
n_total_particles = 0
|
||
|
||
lines = raw.splitlines()
|
||
data_start = 0
|
||
for i, line in enumerate(lines):
|
||
line_stripped = line.strip()
|
||
if line_stripped.startswith("number of frames:"):
|
||
n_total_frames = int(line_stripped.split(":")[1].strip())
|
||
elif line_stripped.startswith("number of particles:"):
|
||
n_total_particles = int(line_stripped.split(":")[1].strip())
|
||
elif line_stripped.startswith("frame:"):
|
||
data_start = i
|
||
break
|
||
else:
|
||
if ":" in line_stripped:
|
||
k, v = line_stripped.split(":", 1)
|
||
header_fields[k.strip()] = v.strip()
|
||
|
||
# 快速定位所有数据行:跳过 frame header 和 column header
|
||
# 数据行格式:每行 7 个字段(n x y z vx vy vz),固定宽度列
|
||
data_text = []
|
||
i = data_start
|
||
n_frames = 0
|
||
while i < len(lines):
|
||
line = lines[i].strip()
|
||
if line.startswith("frame:"):
|
||
n_frames += 1
|
||
i += 2 # 跳过 "frame: N" 和列头行
|
||
continue
|
||
if line:
|
||
data_text.append(line)
|
||
i += 1
|
||
|
||
if n_frames == 0 or not data_text:
|
||
raise ValueError(f"{path} 中没有有效帧数据")
|
||
|
||
# 用 numpy 批量解析所有数据行(远比逐行 split+float 快)
|
||
data_array = np.genfromtxt(data_text, dtype=np.float64)
|
||
# data_array shape: (n_frames * n_atoms, 7) — 列: n, x, y, z, vx, vy, vz
|
||
|
||
n_atoms = n_total_particles
|
||
atoms_per_frame = len(data_text) // n_frames
|
||
|
||
# 提取原子ID(第一帧即可)
|
||
atom_ids = data_array[0:n_atoms, 0].astype(np.int64)
|
||
|
||
# 重塑为 (n_frames, n_atoms, 6) — 去掉第0列(原子ID)
|
||
all_data = data_array[:, 1:].reshape(n_frames, n_atoms, 6)
|
||
|
||
return {
|
||
"frames_x": all_data[:, :, 0],
|
||
"frames_y": all_data[:, :, 1],
|
||
"frames_z": all_data[:, :, 2],
|
||
"frames_vx": all_data[:, :, 3],
|
||
"frames_vy": all_data[:, :, 4],
|
||
"frames_vz": all_data[:, :, 5],
|
||
"atom_ids": atom_ids,
|
||
"n_total_frames": n_total_frames,
|
||
"n_total_particles": n_total_particles,
|
||
"header_fields": header_fields,
|
||
}
|
||
|
||
|
||
def get_output_dir(base_dir=None):
|
||
"""Return the output directory used for generated artifacts."""
|
||
override = os.environ.get("DYNAMICS_OUTPUT_DIR")
|
||
if override:
|
||
output_dir = os.path.abspath(override)
|
||
else:
|
||
if base_dir is None:
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
output_dir = os.path.join(base_dir, "output")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
return output_dir
|
||
|
||
|
||
def get_input_dir(base_dir=None):
|
||
"""Return the input directory used for configuration and source data."""
|
||
override = os.environ.get("DYNAMICS_INPUT_DIR")
|
||
if override:
|
||
input_dir = os.path.abspath(override)
|
||
else:
|
||
if base_dir is None:
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
input_dir = os.path.join(base_dir, "input")
|
||
os.makedirs(input_dir, exist_ok=True)
|
||
return input_dir
|
||
|
||
|
||
def load_coord_file(coord_path):
|
||
"""Load atom ids, masses, radii, coordinates, and velocities."""
|
||
if not os.path.exists(coord_path):
|
||
raise FileNotFoundError(f"坐标文件不存在: {coord_path}")
|
||
|
||
rows = []
|
||
with open(coord_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
expected = ["n", "mass", "radius", "x", "y", "z", "vx", "vy", "vz"]
|
||
legacy = expected + ["fixed_x", "fixed_y", "fixed_z"]
|
||
legacy_new = expected + ["fix_x", "fix_y", "fix_z"]
|
||
if header not in (expected, legacy, legacy_new):
|
||
raise ValueError(
|
||
f"坐标文件表头应为: {' '.join(expected)} 或加三列 fix_x fix_y fix_z,"
|
||
f"实际为: {' '.join(header)}")
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if header == expected and len(parts) != 9:
|
||
raise ValueError(f"{coord_path}:{line_no} 应有 9 列,实际为 {len(parts)} 列")
|
||
if header in (legacy, legacy_new) and len(parts) != 12:
|
||
raise ValueError(f"{coord_path}:{line_no} 应有 12 列,实际为 {len(parts)} 列")
|
||
rows.append(parts)
|
||
|
||
if not rows:
|
||
raise ValueError(f"坐标文件没有原子数据: {coord_path}")
|
||
|
||
atom_ids = np.array([int(row[0]) for row in rows], dtype=np.int64)
|
||
masses = np.array([float(row[1]) for row in rows], dtype=np.float64)
|
||
radii = np.array([float(row[2]) for row in rows], dtype=np.float64)
|
||
positions = np.array([[float(row[3]), float(row[4]), float(row[5])] for row in rows],
|
||
dtype=np.float64)
|
||
velocities = np.array([[float(row[6]), float(row[7]), float(row[8])] for row in rows],
|
||
dtype=np.float64)
|
||
fixed = np.zeros((len(rows), 3), dtype=np.int64)
|
||
if header in (legacy, legacy_new):
|
||
fixed = np.array([[int(row[9]), int(row[10]), int(row[11])] for row in rows],
|
||
dtype=np.int64)
|
||
|
||
if np.any(masses <= 0):
|
||
raise ValueError("坐标文件中的质量必须为正数")
|
||
if np.any(radii <= 0):
|
||
raise ValueError("坐标文件中的半径必须为正数")
|
||
|
||
print(f"[compute] 已加载坐标文件: {coord_path},原子数={len(atom_ids)}")
|
||
return atom_ids, masses, radii, positions, velocities, fixed
|
||
|
||
|
||
def load_bond_parameters(bond_path):
|
||
"""Load bond stiffness and optional rest length by bond name.
|
||
|
||
File format (表头):
|
||
bond_name k # 2 列:从初始坐标计算键长
|
||
bond_name k rest_length # 3 列:显式指定键长
|
||
"""
|
||
if not os.path.exists(bond_path):
|
||
raise FileNotFoundError(f"键参数文件不存在: {bond_path}")
|
||
|
||
bond_map = {}
|
||
with open(bond_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
|
||
if header == ["bond_name", "k"]:
|
||
has_rest_length = False
|
||
elif header == ["bond_name", "k", "rest_length"]:
|
||
has_rest_length = True
|
||
else:
|
||
expected = ["bond_name", "k"] + (["rest_length"] if len(header) == 3 else [])
|
||
raise ValueError(
|
||
f"键参数文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
|
||
|
||
ncols = 3 if has_rest_length else 2
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if len(parts) != ncols:
|
||
raise ValueError(
|
||
f"{bond_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)} 列")
|
||
bond_name = parts[0]
|
||
stiffness = float(parts[1])
|
||
if stiffness < 0:
|
||
raise ValueError(f"{bond_path}:{line_no} 劲度系数必须非负")
|
||
|
||
rest_length = None
|
||
if has_rest_length:
|
||
rest_length = float(parts[2])
|
||
if rest_length <= 0:
|
||
raise ValueError(f"{bond_path}:{line_no} 键长必须为正数")
|
||
|
||
bond_map[bond_name] = {"stiffness": stiffness, "rest_length": rest_length}
|
||
|
||
if not bond_map:
|
||
raise ValueError(f"键参数文件没有有效数据: {bond_path}")
|
||
return bond_map
|
||
|
||
|
||
def load_bond_connections(connection_path, atom_ids, atom_positions, bond_map):
|
||
"""Load bonded atom pairs and derive rest lengths.
|
||
|
||
键长优先级:
|
||
1. bond_map 中显式指定的 rest_length(bond.txt 第三列)
|
||
2. 否则从初始坐标计算
|
||
"""
|
||
if not os.path.exists(connection_path):
|
||
raise FileNotFoundError(f"成键连接文件不存在: {connection_path}")
|
||
|
||
atom_index = {int(atom_id): idx for idx, atom_id in enumerate(atom_ids)}
|
||
pairs = []
|
||
bond_names = []
|
||
stiffness = []
|
||
rest_lengths = []
|
||
|
||
with open(connection_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
expected = ["n1", "n2", "bond_name"]
|
||
if header != expected:
|
||
raise ValueError(
|
||
f"成键连接文件表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if len(parts) != 3:
|
||
raise ValueError(f"{connection_path}:{line_no} 应有 3 列,实际为 {len(parts)} 列")
|
||
|
||
atom_1 = int(parts[0])
|
||
atom_2 = int(parts[1])
|
||
bond_name = parts[2]
|
||
|
||
if atom_1 not in atom_index or atom_2 not in atom_index:
|
||
raise ValueError(
|
||
f"{connection_path}:{line_no} 中的原子序号 {atom_1}, {atom_2} 不在坐标文件里")
|
||
if bond_name not in bond_map:
|
||
raise ValueError(
|
||
f"{connection_path}:{line_no} 使用了未定义的键类型 {bond_name}")
|
||
|
||
idx_1 = atom_index[atom_1]
|
||
idx_2 = atom_index[atom_2]
|
||
|
||
bond_entry = bond_map[bond_name]
|
||
rest_length = bond_entry["rest_length"]
|
||
if rest_length is None:
|
||
# 未指定键长时从初始坐标计算
|
||
delta = atom_positions[idx_2] - atom_positions[idx_1]
|
||
rest_length = float(np.linalg.norm(delta))
|
||
|
||
pairs.append((idx_1, idx_2))
|
||
bond_names.append(bond_name)
|
||
stiffness.append(bond_entry["stiffness"])
|
||
rest_lengths.append(rest_length)
|
||
|
||
if not pairs:
|
||
return (
|
||
np.zeros((0, 2), dtype=np.int64),
|
||
[],
|
||
np.zeros(0, dtype=np.float64),
|
||
np.zeros(0, dtype=np.float64),
|
||
)
|
||
|
||
return (
|
||
np.array(pairs, dtype=np.int64),
|
||
bond_names,
|
||
np.array(stiffness, dtype=np.float64),
|
||
np.array(rest_lengths, dtype=np.float64),
|
||
)
|
||
|
||
|
||
# ===========================================================================
|
||
# 驱动力加载 & 应用
|
||
# ===========================================================================
|
||
|
||
def load_driver_file(driver_path, atom_ids):
|
||
"""从 driver.txt 加载驱动力定义。
|
||
|
||
格式:
|
||
n amp_x amp_y amp_z freq_x freq_y freq_z phi_x phi_y phi_z period
|
||
数值 0 0 0 0 0 10 0 0 90 all
|
||
|
||
其中:
|
||
position = A * cos(2π f t + φ), φ 为角度制(自动转弧度)
|
||
period = all | 数值(周期数,结束后原子静止)
|
||
"""
|
||
if not os.path.exists(driver_path):
|
||
print(f"[compute] 警告: 驱动力文件不存在: {driver_path}")
|
||
return None
|
||
|
||
atom_index = {int(aid): idx for idx, aid in enumerate(atom_ids)}
|
||
drivers = []
|
||
ncols = 11
|
||
|
||
with open(driver_path, "r", encoding="utf-8") as f:
|
||
header = f.readline().strip().split()
|
||
expected = ["n", "amp_x", "amp_y", "amp_z",
|
||
"freq_x", "freq_y", "freq_z",
|
||
"phi_x", "phi_y", "phi_z", "period"]
|
||
if header != expected:
|
||
raise ValueError(
|
||
f"driver.txt 表头应为: {' '.join(expected)},实际为: {' '.join(header)}")
|
||
|
||
for line_no, line in enumerate(f, start=2):
|
||
stripped = line.strip()
|
||
if not stripped or stripped.startswith("#"):
|
||
continue
|
||
parts = stripped.split()
|
||
if len(parts) != ncols:
|
||
raise ValueError(
|
||
f"{driver_path}:{line_no} 应有 {ncols} 列,实际为 {len(parts)} 列")
|
||
|
||
n = int(parts[0])
|
||
if n not in atom_index:
|
||
raise ValueError(f"{driver_path}:{line_no} 原子序号 {n} 不在 coord.txt 中")
|
||
|
||
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])],
|
||
dtype=np.float64)
|
||
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])],
|
||
dtype=np.float64)
|
||
phi_deg = np.array([float(parts[7]), float(parts[8]), float(parts[9])],
|
||
dtype=np.float64)
|
||
phi_rad = phi_deg * np.pi / 180.0
|
||
|
||
period_str = parts[10]
|
||
|
||
drivers.append({
|
||
"atom_index": atom_index[n],
|
||
"atom_id": n,
|
||
"amp": amp,
|
||
"freq": freq,
|
||
"phi": phi_rad,
|
||
"period_str": period_str,
|
||
"period_cycles": None if period_str == "all" else float(period_str),
|
||
# 在模拟中动态记录:冻结步数索引、冻结位置
|
||
"freeze_step": None,
|
||
"freeze_pos": None,
|
||
})
|
||
|
||
if not drivers:
|
||
print(f"[compute] 警告: driver.txt 中没有有效数据")
|
||
return None
|
||
|
||
print(f"[compute] 已加载驱动力: {len(drivers)} 条定义")
|
||
for d in drivers:
|
||
print(f" 原子 {d['atom_id']}: "
|
||
f"A=({d['amp'][0]},{d['amp'][1]},{d['amp'][2]}), "
|
||
f"f=({d['freq'][0]},{d['freq'][1]},{d['freq'][2]}), "
|
||
f"φ=({phi_deg[d['amp'].tolist().index(max(d['amp']))]}° 等), "
|
||
f"period={d['period_str']}")
|
||
return drivers
|
||
|
||
|
||
def apply_driving_force(x, y, z, vx, vy, vz, t, step, drivers, dt):
|
||
"""对受驱原子按驱动力函数覆盖位置/速度。
|
||
|
||
驱动力公式:pos = A * cos(2π f t + φ)
|
||
vel = -A * 2π f * sin(2π f t + φ)
|
||
"""
|
||
if drivers is None:
|
||
return x, y, z, vx, vy, vz
|
||
|
||
for d in drivers:
|
||
idx = d["atom_index"]
|
||
|
||
# 确定驱动力持续到哪一步
|
||
if d["period_cycles"] is not None:
|
||
max_freq = np.max(np.abs(d["freq"]))
|
||
if max_freq > 1e-12:
|
||
period_duration = d["period_cycles"] / max_freq
|
||
period_steps = int(period_duration / dt)
|
||
else:
|
||
period_steps = 0
|
||
|
||
if step > period_steps:
|
||
# 驱动力已结束:原子静止(保持最后位置,速度归零)
|
||
if d["freeze_pos"] is not None:
|
||
x[idx], y[idx], z[idx] = d["freeze_pos"]
|
||
vx[idx] = vy[idx] = vz[idx] = 0.0
|
||
continue
|
||
else:
|
||
period_steps = None # 全程驱动
|
||
|
||
# 当前驱动力下的位置 / 速度
|
||
t_vec = np.array([t, t, t], dtype=np.float64)
|
||
pos_drive = d["amp"] * np.cos(2.0 * np.pi * d["freq"] * t_vec + d["phi"])
|
||
vel_drive = -d["amp"] * 2.0 * np.pi * d["freq"] * np.sin(2.0 * np.pi * d["freq"] * t_vec + d["phi"])
|
||
|
||
x[idx] = pos_drive[0]
|
||
y[idx] = pos_drive[1]
|
||
z[idx] = pos_drive[2]
|
||
vx[idx] = vel_drive[0]
|
||
vy[idx] = vel_drive[1]
|
||
vz[idx] = vel_drive[2]
|
||
|
||
# 若周期有限,记录冻结位置(驱动力最后一帧的位置)
|
||
if period_steps is not None and step == period_steps:
|
||
d["freeze_pos"] = (float(pos_drive[0]), float(pos_drive[1]), float(pos_drive[2]))
|
||
|
||
return x, y, z, vx, vy, vz
|
||
def parse_gravity_vector(value):
|
||
"""Parse gravity into a 3D acceleration vector.
|
||
|
||
Backward compatibility:
|
||
- scalar `G: 9.8` -> [0, 0, -9.8]
|
||
- vector `G: [gx, gy, gz]` -> [gx, gy, gz]
|
||
"""
|
||
if isinstance(value, (int, float, np.integer, np.floating)):
|
||
return np.array([0.0, 0.0, -float(value)], dtype=np.float64)
|
||
vector = np.asarray(value, dtype=np.float64)
|
||
if vector.shape != (3,):
|
||
raise ValueError(f"G 必须是长度为 3 的分量数组,实际为 {value}")
|
||
return vector
|
||
|
||
|
||
def parse_damping_vector(value):
|
||
"""Parse damping into a 3D component vector.
|
||
|
||
Backward compatibility:
|
||
- scalar `B: 0.5` -> [0.5, 0.5, 0.5]
|
||
- vector `B: [bx, by, bz]` -> [bx, by, bz]
|
||
"""
|
||
if isinstance(value, (int, float, np.integer, np.floating)):
|
||
scalar = float(value)
|
||
return np.array([scalar, scalar, scalar], dtype=np.float64)
|
||
vector = np.asarray(value, dtype=np.float64)
|
||
if vector.shape != (3,):
|
||
raise ValueError(f"B 必须是长度为 3 的分量数组,实际为 {value}")
|
||
return vector
|
||
|
||
|
||
def run_from_config(config, out_dir=None):
|
||
"""直接接收配置字典,运行模拟并保存结果。
|
||
|
||
config 字典结构:
|
||
box_a, alpha, coord_file, connection_file, bond_file, plot_atom, G, B, method, NT, DT, NSTEP,
|
||
warmup_steps, sample_start, sample_end(可选),
|
||
ball_radius, ball_color_r/g/b, box_color_r/g/b(可选)
|
||
|
||
out_dir: 输出目录,默认同脚本所在目录
|
||
返回:(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
|
||
"""
|
||
global box_a, alpha, COORD_FILE, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
|
||
global PLOT_ATOM_ID, PLOT_ATOM_ROW
|
||
global ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS, ATOM_VELOCITIES, ATOM_FIXED
|
||
global BOND_CONNECTION_FILE, BOND_PARAMETER_FILE, BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS, BOND_REST_LENGTHS
|
||
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
|
||
global ball_radius, ball_color_r, ball_color_g, ball_color_b
|
||
global box_color_r, box_color_g, box_color_b
|
||
global use_marker, camera_keyframes_raw
|
||
global warmup_steps, sample_start, sample_end
|
||
global GRAVITY_FIELD, GRAVITY_INTERACTION, ELASTIC_FORCE, DAMPING_FORCE, GRAVITY_STRENGTH
|
||
global DRIVING_FORCE, DRIVER_DATA
|
||
|
||
box_a = float(config.get("box_a", 10.0))
|
||
raw_alpha = config.get("alpha", 0.2)
|
||
if isinstance(raw_alpha, (list, tuple)):
|
||
alpha = [float(a) for a in raw_alpha]
|
||
else:
|
||
alpha = float(raw_alpha)
|
||
COORD_FILE = str(config.get("coord_file", os.path.join("input", "coord.txt")))
|
||
coord_path = COORD_FILE
|
||
if out_dir is not None and not os.path.isabs(coord_path):
|
||
coord_path = os.path.join(out_dir, coord_path)
|
||
(ATOM_IDS, ATOM_MASSES, ATOM_RADII, ATOM_POSITIONS,
|
||
ATOM_VELOCITIES, ATOM_FIXED) = load_coord_file(coord_path)
|
||
BOND_CONNECTION_FILE = str(config.get("connection_file", os.path.join("input", "connection.txt")))
|
||
BOND_PARAMETER_FILE = str(config.get("bond_file", os.path.join("input", "bond.txt")))
|
||
connection_path = BOND_CONNECTION_FILE
|
||
bond_path = BOND_PARAMETER_FILE
|
||
if out_dir is not None and not os.path.isabs(connection_path):
|
||
connection_path = os.path.join(out_dir, connection_path)
|
||
if out_dir is not None and not os.path.isabs(bond_path):
|
||
bond_path = os.path.join(out_dir, bond_path)
|
||
bond_map = load_bond_parameters(bond_path)
|
||
(BOND_PAIRS, BOND_NAMES, BOND_STIFFNESS,
|
||
BOND_REST_LENGTHS) = load_bond_connections(
|
||
connection_path, ATOM_IDS, ATOM_POSITIONS, bond_map)
|
||
PLOT_ATOM_ID = int(config.get("plot_atom", int(ATOM_IDS[0])))
|
||
matches = np.where(ATOM_IDS == PLOT_ATOM_ID)[0]
|
||
if len(matches) == 0:
|
||
raise ValueError(f"plot_atom={PLOT_ATOM_ID} 不在坐标文件原子序号中")
|
||
PLOT_ATOM_ROW = int(matches[0])
|
||
|
||
M = float(ATOM_MASSES[PLOT_ATOM_ROW])
|
||
X0, Y0, Z0 = [float(v) for v in ATOM_POSITIONS[PLOT_ATOM_ROW]]
|
||
VX0, VY0, VZ0 = [float(v) for v in ATOM_VELOCITIES[PLOT_ATOM_ROW]]
|
||
G = parse_gravity_vector(config["G"])
|
||
B = parse_damping_vector(config["B"])
|
||
METHOD = normalize_method_name(config.get("method", "explicit_euler"))
|
||
NT = int(config["NT"])
|
||
DT = float(config["DT"])
|
||
NSTEP = int(config["NSTEP"])
|
||
|
||
# 步骤控制参数(可选,有默认值)
|
||
warmup_steps = int(config.get("warmup_steps", 0))
|
||
sample_start = config.get("sample_start") # None 表示从头开始
|
||
sample_end = config.get("sample_end") # None 表示到末尾
|
||
|
||
X_MIN = -box_a; X_MAX = box_a
|
||
Y_MIN = -box_a; Y_MAX = box_a
|
||
Z_MIN = -box_a; Z_MAX = box_a
|
||
|
||
ball_radius = float(config.get("ball_radius", ATOM_RADII[PLOT_ATOM_ROW]))
|
||
ball_color_r = float(config.get("ball_color_r", 0.9))
|
||
ball_color_g = float(config.get("ball_color_g", 0.2))
|
||
ball_color_b = float(config.get("ball_color_b", 0.2))
|
||
box_color_r = float(config.get("box_color_r", 0.8))
|
||
box_color_g = float(config.get("box_color_g", 0.8))
|
||
box_color_b = float(config.get("box_color_b", 0.85))
|
||
use_marker = int(config.get("use_marker", 0))
|
||
|
||
# 力开关
|
||
global GRAVITY_FIELD, GRAVITY_INTERACTION, ELASTIC_FORCE, DAMPING_FORCE, GRAVITY_STRENGTH
|
||
global DRIVING_FORCE, DRIVER_DATA
|
||
GRAVITY_FIELD = int(config.get("gravity_field", 1))
|
||
GRAVITY_INTERACTION = int(config.get("gravity_interaction", 0))
|
||
ELASTIC_FORCE = int(config.get("elastic_force", 1))
|
||
DAMPING_FORCE = int(config.get("damping_force", 0))
|
||
GRAVITY_STRENGTH = float(config.get("gravity_strength", 1.0))
|
||
DRIVING_FORCE = int(config.get("driving_force", 0))
|
||
|
||
# 加载驱动力定义
|
||
DRIVER_DATA = None
|
||
if DRIVING_FORCE:
|
||
driver_rel = str(config.get("driver_file", os.path.join("input", "driver.txt")))
|
||
driver_path = driver_rel
|
||
if out_dir is not None and not os.path.isabs(driver_rel):
|
||
driver_path = os.path.join(out_dir, driver_rel)
|
||
DRIVER_DATA = load_driver_file(driver_path, ATOM_IDS)
|
||
|
||
# 加载运动相机关键帧
|
||
camera_keyframes_raw = ""
|
||
move_camera = int(config.get("move_camera", 0))
|
||
camera_keyframes_raw = ""
|
||
if move_camera:
|
||
cam_rel = str(config.get("move_camera_file", os.path.join("input", "move_camera.txt")))
|
||
cam_path = cam_rel
|
||
if out_dir is not None and not os.path.isabs(cam_rel):
|
||
cam_path = os.path.join(out_dir, cam_rel)
|
||
camera_keyframes_raw = _load_camera_motion(cam_path)
|
||
|
||
print(f"[compute] 使用算法: {METHOD}")
|
||
print(f"[compute] 已加载成键信息: {len(BOND_PAIRS)} 条键")
|
||
if config.get("_skip_run", False):
|
||
return None, None, None, None, None, None
|
||
t_start = time.time()
|
||
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation(
|
||
save_trajectory=int(config.get("save_trajectory", 0)))
|
||
t_end = time.time()
|
||
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
elapsed = t_end - t_start
|
||
|
||
# 写入 dynamics.log
|
||
log_dir = get_output_dir(out_dir)
|
||
log_path = os.path.join(log_dir, "dynamics.log")
|
||
with open(log_path, "w", encoding="utf-8") as f:
|
||
f.write("=" * 50 + "\n")
|
||
f.write("Dynamics 计算日志\n")
|
||
f.write("=" * 50 + "\n")
|
||
f.write(f"计算开始时间: {t_start_str}\n")
|
||
f.write(f"计算结束时间: {t_end_str}\n")
|
||
f.write(f"计算耗时: {elapsed:.3f} s\n")
|
||
f.write("-" * 50 + "\n")
|
||
f.write("计算参数:\n")
|
||
f.write(f" box_a: {box_a}\n")
|
||
_alpha_str = alpha if isinstance(alpha, str) else (str(alpha) if isinstance(alpha, list) else str(alpha))
|
||
f.write(f" alpha: {_alpha_str}\n")
|
||
f.write(f" coord_file: {COORD_FILE}\n")
|
||
f.write(f" method: {METHOD}\n")
|
||
f.write(f" NT: {NT}\n")
|
||
f.write(f" DT: {DT}\n")
|
||
f.write(f" NSTEP: {NSTEP}\n")
|
||
f.write(f" G: ({G[0]:.2f}, {G[1]:.2f}, {G[2]:.2f})\n")
|
||
f.write(f" B: ({B[0]:.2f}, {B[1]:.2f}, {B[2]:.2f})\n")
|
||
f.write(f" warmup: {warmup_steps}\n")
|
||
f.write(f" 原子数: {len(ATOM_IDS)}\n")
|
||
f.write(f" 键数: {len(BOND_PAIRS)}\n")
|
||
f.write("-" * 50 + "\n")
|
||
f.write("初始状态 (plot_atom):\n")
|
||
f.write(f" position: ({X0:.6f}, {Y0:.6f}, {Z0:.6f})\n")
|
||
f.write(f" velocity: ({VX0:.6f}, {VY0:.6f}, {VZ0:.6f})\n")
|
||
f.write(f" mass: {M}\n")
|
||
f.write("=" * 50 + "\n")
|
||
print(f"[compute] 日志已保存至: {log_path}")
|
||
record = len(traj_x)
|
||
print(f"[compute] 计算完成: {record} 步, {elapsed:.3f} s")
|
||
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
|
||
|
||
|
||
def run_engine(engine, input_dir, output_dir, config):
|
||
"""调用外部计算引擎(C/C++/Fortran),生成 trajectory.txt。
|
||
|
||
Args:
|
||
engine: 引擎名称 ("c", "cpp", "fortran")
|
||
input_dir: 输入目录(含 coord.txt, connection.txt, bond.txt)
|
||
output_dir: 输出目录(轨迹文件写入位置)
|
||
config: YAML 配置字典
|
||
Returns:
|
||
(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
|
||
"""
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
system = platform.system().lower()
|
||
engine_map = {
|
||
"c": "engines/c/build/dynamics_c",
|
||
"cpp": "engines/cpp/build/dynamics_cpp",
|
||
"fortran": "engines/fortran/build/dynamics_f90",
|
||
}
|
||
if engine not in engine_map:
|
||
raise ValueError(f"不支持的引擎: {engine},可选: {list(engine_map.keys())}")
|
||
|
||
engine_rel = engine_map[engine]
|
||
engine_path = os.path.join(script_dir, engine_rel)
|
||
|
||
# 自动检测可执行文件后缀和平台专用版本
|
||
candidates = [
|
||
engine_path, # 无后缀
|
||
engine_path + ".exe", # Windows .exe
|
||
engine_path + f"_{system}.exe", # 平台专用 (c_linux.exe, c_darwin.exe)
|
||
]
|
||
found = None
|
||
for p in candidates:
|
||
if os.path.exists(p):
|
||
found = p
|
||
break
|
||
if found is None:
|
||
raise FileNotFoundError(
|
||
f"引擎可执行文件不存在: 尝试了 {candidates}\n"
|
||
f"请先编译: cd engines/{engine} && make\n"
|
||
f"或安装交叉编译器后: cd engines/{engine} && make {system}")
|
||
|
||
# 构造 param.json(数值参数)
|
||
G = parse_gravity_vector(config.get("G", [0, 0, -9.8]))
|
||
B = parse_damping_vector(config.get("B", [0, 0, 0]))
|
||
param_json = {
|
||
"box_a": float(config.get("box_a", 10.0)),
|
||
"NT": int(config["NT"]),
|
||
"DT": float(config["DT"]),
|
||
"NSTEP": int(config.get("NSTEP", 1)),
|
||
"warmup_steps": int(config.get("warmup_steps", 0)),
|
||
"method": normalize_method_name(config.get("method", "leapfrog")),
|
||
"G": [float(v) for v in G],
|
||
"B": [float(v) for v in B],
|
||
"gravity_field": int(config.get("gravity_field", 1)),
|
||
"gravity_interaction": int(config.get("gravity_interaction", 0)),
|
||
"elastic_force": int(config.get("elastic_force", 1)),
|
||
"damping_force": int(config.get("damping_force", 0)),
|
||
"gravity_strength": float(config.get("gravity_strength", 1.0)),
|
||
"driving_force": int(config.get("driving_force", 0)),
|
||
"save_trajectory": int(config.get("save_trajectory", 0)),
|
||
}
|
||
param_path = os.path.join(script_dir, "engines", engine, "param.json")
|
||
os.makedirs(os.path.dirname(param_path), exist_ok=True)
|
||
with open(param_path, "w", encoding="utf-8") as f:
|
||
json.dump(param_json, f, indent=2)
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
print(f"[compute] 引擎: {engine} → {os.path.basename(engine_path)}")
|
||
print(f"[compute] input: {input_dir}")
|
||
print(f"[compute] output: {output_dir}")
|
||
|
||
# ── 预校准:跑 NT=1000 步测速 ────────────────────────────────
|
||
total_steps = int(config["NT"]) - int(config.get("warmup_steps", 0))
|
||
_calib_nt = min(1000, max(100, total_steps // 10))
|
||
_calib_param = dict(param_json)
|
||
_calib_param["NT"] = _calib_nt
|
||
_calib_path = os.path.join(script_dir, "engines", engine, "_calib.json")
|
||
with open(_calib_path, "w", encoding="utf-8") as _cf:
|
||
json.dump(_calib_param, _cf, indent=2)
|
||
_calib_outdir = os.path.join(script_dir, "engines", engine, "_calib_out")
|
||
os.makedirs(_calib_outdir, exist_ok=True)
|
||
_ct0 = time.time()
|
||
subprocess.run(
|
||
[engine_path, os.path.abspath(input_dir), _calib_outdir, _calib_path],
|
||
capture_output=True, timeout=60)
|
||
# 清理校准临时文件
|
||
for _f in os.listdir(_calib_outdir):
|
||
try: os.remove(os.path.join(_calib_outdir, _f))
|
||
except OSError: pass
|
||
try: os.rmdir(_calib_outdir)
|
||
except OSError: pass
|
||
os.remove(_calib_path)
|
||
_calib_elapsed = max(time.time() - _ct0, 0.001)
|
||
_overhead = _calib_elapsed * 0.15
|
||
_step_time = max(_calib_elapsed - _overhead, 0.0001) / _calib_nt
|
||
_est_total = max(_calib_elapsed, _overhead + _step_time * total_steps)
|
||
|
||
t_start = time.time()
|
||
t_start_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
_p = subprocess.Popen(
|
||
[engine_path, os.path.abspath(input_dir), os.path.abspath(output_dir), param_path],
|
||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||
text=True, encoding='utf-8', errors='replace')
|
||
_engine_lines = []
|
||
|
||
try:
|
||
from tqdm import tqdm as _tqdm
|
||
_pbar = _tqdm(total=total_steps, desc=f"[compute] 引擎 {engine}",
|
||
unit="步", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
||
except ImportError:
|
||
_pbar = None
|
||
|
||
try:
|
||
while True:
|
||
_line = _p.stdout.readline() if _p.stdout else ''
|
||
if _line:
|
||
_line = _line.strip()
|
||
if _line:
|
||
_engine_lines.append(_line)
|
||
# 读取外部引擎真实进度:格式 "[xxx-engine] progress: N/total"
|
||
_prog_match = re.search(r'progress:\s*(\d+)/(\d+)', _line)
|
||
if _pbar is not None and _prog_match:
|
||
_prog_done = int(_prog_match.group(1))
|
||
_prog_total = int(_prog_match.group(2))
|
||
if _prog_total > 0:
|
||
_pbar.n = min(_prog_done, total_steps)
|
||
_pbar.refresh()
|
||
if _p.poll() is not None:
|
||
if _p.stdout:
|
||
for _r in _p.stdout:
|
||
_r = _r.strip()
|
||
if _r:
|
||
_engine_lines.append(_r)
|
||
# 读取残留在管道中的进度消息,避免 20%→100% 跳变
|
||
_prog_match = re.search(r'progress:\s*(\d+)/(\d+)', _r)
|
||
if _pbar is not None and _prog_match:
|
||
_p_done = min(int(_prog_match.group(1)), total_steps)
|
||
_pbar.n = max(_pbar.n, _p_done)
|
||
if _pbar is not None: _pbar.refresh()
|
||
break
|
||
time.sleep(0.05)
|
||
finally:
|
||
if _pbar is not None:
|
||
_pbar.n = total_steps
|
||
_pbar.refresh()
|
||
_pbar.close()
|
||
if _p.poll() is None:
|
||
_p.kill()
|
||
_p.wait(timeout=5)
|
||
|
||
_rc = _p.returncode
|
||
t_end = time.time()
|
||
t_end_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
elapsed = t_end - t_start
|
||
|
||
if _rc != 0:
|
||
_err = _p.stderr.read() if _p.stderr else ''
|
||
if _err:
|
||
print(f"[compute] 引擎错误:\n{_err}")
|
||
raise RuntimeError(f"引擎 {engine} 返回错误码 {_rc}")
|
||
|
||
if _engine_lines:
|
||
_log_path = os.path.join(output_dir, "dynamics.log")
|
||
try:
|
||
with open(_log_path, "a", encoding="utf-8") as _lf:
|
||
_lf.write("\n--- 引擎输出 ---\n")
|
||
for _ln in _engine_lines:
|
||
_lf.write(_ln + "\n")
|
||
except OSError:
|
||
pass
|
||
|
||
# 加载输出的 trajectory.txt / display.txt
|
||
save_traj = int(config.get("save_trajectory", 0))
|
||
if not save_traj:
|
||
# save_trajectory=0:引擎只写 display.txt
|
||
disp_path = os.path.join(os.path.abspath(output_dir), "display.txt")
|
||
if not os.path.exists(disp_path):
|
||
raise FileNotFoundError(f"引擎未生成 display.txt: {disp_path}")
|
||
print(f"[compute] 引擎已生成 {disp_path}")
|
||
return None, None, None, None, None, None
|
||
|
||
# save_trajectory=1:加载完整 trajectory.txt
|
||
traj_path = os.path.join(os.path.abspath(output_dir), "trajectory.txt")
|
||
if not os.path.exists(traj_path):
|
||
raise FileNotFoundError(f"引擎未生成 trajectory.txt: {traj_path}")
|
||
data = load_text_data(traj_path)
|
||
traj_x = data["traj_x"]
|
||
traj_y = data["traj_y"]
|
||
traj_z = data["traj_z"]
|
||
traj_vx = data["traj_vx"]
|
||
traj_vy = data["traj_vy"]
|
||
traj_vz = data["traj_vz"]
|
||
|
||
# 写入日志文件
|
||
log_path = os.path.join(output_dir, "dynamics.log")
|
||
with open(log_path, "w", encoding="utf-8") as f:
|
||
f.write("=" * 50 + "\n")
|
||
f.write("Dynamics 计算日志\n")
|
||
f.write("=" * 50 + "\n")
|
||
f.write(f"引擎: {engine}\n")
|
||
f.write(f"计算开始时间: {t_start_str}\n")
|
||
f.write(f"计算结束时间: {t_end_str}\n")
|
||
f.write(f"计算耗时: {elapsed:.3f} s\n")
|
||
|
||
print(f"[compute] 引擎完成: {len(traj_x)} 步, {traj_x.shape[1]} 原子, {elapsed:.3f} s")
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz
|
||
|
||
|
||
def save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, out_dir=None):
|
||
"""将轨迹数据保存到结构化 txt 文件(含所有参数元数据)。"""
|
||
if out_dir is None:
|
||
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||
out_path = os.path.join(get_output_dir(out_dir), "trajectory.txt")
|
||
|
||
# 记录的实际步数(扣除预热后)
|
||
record_steps = len(traj_x)
|
||
sample_start_to_save = -1 if sample_start is None else int(sample_start)
|
||
sample_end_to_save = -1 if sample_end is None else int(sample_end)
|
||
|
||
payload = {
|
||
"traj_x": traj_x, "traj_y": traj_y, "traj_z": traj_z,
|
||
"traj_vx": traj_vx, "traj_vy": traj_vy, "traj_vz": traj_vz,
|
||
"NT": record_steps, "DT": DT, "NSTEP": NSTEP,
|
||
"method": METHOD,
|
||
"coord_file": COORD_FILE,
|
||
"connection_file": BOND_CONNECTION_FILE,
|
||
"bond_file": BOND_PARAMETER_FILE,
|
||
"atom_ids": ATOM_IDS,
|
||
"atom_masses": ATOM_MASSES,
|
||
"atom_radii": ATOM_RADII,
|
||
"atom_positions": ATOM_POSITIONS,
|
||
"atom_velocities": ATOM_VELOCITIES,
|
||
"atom_fixed": ATOM_FIXED,
|
||
"bond_pairs": BOND_PAIRS,
|
||
"bond_names": BOND_NAMES,
|
||
"bond_stiffness": BOND_STIFFNESS,
|
||
"bond_rest_lengths": BOND_REST_LENGTHS,
|
||
"G": G.tolist() if G is not None else None,
|
||
"B": B.tolist() if B is not None else None,
|
||
"plot_atom_id": PLOT_ATOM_ID,
|
||
"plot_atom_row": PLOT_ATOM_ROW,
|
||
"warmup_steps": warmup_steps or 0,
|
||
"sample_start": sample_start_to_save,
|
||
"sample_end": sample_end_to_save,
|
||
"X_MIN": X_MIN, "X_MAX": X_MAX,
|
||
"Y_MIN": Y_MIN, "Y_MAX": Y_MAX,
|
||
"Z_MIN": Z_MIN, "Z_MAX": Z_MAX,
|
||
"X0": X0, "Y0": Y0, "Z0": Z0,
|
||
"VX0": VX0, "VY0": VY0, "VZ0": VZ0,
|
||
"M": M,
|
||
"alpha": alpha,
|
||
"ball_radius": ball_radius,
|
||
"ball_color_r": ball_color_r,
|
||
"ball_color_g": ball_color_g,
|
||
"ball_color_b": ball_color_b,
|
||
"box_color_r": box_color_r,
|
||
"box_color_g": box_color_g,
|
||
"box_color_b": box_color_b,
|
||
"gravity_field": GRAVITY_FIELD,
|
||
"gravity_interaction": GRAVITY_INTERACTION,
|
||
"elastic_force": ELASTIC_FORCE,
|
||
"damping_force": DAMPING_FORCE,
|
||
"gravity_strength": GRAVITY_STRENGTH,
|
||
}
|
||
save_text_data(out_path, payload)
|
||
print(f"[compute] 轨迹数据已保存至: {out_path}")
|
||
return out_path
|
||
|
||
|
||
def load_parameters(param_file):
|
||
"""从文本文件加载参数,文件格式:参数 = 值"""
|
||
global box_a, alpha, X0, Y0, Z0, VX0, VY0, VZ0, M, G, B, METHOD, NT, DT, NSTEP
|
||
global X_MIN, X_MAX, Y_MIN, Y_MAX, Z_MIN, Z_MAX
|
||
global ball_radius, ball_color_r, ball_color_g, ball_color_b
|
||
global box_color_r, box_color_g, box_color_b
|
||
|
||
print(f"[compute] 正在加载参数文件: {param_file}")
|
||
with open(param_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line or line.startswith('#'):
|
||
continue
|
||
if '=' not in line:
|
||
continue
|
||
# 分割键值,并去除行内注释
|
||
key, value = line.split('=', 1)
|
||
key = key.strip()
|
||
value = value.strip()
|
||
# 去除值中可能存在的注释(# 之后的部分)
|
||
if '#' in value:
|
||
value = value.split('#', 1)[0].strip()
|
||
# 尝试转换为 int 或 float
|
||
try:
|
||
if '.' in value or 'e' in value.lower():
|
||
value = float(value)
|
||
else:
|
||
value = int(value)
|
||
except ValueError:
|
||
pass # 保持字符串
|
||
|
||
# 设置全局变量
|
||
if key in globals():
|
||
globals()[key] = value
|
||
else:
|
||
print(f"警告: 未知参数 '{key}',忽略")
|
||
|
||
# 计算派生边界
|
||
if box_a is not None:
|
||
X_MIN = -box_a
|
||
X_MAX = box_a
|
||
Y_MIN = -box_a
|
||
Y_MAX = box_a
|
||
Z_MIN = -box_a
|
||
Z_MAX = box_a
|
||
else:
|
||
print("错误: 未找到 box_a 参数")
|
||
sys.exit(1)
|
||
|
||
# 检查必需参数
|
||
required = ['alpha', 'X0', 'Y0', 'Z0', 'VX0', 'VY0', 'VZ0', 'M', 'G', 'B', 'NT', 'DT', 'NSTEP']
|
||
for param in required:
|
||
if globals()[param] is None:
|
||
print(f"错误: 未找到必需参数 '{param}'")
|
||
sys.exit(1)
|
||
|
||
METHOD = normalize_method_name(METHOD or "explicit_euler")
|
||
|
||
# 设置绘图参数的默认值(如果未提供)
|
||
if ball_radius is None:
|
||
ball_radius = 0.28
|
||
if ball_color_r is None:
|
||
ball_color_r = 0.9
|
||
if ball_color_g is None:
|
||
ball_color_g = 0.2
|
||
if ball_color_b is None:
|
||
ball_color_b = 0.2
|
||
if box_color_r is None:
|
||
box_color_r = 0.8
|
||
if box_color_g is None:
|
||
box_color_g = 0.8
|
||
if box_color_b is None:
|
||
box_color_b = 0.85
|
||
|
||
print(f"[compute] 参数加载完成: box_a={box_a}, alpha={alpha}, NT={NT}, DT={DT}, method={METHOD}")
|
||
|
||
def normalize_method_name(method):
|
||
"""Normalize method aliases from YAML/text config."""
|
||
aliases = {
|
||
"explicit_euler": "explicit_euler",
|
||
"explicit": "explicit_euler",
|
||
"euler_explicit": "explicit_euler",
|
||
"显式欧拉": "explicit_euler",
|
||
"显式欧拉法": "explicit_euler",
|
||
"implicit_euler": "implicit_euler",
|
||
"implicit": "implicit_euler",
|
||
"euler_implicit": "implicit_euler",
|
||
"隐式欧拉": "implicit_euler",
|
||
"隐式欧拉法": "implicit_euler",
|
||
"midpoint": "midpoint",
|
||
"mid_point": "midpoint",
|
||
"midpoint_method": "midpoint",
|
||
"中点法": "midpoint",
|
||
"中点差分法": "midpoint",
|
||
"leapfrog": "leapfrog",
|
||
"leap_frog": "leapfrog",
|
||
"蛙跳法": "leapfrog",
|
||
}
|
||
key = str(method).strip().lower()
|
||
if key not in aliases:
|
||
valid = ", ".join(["explicit_euler", "implicit_euler", "midpoint", "leapfrog"])
|
||
raise ValueError(f"未知算法 '{method}'。可选值: {valid}")
|
||
return aliases[key]
|
||
|
||
|
||
def compute_force(x, y, z, vx, vy, vz, m, g, b):
|
||
"""Compute total force from the current state.
|
||
|
||
Each force type is independently controlled by a global switch:
|
||
- GRAVITY_FIELD: uniform gravity F = m * g_vec
|
||
- DAMPING_FORCE: linear drag F_drag = -B_vec * v
|
||
- ELASTIC_FORCE: spring bonds based on bonded pair distance
|
||
- GRAVITY_INTERACTION: Newtonian gravity between atom pairs
|
||
"""
|
||
fx = np.zeros_like(x)
|
||
fy = np.zeros_like(y)
|
||
fz = np.zeros_like(z)
|
||
|
||
if GRAVITY_FIELD:
|
||
fx += m * g[0]
|
||
fy += m * g[1]
|
||
fz += m * g[2]
|
||
|
||
if DAMPING_FORCE:
|
||
fx -= b[0] * vx
|
||
fy -= b[1] * vy
|
||
fz -= b[2] * vz
|
||
|
||
if ELASTIC_FORCE and BOND_PAIRS is not None and len(BOND_PAIRS) > 0:
|
||
for bond_idx, (idx_1, idx_2) in enumerate(BOND_PAIRS):
|
||
dx = x[idx_2] - x[idx_1]
|
||
dy = y[idx_2] - y[idx_1]
|
||
dz = z[idx_2] - z[idx_1]
|
||
dist = np.sqrt(dx * dx + dy * dy + dz * dz)
|
||
if dist <= 1e-12:
|
||
continue
|
||
|
||
stretch = dist - BOND_REST_LENGTHS[bond_idx]
|
||
force_mag = BOND_STIFFNESS[bond_idx] * stretch
|
||
ux = dx / dist
|
||
uy = dy / dist
|
||
uz = dz / dist
|
||
|
||
fx_bond = force_mag * ux
|
||
fy_bond = force_mag * uy
|
||
fz_bond = force_mag * uz
|
||
|
||
fx[idx_1] += fx_bond
|
||
fy[idx_1] += fy_bond
|
||
fz[idx_1] += fz_bond
|
||
fx[idx_2] -= fx_bond
|
||
fy[idx_2] -= fy_bond
|
||
fz[idx_2] -= fz_bond
|
||
|
||
if GRAVITY_INTERACTION:
|
||
n = len(m)
|
||
for i in range(n):
|
||
for j in range(i + 1, n):
|
||
dx = x[j] - x[i]
|
||
dy = y[j] - y[i]
|
||
dz = z[j] - z[i]
|
||
r2 = dx * dx + dy * dy + dz * dz
|
||
if r2 <= 1e-12:
|
||
continue
|
||
mi = m[i]; mj = m[j]
|
||
f_mag = GRAVITY_STRENGTH * mi * mj / r2
|
||
r = np.sqrt(r2)
|
||
fx_i = f_mag * dx / r
|
||
fy_i = f_mag * dy / r
|
||
fz_i = f_mag * dz / r
|
||
fx[i] += fx_i; fy[i] += fy_i; fz[i] += fz_i
|
||
fx[j] -= fx_i; fy[j] -= fy_i; fz[j] -= fz_i
|
||
|
||
return fx, fy, fz
|
||
|
||
|
||
def compute_acceleration(x, y, z, vx, vy, vz, m, g, b):
|
||
"""Compute acceleration from the shared force model."""
|
||
fx, fy, fz = compute_force(x, y, z, vx, vy, vz, m, g, b)
|
||
return fx / m, fy / m, fz / m
|
||
|
||
|
||
def Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
|
||
x = x + vx * dt
|
||
y = y + vy * dt
|
||
z = z + vz * dt
|
||
vx = vx + ax * dt
|
||
vy = vy + ay * dt
|
||
vz = vz + az * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
gamma_x = b[0] / m
|
||
gamma_y = b[1] / m
|
||
gamma_z = b[2] / m
|
||
vx_next = (vx + g[0] * dt) / (1.0 + gamma_x * dt)
|
||
vy_next = (vy + g[1] * dt) / (1.0 + gamma_y * dt)
|
||
vz_next = (vz + g[2] * dt) / (1.0 + gamma_z * dt)
|
||
ax_next, ay_next, az_next = compute_acceleration(
|
||
x, y, z, vx_next, vy_next, vz_next, m, g, b)
|
||
|
||
vx = vx + ax_next * dt
|
||
vy = vy + ay_next * dt
|
||
vz = vz + az_next * dt
|
||
x = x + vx * dt
|
||
y = y + vy * dt
|
||
z = z + vz * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
|
||
x_mid = x + 0.5 * vx * dt
|
||
y_mid = y + 0.5 * vy * dt
|
||
z_mid = z + 0.5 * vz * dt
|
||
vx_mid = vx + 0.5 * ax * dt
|
||
vy_mid = vy + 0.5 * ay * dt
|
||
vz_mid = vz + 0.5 * az * dt
|
||
|
||
x = x + vx_mid * dt
|
||
y = y + vy_mid * dt
|
||
z = z + vz_mid * dt
|
||
|
||
ax_mid, ay_mid, az_mid = compute_acceleration(
|
||
x_mid, y_mid, z_mid, vx_mid, vy_mid, vz_mid, m, g, b)
|
||
vx = vx + ax_mid * dt
|
||
vy = vy + ay_mid * dt
|
||
vz = vz + az_mid * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
ax, ay, az = compute_acceleration(x, y, z, vx, vy, vz, m, g, b)
|
||
vx_half = vx + 0.5 * ax * dt
|
||
vy_half = vy + 0.5 * ay * dt
|
||
vz_half = vz + 0.5 * az * dt
|
||
|
||
x = x + vx_half * dt
|
||
y = y + vy_half * dt
|
||
z = z + vz_half * dt
|
||
|
||
# 显式预测器:用原始加速度外推半步(标准 Velocity-Verlet 预测步)
|
||
# v_pred = v_half + 0.5 * a(t)*dt,包含重力+阻尼+弹簧的所有贡献
|
||
vx_pred = vx_half + 0.5 * ax * dt
|
||
vy_pred = vy_half + 0.5 * ay * dt
|
||
vz_pred = vz_half + 0.5 * az * dt
|
||
ax_next, ay_next, az_next = compute_acceleration(
|
||
x, y, z, vx_pred, vy_pred, vz_pred, m, g, b)
|
||
vx = vx_half + 0.5 * ax_next * dt
|
||
vy = vy_half + 0.5 * ay_next * dt
|
||
vz = vz_half + 0.5 * az_next * dt
|
||
return x, y, z, vx, vy, vz
|
||
|
||
def Limit_in_box(a, amin, amax, va):
|
||
"""限制物体在边界内,发生碰撞时反弹。"""
|
||
over = a > amax
|
||
under = a < amin
|
||
a = np.where(over, amax, a)
|
||
a = np.where(under, amin, a)
|
||
va = np.where(over | under, -va, va)
|
||
return a, va
|
||
|
||
def apply_motion_update(x, y, z, vx, vy, vz, dt, m, g, b):
|
||
"""按配置选择的位置更新算法推进一步。"""
|
||
if METHOD == "explicit_euler":
|
||
x, y, z, vx, vy, vz = Explicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
elif METHOD == "implicit_euler":
|
||
x, y, z, vx, vy, vz = Implicit_Euler_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
elif METHOD == "midpoint":
|
||
x, y, z, vx, vy, vz = Midpoint_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
elif METHOD == "leapfrog":
|
||
x, y, z, vx, vy, vz = Leapfrog_Method(x, y, z, vx, vy, vz, dt, m, g, b)
|
||
else:
|
||
raise ValueError(f"未知算法: {METHOD}")
|
||
|
||
x, vx = Limit_in_box(x, X_MIN, X_MAX, vx)
|
||
y, vy = Limit_in_box(y, Y_MIN, Y_MAX, vy)
|
||
z, vz = Limit_in_box(z, Z_MIN, Z_MAX, vz)
|
||
return x, y, z, vx, vy, vz
|
||
|
||
|
||
def apply_fixed_constraints(x, y, z, vx, vy, vz):
|
||
"""Keep fixed degrees of freedom at their initial coordinate with zero speed."""
|
||
fixed = ATOM_FIXED != 0
|
||
positions = np.column_stack((x, y, z))
|
||
velocities = np.column_stack((vx, vy, vz))
|
||
positions = np.where(fixed, ATOM_POSITIONS, positions)
|
||
velocities = np.where(fixed, 0.0, velocities)
|
||
return (
|
||
positions[:, 0], positions[:, 1], positions[:, 2],
|
||
velocities[:, 0], velocities[:, 1], velocities[:, 2],
|
||
)
|
||
|
||
def wrap_position(x, y, z):
|
||
"""边界回绕( dynamics 模式)。"""
|
||
x = np.where(x > X_MAX, X_MIN, x)
|
||
x = np.where(x < X_MIN, X_MAX, x)
|
||
y = np.where(y > Y_MAX, Y_MIN, y)
|
||
y = np.where(y < Y_MIN, Y_MAX, y)
|
||
z = np.where(z > Z_MAX, Z_MIN, z)
|
||
z = np.where(z < Z_MIN, Z_MAX, z)
|
||
return x, y, z
|
||
|
||
|
||
# ===========================================================================
|
||
# 主计算流程
|
||
# ===========================================================================
|
||
|
||
def run_simulation(save_trajectory=0):
|
||
"""计算 NT 步轨迹,直接抽帧并保存 display.txt。
|
||
|
||
步骤控制:
|
||
- warmup_steps: 预热步数,跳过不记录(用于稳定初始状态)
|
||
- 按 NSTEP 抽帧保存到 display.txt(新格式)
|
||
- save_trajectory=1 时额外保存完整 trajectory.txt(JSON 旧格式)
|
||
"""
|
||
# 预热阶段
|
||
x = ATOM_POSITIONS[:, 0].copy()
|
||
y = ATOM_POSITIONS[:, 1].copy()
|
||
z = ATOM_POSITIONS[:, 2].copy()
|
||
vx = ATOM_VELOCITIES[:, 0].copy()
|
||
vy = ATOM_VELOCITIES[:, 1].copy()
|
||
vz = ATOM_VELOCITIES[:, 2].copy()
|
||
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
|
||
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, 0.0, 0, DRIVER_DATA, DT)
|
||
|
||
if warmup_steps is not None and warmup_steps > 0:
|
||
print(f"[compute] 预热阶段: 前 {warmup_steps} 步不记录")
|
||
for step in trange(warmup_steps, desc="[compute] 预热"):
|
||
t = (step + 1) * DT
|
||
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, t, step, DRIVER_DATA, DT)
|
||
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
|
||
x, y, z = wrap_position(x, y, z)
|
||
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
|
||
print(
|
||
f"[compute] 预热完成,展示原子位置: "
|
||
f"({x[PLOT_ATOM_ROW]:.4f}, {y[PLOT_ATOM_ROW]:.4f}, {z[PLOT_ATOM_ROW]:.4f})"
|
||
)
|
||
|
||
# 记录阶段 - 按 NSTEP 抽帧保存
|
||
record_steps = NT - (warmup_steps or 0)
|
||
n_atoms = len(ATOM_IDS)
|
||
n_frames = (record_steps + NSTEP - 1) // NSTEP
|
||
frame_indices = []
|
||
|
||
# 按 NSTEP 抽帧的临时缓冲区(远小于全量轨迹)
|
||
sampled_x = np.zeros((n_frames, n_atoms), dtype=np.float64)
|
||
sampled_y = np.zeros((n_frames, n_atoms), dtype=np.float64)
|
||
sampled_z = np.zeros((n_frames, n_atoms), dtype=np.float64)
|
||
sampled_vx = np.zeros((n_frames, n_atoms), dtype=np.float64)
|
||
sampled_vy = np.zeros((n_frames, n_atoms), dtype=np.float64)
|
||
sampled_vz = np.zeros((n_frames, n_atoms), dtype=np.float64)
|
||
|
||
# 如果 save_trajectory,准备完整轨迹缓冲区
|
||
if save_trajectory:
|
||
traj_x = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_y = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_z = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_vx = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_vy = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
traj_vz = np.zeros((record_steps, n_atoms), dtype=np.float64)
|
||
|
||
for step in trange(record_steps, desc="[compute] 计算中"):
|
||
t = (step + (warmup_steps or 0)) * DT
|
||
x, y, z, vx, vy, vz = apply_driving_force(x, y, z, vx, vy, vz, t, step, DRIVER_DATA, DT)
|
||
|
||
if save_trajectory:
|
||
traj_x[step] = x
|
||
traj_y[step] = y
|
||
traj_z[step] = z
|
||
traj_vx[step] = vx
|
||
traj_vy[step] = vy
|
||
traj_vz[step] = vz
|
||
|
||
# 抽帧:NSTEP 间隔保存
|
||
if step % NSTEP == 0:
|
||
fi = step // NSTEP
|
||
sampled_x[fi] = x
|
||
sampled_y[fi] = y
|
||
sampled_z[fi] = z
|
||
sampled_vx[fi] = vx
|
||
sampled_vy[fi] = vy
|
||
sampled_vz[fi] = vz
|
||
frame_indices.append(step)
|
||
|
||
x, y, z, vx, vy, vz = apply_motion_update(x, y, z, vx, vy, vz, DT, ATOM_MASSES, G, B)
|
||
x, y, z = wrap_position(x, y, z)
|
||
x, y, z, vx, vy, vz = apply_fixed_constraints(x, y, z, vx, vy, vz)
|
||
|
||
# 写入 display.txt(新格式)
|
||
output_dir = get_output_dir()
|
||
disp_path = os.path.join(output_dir, "display.txt")
|
||
n_frames_actual = len(frame_indices)
|
||
save_display_txt(
|
||
disp_path,
|
||
sampled_x[:n_frames_actual], sampled_y[:n_frames_actual], sampled_z[:n_frames_actual],
|
||
sampled_vx[:n_frames_actual], sampled_vy[:n_frames_actual], sampled_vz[:n_frames_actual],
|
||
np.array(ATOM_IDS), n_frames_actual, n_atoms,
|
||
header_fields={"DT": str(DT), "NSTEP": str(NSTEP), "method": str(METHOD),
|
||
"warmup_steps": str(warmup_steps or 0),
|
||
"dynamic_steps": str(record_steps),
|
||
"T_total": str(NT * DT),
|
||
"X_MAX": str(X_MAX), "X_MIN": str(X_MIN),
|
||
"Y_MAX": str(Y_MAX), "Y_MIN": str(Y_MIN),
|
||
"Z_MAX": str(Z_MAX), "Z_MIN": str(Z_MIN),
|
||
"ball_radius": str(ball_radius),
|
||
"ball_color_r": str(ball_color_r),
|
||
"ball_color_g": str(ball_color_g),
|
||
"ball_color_b": str(ball_color_b),
|
||
"box_color_r": str(box_color_r),
|
||
"box_color_g": str(box_color_g),
|
||
"box_color_b": str(box_color_b),
|
||
"gravity_field": str(GRAVITY_FIELD),
|
||
"gravity_interaction": str(GRAVITY_INTERACTION),
|
||
"elastic_force": str(ELASTIC_FORCE),
|
||
"damping_force": str(DAMPING_FORCE),
|
||
"gravity_strength": str(GRAVITY_STRENGTH),
|
||
"driving_force": str(DRIVING_FORCE),
|
||
"use_marker": str(use_marker),
|
||
"alpha": ",".join(str(a) for a in (alpha if isinstance(alpha, list) else [alpha])),
|
||
"atom_radii": ",".join(str(r) for r in ATOM_RADII),
|
||
"camera_distance": str(config.get("camera_distance", 40.0)),
|
||
"camera_elevation": str(config.get("camera_elevation", 0)),
|
||
"camera_azimuth": str(config.get("camera_azimuth", 0)),
|
||
"camera_keyframes": str(camera_keyframes_raw)}
|
||
)
|
||
print(f"[compute] display.txt 已保存至: {disp_path} ({n_frames_actual} 帧)")
|
||
|
||
# 可选:保存完整 trajectory.txt
|
||
if save_trajectory:
|
||
save_trajectory_txt(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz)
|
||
print(f"[compute] trajectory.txt 已保存(完整轨迹)")
|
||
|
||
return sampled_x[:n_frames_actual], sampled_y[:n_frames_actual], sampled_z[:n_frames_actual], \
|
||
sampled_vx[:n_frames_actual], sampled_vy[:n_frames_actual], sampled_vz[:n_frames_actual]
|
||
|
||
|
||
def save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT):
|
||
"""将轨迹数据保存为逐行表格文本,便于教学查看。"""
|
||
if np.ndim(traj_x) > 1:
|
||
row = PLOT_ATOM_ROW if PLOT_ATOM_ROW is not None else 0
|
||
traj_x = traj_x[:, row]
|
||
traj_y = traj_y[:, row]
|
||
traj_z = traj_z[:, row]
|
||
traj_vx = traj_vx[:, row]
|
||
traj_vy = traj_vy[:, row]
|
||
traj_vz = traj_vz[:, row]
|
||
|
||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||
# 写入表头
|
||
f.write("# 步数, 时间(s), x, y, z, vx, vy, vz\n")
|
||
for step in range(NT):
|
||
t = step * DT
|
||
x = traj_x[step]
|
||
y = traj_y[step]
|
||
z = traj_z[step]
|
||
vx = traj_vx[step]
|
||
vy = traj_vy[step]
|
||
vz = traj_vz[step]
|
||
f.write(f"{step+1:6d}, {t:.6f}, {x:.6f}, {y:.6f}, {z:.6f}, {vx:.6f}, {vy:.6f}, {vz:.6f}\n")
|
||
print(f"[compute] 轨迹文本文件已保存至: {txt_path}")
|
||
|
||
|
||
def main():
|
||
# 默认参数文件
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
param_file = os.path.join(get_input_dir(script_dir), "input.txt")
|
||
if len(sys.argv) > 1:
|
||
param_file = sys.argv[1]
|
||
|
||
# 加载参数
|
||
load_parameters(param_file)
|
||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
output_dir = get_output_dir(script_dir)
|
||
|
||
print(f"[compute] 开始计算 NT={NT} DT={DT} ")
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz = run_simulation(save_trajectory=0)
|
||
print(f"[compute] 计算完成,共 {NT} 步")
|
||
print(f"[compute] display.txt 已在 run_simulation 中保存")
|
||
|
||
# 如果需要完整轨迹,以上传 save_trajectory=1 重新运行
|
||
# 以下旧函数保留兼容但不再自动调用
|
||
# save_trajectory_txt(...)
|
||
|
||
# 同时保存为逐行表格,便于直接查看
|
||
txt_path = os.path.join(output_dir, "trajectory_table.txt")
|
||
save_trajectory_table_txt(txt_path, traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|