""" plot_trajectory.py ------------------ 绘制轨迹数据的x-t, y-t, z-t, vx-t, vy-t, vz-t函数图像,支持多原子同时绘制。 用法: python plot_trajectory.py # 绘制所有原子 python plot_trajectory.py output/trajectory.txt # 指定文件 python plot_trajectory.py output/trajectory.txt --atom-id 1 # 仅绘制原子1 参数: trajectory_file: 轨迹数据文件,支持结构化 .txt 格式(默认:output/trajectory.txt) --atom-id: 可选,指定只绘制单个原子的轨迹 """ import numpy as np import matplotlib.pyplot as plt import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import compute def select_atom_series(data, atom_id=None): """Select one atom from trajectory arrays when the file stores many atoms.""" traj_x = data['traj_x'] traj_y = data['traj_y'] traj_z = data['traj_z'] traj_vx = data['traj_vx'] traj_vy = data['traj_vy'] traj_vz = data['traj_vz'] if traj_x.ndim == 1: selected_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else 1 return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id atom_ids = data["atom_ids"] if "atom_ids" in data else np.arange(traj_x.shape[1]) + 1 if atom_id is None: row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0 else: matches = np.where(atom_ids == int(atom_id))[0] if len(matches) == 0: raise ValueError(f"原子序号 {atom_id} 不在轨迹文件中") row = int(matches[0]) selected_id = int(atom_ids[row]) return ( traj_x[:, row], traj_y[:, row], traj_z[:, row], traj_vx[:, row], traj_vy[:, row], traj_vz[:, row], selected_id, ) def load_trajectory_txt(file_path, atom_id=None): """加载结构化 .txt 格式的轨迹数据""" data = compute.load_text_data(file_path) traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id = select_atom_series(data, atom_id) NT = int(data['NT']) DT = float(data['DT']) return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id def load_full_trajectory(file_path): """加载完整轨迹数据(所有原子),返回 2D 数组及原始数据字典。""" data = compute.load_text_data(file_path) traj_x = data['traj_x'] traj_y = data['traj_y'] traj_z = data['traj_z'] traj_vx = data['traj_vx'] traj_vy = data['traj_vy'] traj_vz = data['traj_vz'] NT = int(data['NT']) DT = float(data['DT']) n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1 atom_ids = data.get("atom_ids", np.arange(n_atoms) + 1) return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, data def create_plots(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, output_dir='.', atom_ids=None, extra_data=None): """创建六个子图的图表,支持多原子绘制与能量曲线。""" # 配置中文字体支持 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False # 生成时间数组 time = np.arange(NT) * DT # 判断单原子(1D)或多原子(2D) is_multi = traj_x.ndim == 2 n_atoms = traj_x.shape[1] if is_multi else 1 # 是否绘制能量子图(需要 multi-atom + extra_data 中有 G) has_energy = is_multi and extra_data is not None and "G" in extra_data n_rows = 3 if has_energy else 2 n_cols = 3 figsize = (15, 13) if has_energy else (15, 9) # 创建图形和子图 fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) title = '轨迹与能量分析' if has_energy else '轨迹分析' if atom_ids is not None and not is_multi: title += f' - 原子 {int(atom_ids)}' fig.suptitle(title, fontsize=16) plot_configs = [ (axes[0, 0], traj_x, 'x'), (axes[0, 1], traj_y, 'y'), (axes[0, 2], traj_z, 'z'), (axes[1, 0], traj_vx, 'vx'), (axes[1, 1], traj_vy, 'vy'), (axes[1, 2], traj_vz, 'vz'), ] if is_multi: colors = plt.cm.tab10(np.linspace(0, 1, n_atoms)) if atom_ids is None: atom_ids = np.arange(n_atoms) + 1 for ax, data_arr, label in plot_configs: if is_multi: for i in range(n_atoms): aid = int(atom_ids[i]) ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}") ax.legend() else: ax.plot(time, data_arr, 'b-', linewidth=1.5) ax.set_title(f'{label} - 时间') ax.set_xlabel('时间 (s)') ax.set_ylabel(label) ax.grid(True, alpha=0.3) # ── 能量子图 ───────────────────────────────────── if has_energy: masses = np.array(extra_data["atom_masses"]) # (n_atoms,) G_vec = np.array(extra_data["G"]) # [gx, gy, gz] # 动能 Ek = ½ m v² ek = 0.5 * masses[np.newaxis, :] * (traj_vx**2 + traj_vy**2 + traj_vz**2) # 重力势能 Ug = -m G·r ug = -masses[np.newaxis, :] * (G_vec[0] * traj_x + G_vec[1] * traj_y + G_vec[2] * traj_z) # 弹性势能 Us = ½ k (d - d₀)² us = np.zeros_like(ek) bond_pairs = extra_data.get("bond_pairs") bond_stiffness = extra_data.get("bond_stiffness") bond_rest_lengths = extra_data.get("bond_rest_lengths") if bond_pairs is not None and len(bond_pairs) > 0: for b_idx in range(len(bond_pairs)): i, j = bond_pairs[b_idx] dx = traj_x[:, j] - traj_x[:, i] dy = traj_y[:, j] - traj_y[:, i] dz = traj_z[:, j] - traj_z[:, i] dist = np.sqrt(dx**2 + dy**2 + dz**2) stretch = dist - bond_rest_lengths[b_idx] us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2 us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2 e_total = ek + ug + us # (NT, n_atoms) ek_sys = np.sum(ek, axis=1) ug_sys = np.sum(ug, axis=1) us_sys = np.sum(us, axis=1) e_sys = ek_sys + ug_sys + us_sys # 第 3 行左:各原子总能量 ax_e = axes[2, 0] for i in range(n_atoms): aid = int(atom_ids[i]) ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}") ax_e.set_title("各原子总能量") ax_e.set_xlabel("时间 (s)") ax_e.set_ylabel("能量") ax_e.grid(True, alpha=0.3) ax_e.legend(loc="upper right") # 第 3 行右:系统总能量 ax_sys = axes[2, 1] ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能") ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能") if bond_pairs is not None and len(bond_pairs) > 0: ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能") ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量") ax_sys.set_title("系统总能量") ax_sys.set_xlabel("时间 (s)") ax_sys.set_ylabel("能量") ax_sys.grid(True, alpha=0.3) ax_sys.legend(loc="upper right") # 隐藏第 3 行第 3 列空子图 axes[2, 2].set_visible(False) # 调整布局 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # 保存图像 output_path = os.path.join(output_dir, 'trajectory_plots.png') plt.savefig(output_path, dpi=300, bbox_inches='tight') print(f"图表已保存至: {output_path}") # 显示图像(如果环境支持) plt.show() return output_path def main(): import argparse parser = argparse.ArgumentParser(description="绘制轨迹数据图表") parser.add_argument("trajectory_file", nargs="?", help="轨迹数据文件路径(默认: output/trajectory.txt)") parser.add_argument("--atom-id", type=int, default=None, help="指定只绘制某个原子的轨迹(默认绘制所有原子)") args = parser.parse_args() # 默认文件 script_dir = os.path.dirname(os.path.abspath(__file__)) default_file = os.path.join(compute.get_output_dir(script_dir), "trajectory.txt") input_file = args.trajectory_file or default_file # 检查文件是否存在 if not os.path.exists(input_file): print(f"错误: 文件 '{input_file}' 不存在") print(f"请先运行 compute.py 生成轨迹数据") sys.exit(1) file_ext = os.path.splitext(input_file)[1].lower() try: if args.atom_id is not None: # 单原子模式(旧版兼容) traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id = \ load_trajectory_txt(input_file, args.atom_id) ids_for_plot = selected_id extra_data = None n_atoms_str = f" 绘图原子序号: {selected_id}" else: # 全原子模式 traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, raw_data = \ load_full_trajectory(input_file) ids_for_plot = atom_ids extra_data = raw_data n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1 n_atoms_str = f" 绘图原子数: {n_atoms}" except Exception as e: print(f"加载文件时出错: {e}") sys.exit(1) print(f"加载轨迹数据: NT={NT}, DT={DT}") if args.atom_id is not None: print(n_atoms_str) else: print(n_atoms_str) print(f" 位置范围: x [{traj_x.min():.3f}, {traj_x.max():.3f}], " f"y [{traj_y.min():.3f}, {traj_y.max():.3f}], " f"z [{traj_z.min():.3f}, {traj_z.max():.3f}]") print(f" 速度范围: vx [{traj_vx.min():.3f}, {traj_vx.max():.3f}], " f"vy [{traj_vy.min():.3f}, {traj_vy.max():.3f}], " f"vz [{traj_vz.min():.3f}, {traj_vz.max():.3f}]") # 创建图表 output_path = create_plots( traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, compute.get_output_dir(script_dir), ids_for_plot, extra_data, ) print("绘图完成!") if __name__ == "__main__": main()