""" sample.py --------- 纯 Python 抽帧脚本,不依赖 VisPy。 功能: 1. 从 output/trajectory.txt 读取完整轨迹 2. 每隔 NSTEP 帧抽取一帧,生成 output/display.txt 3. output/display.txt 可直接被 draw.py 加载驱动动画 """ import numpy as np import os import sys sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import compute def build_sample_indices(total_steps, sample_step, sample_start, sample_end): """Validate sampling settings and build frame indices.""" if sample_step <= 0: raise ValueError(f"NSTEP 必须为正整数,实际为 {sample_step}") if sample_start < 0: raise ValueError(f"sample_start 不能小于 0,实际为 {sample_start}") if sample_end > total_steps: raise ValueError( f"sample_end 不能大于记录步数 {total_steps},实际为 {sample_end}") if sample_start >= sample_end: raise ValueError( f"sample_start 必须小于 sample_end,实际为 [{sample_start}, {sample_end})") n_frames = (sample_end - sample_start) // sample_step if n_frames <= 0: raise ValueError( f"抽帧范围 [{sample_start}, {sample_end}) 过短,按 NSTEP={sample_step} 无法抽出任何帧") return np.arange(n_frames, dtype=np.int64) * sample_step + sample_start def read_optional_index(data, key, default_value): """Read an optional integer index from txt metadata.""" if key not in data: return default_value value = data[key] if value is None or int(value) < 0: return default_value return int(value) def main(): script_dir = os.path.dirname(os.path.abspath(__file__)) output_dir = compute.get_output_dir(script_dir) traj_path = os.path.join(output_dir, "trajectory.txt") disp_path = os.path.join(output_dir, "display.txt") # ------------------------------------------------------------------------- # 1. 加载完整轨迹 # ------------------------------------------------------------------------- data = compute.load_text_data(traj_path) NT = int(data["NT"]) DT = float(data["DT"]) NSTEP = int(data["NSTEP"]) traj_x = data["traj_x"] traj_y = data["traj_y"] traj_z = data["traj_z"] traj_vx = data["traj_vx"] traj_vy = data["traj_vy"] traj_vz = data["traj_vz"] plot_atom_row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0 plot_atom_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else int(data["atom_ids"][plot_atom_row]) print(f"[sample] 加载轨迹数据: NT={NT}, DT={DT}, NSTEP={NSTEP}") # ------------------------------------------------------------------------- # 2. 抽帧:支持配置文件中保存的 [sample_start, sample_end) 区间 # ------------------------------------------------------------------------- sample_start = read_optional_index(data, "sample_start", 0) sample_end = read_optional_index(data, "sample_end", NT) indices = build_sample_indices(NT, NSTEP, sample_start, sample_end) n_frames = len(indices) if traj_x.ndim == 1: selected_x = traj_x selected_y = traj_y selected_z = traj_z selected_vx = traj_vx selected_vy = traj_vy selected_vz = traj_vz all_x = traj_x[:, None] all_y = traj_y[:, None] all_z = traj_z[:, None] all_vx = traj_vx[:, None] all_vy = traj_vy[:, None] all_vz = traj_vz[:, None] else: selected_x = traj_x[:, plot_atom_row] selected_y = traj_y[:, plot_atom_row] selected_z = traj_z[:, plot_atom_row] selected_vx = traj_vx[:, plot_atom_row] selected_vy = traj_vy[:, plot_atom_row] selected_vz = traj_vz[:, plot_atom_row] all_x = traj_x all_y = traj_y all_z = traj_z all_vx = traj_vx all_vy = traj_vy all_vz = traj_vz disp_x = selected_x[indices] disp_y = selected_y[indices] disp_z = selected_z[indices] disp_vx = selected_vx[indices] disp_vy = selected_vy[indices] disp_vz = selected_vz[indices] disp_t = indices * DT # 物理时间 = step * DT disp_step = indices # 对应保存轨迹的步编号(0-based) print(f"[sample] 抽帧完成: [{sample_start}, {sample_end}) -> {n_frames} 帧") print(f"[sample] 每帧时间跨度: {NSTEP*DT:.3f} s (即每隔 {NSTEP} 步取一帧)") # ------------------------------------------------------------------------- # 3. 保存显示数组 # ------------------------------------------------------------------------- payload = { "disp_x": disp_x, "disp_y": disp_y, "disp_z": disp_z, "disp_vx": disp_vx, "disp_vy": disp_vy, "disp_vz": disp_vz, "disp_all_x": all_x[indices], "disp_all_y": all_y[indices], "disp_all_z": all_z[indices], "disp_all_vx": all_vx[indices], "disp_all_vy": all_vy[indices], "disp_all_vz": all_vz[indices], "disp_t": disp_t, "disp_step": disp_step, "n_frames": n_frames, "NT": NT, "DT": DT, "NSTEP": NSTEP, "plot_atom_id": plot_atom_id, "plot_atom_row": plot_atom_row, "method": str(data["method"]) if "method" in data else "explicit_euler", "coord_file": str(data["coord_file"]) if "coord_file" in data else os.path.join("input", "coord.txt"), "atom_ids": data["atom_ids"] if "atom_ids" in data else np.array([1]), "atom_masses": data["atom_masses"] if "atom_masses" in data else np.array([float(data["M"])]), "atom_radii": data["atom_radii"] if "atom_radii" in data else np.array([float(data["ball_radius"])]), "atom_positions": data["atom_positions"] if "atom_positions" in data else np.array([[float(data["X0"]), float(data["Y0"]), float(data["Z0"])]]), "atom_velocities": data["atom_velocities"] if "atom_velocities" in data else np.array([[float(data["VX0"]), float(data["VY0"]), float(data["VZ0"])]]), "atom_fixed": data["atom_fixed"] if "atom_fixed" in data else np.array([[0, 0, 0]]), "warmup_steps": int(data["warmup_steps"]) if "warmup_steps" in data else 0, "sample_start": sample_start, "sample_end": sample_end, "X_MIN": float(data["X_MIN"]), "X_MAX": float(data["X_MAX"]), "Y_MIN": float(data["Y_MIN"]), "Y_MAX": float(data["Y_MAX"]), "Z_MIN": float(data["Z_MIN"]), "Z_MAX": float(data["Z_MAX"]), "X0": float(data["X0"]), "Y0": float(data["Y0"]), "Z0": float(data["Z0"]), "VX0": float(data["VX0"]), "VY0": float(data["VY0"]), "VZ0": float(data["VZ0"]), "M": float(data["M"]) if "M" in data else 1.0, "alpha": float(data["alpha"]), "ball_radius": float(data["ball_radius"]), "ball_color_r": float(data["ball_color_r"]), "ball_color_g": float(data["ball_color_g"]), "ball_color_b": float(data["ball_color_b"]), "box_color_r": float(data["box_color_r"]), "box_color_g": float(data["box_color_g"]), "box_color_b": float(data["box_color_b"]), } compute.save_text_data(disp_path, payload) print(f"[sample] 显示数组已保存至: {disp_path}") if __name__ == "__main__": main()