185 lines
7.2 KiB
Python
185 lines
7.2 KiB
Python
"""
|
|
sample.py
|
|
---------
|
|
纯 Python 抽帧脚本,不依赖 VisPy。
|
|
|
|
功能:
|
|
1. 从 output/trajectory.txt 读取完整轨迹
|
|
2. 每隔 NSTEP 帧抽取一帧,生成 output/display.txt
|
|
3. output/display.txt 可直接被 draw.py 加载驱动动画
|
|
"""
|
|
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
import compute
|
|
|
|
|
|
def build_sample_indices(total_steps, sample_step, sample_start, sample_end):
|
|
"""Validate sampling settings and build frame indices."""
|
|
if sample_step <= 0:
|
|
raise ValueError(f"NSTEP 必须为正整数,实际为 {sample_step}")
|
|
if sample_start < 0:
|
|
raise ValueError(f"sample_start 不能小于 0,实际为 {sample_start}")
|
|
if sample_end > total_steps:
|
|
raise ValueError(
|
|
f"sample_end 不能大于记录步数 {total_steps},实际为 {sample_end}")
|
|
if sample_start >= sample_end:
|
|
raise ValueError(
|
|
f"sample_start 必须小于 sample_end,实际为 [{sample_start}, {sample_end})")
|
|
|
|
n_frames = (sample_end - sample_start) // sample_step
|
|
if n_frames <= 0:
|
|
raise ValueError(
|
|
f"抽帧范围 [{sample_start}, {sample_end}) 过短,按 NSTEP={sample_step} 无法抽出任何帧")
|
|
|
|
return np.arange(n_frames, dtype=np.int64) * sample_step + sample_start
|
|
|
|
|
|
def read_optional_index(data, key, default_value):
|
|
"""Read an optional integer index from txt metadata."""
|
|
if key not in data:
|
|
return default_value
|
|
value = data[key]
|
|
if value is None or int(value) < 0:
|
|
return default_value
|
|
return int(value)
|
|
|
|
|
|
def main():
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
output_dir = compute.get_output_dir(script_dir)
|
|
traj_path = os.path.join(output_dir, "trajectory.txt")
|
|
disp_path = os.path.join(output_dir, "display.txt")
|
|
|
|
# -------------------------------------------------------------------------
|
|
# 1. 加载完整轨迹
|
|
# -------------------------------------------------------------------------
|
|
data = compute.load_text_data(traj_path)
|
|
NT = int(data["NT"])
|
|
DT = float(data["DT"])
|
|
NSTEP = int(data["NSTEP"])
|
|
traj_x = data["traj_x"]
|
|
traj_y = data["traj_y"]
|
|
traj_z = data["traj_z"]
|
|
traj_vx = data["traj_vx"]
|
|
traj_vy = data["traj_vy"]
|
|
traj_vz = data["traj_vz"]
|
|
plot_atom_row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0
|
|
plot_atom_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else int(data["atom_ids"][plot_atom_row])
|
|
|
|
print(f"[sample] 加载轨迹数据: NT={NT}, DT={DT}, NSTEP={NSTEP}")
|
|
|
|
# -------------------------------------------------------------------------
|
|
# 2. 抽帧:支持配置文件中保存的 [sample_start, sample_end) 区间
|
|
# -------------------------------------------------------------------------
|
|
sample_start = read_optional_index(data, "sample_start", 0)
|
|
sample_end = read_optional_index(data, "sample_end", NT)
|
|
indices = build_sample_indices(NT, NSTEP, sample_start, sample_end)
|
|
n_frames = len(indices)
|
|
|
|
if traj_x.ndim == 1:
|
|
selected_x = traj_x
|
|
selected_y = traj_y
|
|
selected_z = traj_z
|
|
selected_vx = traj_vx
|
|
selected_vy = traj_vy
|
|
selected_vz = traj_vz
|
|
all_x = traj_x[:, None]
|
|
all_y = traj_y[:, None]
|
|
all_z = traj_z[:, None]
|
|
all_vx = traj_vx[:, None]
|
|
all_vy = traj_vy[:, None]
|
|
all_vz = traj_vz[:, None]
|
|
else:
|
|
selected_x = traj_x[:, plot_atom_row]
|
|
selected_y = traj_y[:, plot_atom_row]
|
|
selected_z = traj_z[:, plot_atom_row]
|
|
selected_vx = traj_vx[:, plot_atom_row]
|
|
selected_vy = traj_vy[:, plot_atom_row]
|
|
selected_vz = traj_vz[:, plot_atom_row]
|
|
all_x = traj_x
|
|
all_y = traj_y
|
|
all_z = traj_z
|
|
all_vx = traj_vx
|
|
all_vy = traj_vy
|
|
all_vz = traj_vz
|
|
|
|
disp_x = selected_x[indices]
|
|
disp_y = selected_y[indices]
|
|
disp_z = selected_z[indices]
|
|
disp_vx = selected_vx[indices]
|
|
disp_vy = selected_vy[indices]
|
|
disp_vz = selected_vz[indices]
|
|
disp_t = indices * DT # 物理时间 = step * DT
|
|
disp_step = indices # 对应保存轨迹的步编号(0-based)
|
|
|
|
print(f"[sample] 抽帧完成: [{sample_start}, {sample_end}) -> {n_frames} 帧")
|
|
print(f"[sample] 每帧时间跨度: {NSTEP*DT:.3f} s (即每隔 {NSTEP} 步取一帧)")
|
|
|
|
# -------------------------------------------------------------------------
|
|
# 3. 保存显示数组
|
|
# -------------------------------------------------------------------------
|
|
payload = {
|
|
"disp_x": disp_x,
|
|
"disp_y": disp_y,
|
|
"disp_z": disp_z,
|
|
"disp_vx": disp_vx,
|
|
"disp_vy": disp_vy,
|
|
"disp_vz": disp_vz,
|
|
"disp_all_x": all_x[indices],
|
|
"disp_all_y": all_y[indices],
|
|
"disp_all_z": all_z[indices],
|
|
"disp_all_vx": all_vx[indices],
|
|
"disp_all_vy": all_vy[indices],
|
|
"disp_all_vz": all_vz[indices],
|
|
"disp_t": disp_t,
|
|
"disp_step": disp_step,
|
|
"n_frames": n_frames,
|
|
"NT": NT,
|
|
"DT": DT,
|
|
"NSTEP": NSTEP,
|
|
"plot_atom_id": plot_atom_id,
|
|
"plot_atom_row": plot_atom_row,
|
|
"method": str(data["method"]) if "method" in data else "explicit_euler",
|
|
"coord_file": str(data["coord_file"]) if "coord_file" in data else os.path.join("input", "coord.txt"),
|
|
"atom_ids": data["atom_ids"] if "atom_ids" in data else np.array([1]),
|
|
"atom_masses": data["atom_masses"] if "atom_masses" in data else np.array([float(data["M"])]),
|
|
"atom_radii": data["atom_radii"] if "atom_radii" in data else np.array([float(data["ball_radius"])]),
|
|
"atom_positions": data["atom_positions"] if "atom_positions" in data else np.array([[float(data["X0"]), float(data["Y0"]), float(data["Z0"])]]),
|
|
"atom_velocities": data["atom_velocities"] if "atom_velocities" in data else np.array([[float(data["VX0"]), float(data["VY0"]), float(data["VZ0"])]]),
|
|
"atom_fixed": data["atom_fixed"] if "atom_fixed" in data else np.array([[0, 0, 0]]),
|
|
"warmup_steps": int(data["warmup_steps"]) if "warmup_steps" in data else 0,
|
|
"sample_start": sample_start,
|
|
"sample_end": sample_end,
|
|
"X_MIN": float(data["X_MIN"]),
|
|
"X_MAX": float(data["X_MAX"]),
|
|
"Y_MIN": float(data["Y_MIN"]),
|
|
"Y_MAX": float(data["Y_MAX"]),
|
|
"Z_MIN": float(data["Z_MIN"]),
|
|
"Z_MAX": float(data["Z_MAX"]),
|
|
"X0": float(data["X0"]),
|
|
"Y0": float(data["Y0"]),
|
|
"Z0": float(data["Z0"]),
|
|
"VX0": float(data["VX0"]),
|
|
"VY0": float(data["VY0"]),
|
|
"VZ0": float(data["VZ0"]),
|
|
"M": float(data["M"]) if "M" in data else 1.0,
|
|
"alpha": float(data["alpha"]),
|
|
"ball_radius": float(data["ball_radius"]),
|
|
"ball_color_r": float(data["ball_color_r"]),
|
|
"ball_color_g": float(data["ball_color_g"]),
|
|
"ball_color_b": float(data["ball_color_b"]),
|
|
"box_color_r": float(data["box_color_r"]),
|
|
"box_color_g": float(data["box_color_g"]),
|
|
"box_color_b": float(data["box_color_b"]),
|
|
}
|
|
compute.save_text_data(disp_path, payload)
|
|
print(f"[sample] 显示数组已保存至: {disp_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|