266 lines
10 KiB
Python
266 lines
10 KiB
Python
"""
|
||
plot_trajectory.py
|
||
------------------
|
||
绘制轨迹数据的x-t, y-t, z-t, vx-t, vy-t, vz-t函数图像,支持多原子同时绘制。
|
||
|
||
用法:
|
||
python plot_trajectory.py # 绘制所有原子
|
||
python plot_trajectory.py output/trajectory.txt # 指定文件
|
||
python plot_trajectory.py output/trajectory.txt --atom-id 1 # 仅绘制原子1
|
||
|
||
参数:
|
||
trajectory_file: 轨迹数据文件,支持结构化 .txt 格式(默认:output/trajectory.txt)
|
||
--atom-id: 可选,指定只绘制单个原子的轨迹
|
||
"""
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import sys
|
||
import os
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
import compute
|
||
|
||
def select_atom_series(data, atom_id=None):
|
||
"""Select one atom from trajectory arrays when the file stores many atoms."""
|
||
traj_x = data['traj_x']
|
||
traj_y = data['traj_y']
|
||
traj_z = data['traj_z']
|
||
traj_vx = data['traj_vx']
|
||
traj_vy = data['traj_vy']
|
||
traj_vz = data['traj_vz']
|
||
|
||
if traj_x.ndim == 1:
|
||
selected_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else 1
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id
|
||
|
||
atom_ids = data["atom_ids"] if "atom_ids" in data else np.arange(traj_x.shape[1]) + 1
|
||
if atom_id is None:
|
||
row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0
|
||
else:
|
||
matches = np.where(atom_ids == int(atom_id))[0]
|
||
if len(matches) == 0:
|
||
raise ValueError(f"原子序号 {atom_id} 不在轨迹文件中")
|
||
row = int(matches[0])
|
||
selected_id = int(atom_ids[row])
|
||
return (
|
||
traj_x[:, row], traj_y[:, row], traj_z[:, row],
|
||
traj_vx[:, row], traj_vy[:, row], traj_vz[:, row],
|
||
selected_id,
|
||
)
|
||
|
||
|
||
def load_trajectory_txt(file_path, atom_id=None):
|
||
"""加载结构化 .txt 格式的轨迹数据"""
|
||
data = compute.load_text_data(file_path)
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id = select_atom_series(data, atom_id)
|
||
NT = int(data['NT'])
|
||
DT = float(data['DT'])
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id
|
||
|
||
|
||
def load_full_trajectory(file_path):
|
||
"""加载完整轨迹数据(所有原子),返回 2D 数组及原始数据字典。"""
|
||
data = compute.load_text_data(file_path)
|
||
traj_x = data['traj_x']
|
||
traj_y = data['traj_y']
|
||
traj_z = data['traj_z']
|
||
traj_vx = data['traj_vx']
|
||
traj_vy = data['traj_vy']
|
||
traj_vz = data['traj_vz']
|
||
NT = int(data['NT'])
|
||
DT = float(data['DT'])
|
||
n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1
|
||
atom_ids = data.get("atom_ids", np.arange(n_atoms) + 1)
|
||
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, data
|
||
|
||
def create_plots(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, output_dir='.', atom_ids=None, extra_data=None):
|
||
"""创建六个子图的图表,支持多原子绘制与能量曲线。"""
|
||
# 配置中文字体支持
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
# 生成时间数组
|
||
time = np.arange(NT) * DT
|
||
|
||
# 判断单原子(1D)或多原子(2D)
|
||
is_multi = traj_x.ndim == 2
|
||
n_atoms = traj_x.shape[1] if is_multi else 1
|
||
|
||
# 是否绘制能量子图(需要 multi-atom + extra_data 中有 G)
|
||
has_energy = is_multi and extra_data is not None and "G" in extra_data
|
||
n_rows = 3 if has_energy else 2
|
||
n_cols = 3
|
||
figsize = (15, 13) if has_energy else (15, 9)
|
||
|
||
# 创建图形和子图
|
||
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
||
title = '轨迹与能量分析' if has_energy else '轨迹分析'
|
||
if atom_ids is not None and not is_multi:
|
||
title += f' - 原子 {int(atom_ids)}'
|
||
fig.suptitle(title, fontsize=16)
|
||
|
||
plot_configs = [
|
||
(axes[0, 0], traj_x, 'x'),
|
||
(axes[0, 1], traj_y, 'y'),
|
||
(axes[0, 2], traj_z, 'z'),
|
||
(axes[1, 0], traj_vx, 'vx'),
|
||
(axes[1, 1], traj_vy, 'vy'),
|
||
(axes[1, 2], traj_vz, 'vz'),
|
||
]
|
||
|
||
if is_multi:
|
||
colors = plt.cm.tab10(np.linspace(0, 1, n_atoms))
|
||
if atom_ids is None:
|
||
atom_ids = np.arange(n_atoms) + 1
|
||
|
||
for ax, data_arr, label in plot_configs:
|
||
if is_multi:
|
||
for i in range(n_atoms):
|
||
aid = int(atom_ids[i])
|
||
ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
|
||
ax.legend()
|
||
else:
|
||
ax.plot(time, data_arr, 'b-', linewidth=1.5)
|
||
ax.set_title(f'{label} - 时间')
|
||
ax.set_xlabel('时间 (s)')
|
||
ax.set_ylabel(label)
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
# ── 能量子图 ─────────────────────────────────────
|
||
if has_energy:
|
||
masses = np.array(extra_data["atom_masses"]) # (n_atoms,)
|
||
G_vec = np.array(extra_data["G"]) # [gx, gy, gz]
|
||
|
||
# 动能 Ek = ½ m v²
|
||
ek = 0.5 * masses[np.newaxis, :] * (traj_vx**2 + traj_vy**2 + traj_vz**2)
|
||
# 重力势能 Ug = -m G·r
|
||
ug = -masses[np.newaxis, :] * (G_vec[0] * traj_x + G_vec[1] * traj_y + G_vec[2] * traj_z)
|
||
# 弹性势能 Us = ½ k (d - d₀)²
|
||
us = np.zeros_like(ek)
|
||
bond_pairs = extra_data.get("bond_pairs")
|
||
bond_stiffness = extra_data.get("bond_stiffness")
|
||
bond_rest_lengths = extra_data.get("bond_rest_lengths")
|
||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||
for b_idx in range(len(bond_pairs)):
|
||
i, j = bond_pairs[b_idx]
|
||
dx = traj_x[:, j] - traj_x[:, i]
|
||
dy = traj_y[:, j] - traj_y[:, i]
|
||
dz = traj_z[:, j] - traj_z[:, i]
|
||
dist = np.sqrt(dx**2 + dy**2 + dz**2)
|
||
stretch = dist - bond_rest_lengths[b_idx]
|
||
us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2
|
||
us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2
|
||
|
||
e_total = ek + ug + us # (NT, n_atoms)
|
||
ek_sys = np.sum(ek, axis=1)
|
||
ug_sys = np.sum(ug, axis=1)
|
||
us_sys = np.sum(us, axis=1)
|
||
e_sys = ek_sys + ug_sys + us_sys
|
||
|
||
# 第 3 行左:各原子总能量
|
||
ax_e = axes[2, 0]
|
||
for i in range(n_atoms):
|
||
aid = int(atom_ids[i])
|
||
ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
|
||
ax_e.set_title("各原子总能量")
|
||
ax_e.set_xlabel("时间 (s)")
|
||
ax_e.set_ylabel("能量")
|
||
ax_e.grid(True, alpha=0.3)
|
||
ax_e.legend(loc="upper right")
|
||
|
||
# 第 3 行右:系统总能量
|
||
ax_sys = axes[2, 1]
|
||
ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能")
|
||
ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能")
|
||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||
ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能")
|
||
ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量")
|
||
ax_sys.set_title("系统总能量")
|
||
ax_sys.set_xlabel("时间 (s)")
|
||
ax_sys.set_ylabel("能量")
|
||
ax_sys.grid(True, alpha=0.3)
|
||
ax_sys.legend(loc="upper right")
|
||
|
||
# 隐藏第 3 行第 3 列空子图
|
||
axes[2, 2].set_visible(False)
|
||
|
||
# 调整布局
|
||
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
||
|
||
# 保存图像
|
||
output_path = os.path.join(output_dir, 'trajectory_plots.png')
|
||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||
print(f"图表已保存至: {output_path}")
|
||
|
||
# 显示图像(如果环境支持)
|
||
plt.show()
|
||
|
||
return output_path
|
||
|
||
def main():
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="绘制轨迹数据图表")
|
||
parser.add_argument("trajectory_file", nargs="?",
|
||
help="轨迹数据文件路径(默认: output/trajectory.txt)")
|
||
parser.add_argument("--atom-id", type=int, default=None,
|
||
help="指定只绘制某个原子的轨迹(默认绘制所有原子)")
|
||
args = parser.parse_args()
|
||
|
||
# 默认文件
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
default_file = os.path.join(compute.get_output_dir(script_dir), "trajectory.txt")
|
||
input_file = args.trajectory_file or default_file
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(input_file):
|
||
print(f"错误: 文件 '{input_file}' 不存在")
|
||
print(f"请先运行 compute.py 生成轨迹数据")
|
||
sys.exit(1)
|
||
|
||
file_ext = os.path.splitext(input_file)[1].lower()
|
||
|
||
try:
|
||
if args.atom_id is not None:
|
||
# 单原子模式(旧版兼容)
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id = \
|
||
load_trajectory_txt(input_file, args.atom_id)
|
||
ids_for_plot = selected_id
|
||
extra_data = None
|
||
n_atoms_str = f" 绘图原子序号: {selected_id}"
|
||
else:
|
||
# 全原子模式
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, raw_data = \
|
||
load_full_trajectory(input_file)
|
||
ids_for_plot = atom_ids
|
||
extra_data = raw_data
|
||
n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1
|
||
n_atoms_str = f" 绘图原子数: {n_atoms}"
|
||
except Exception as e:
|
||
print(f"加载文件时出错: {e}")
|
||
sys.exit(1)
|
||
|
||
print(f"加载轨迹数据: NT={NT}, DT={DT}")
|
||
if args.atom_id is not None:
|
||
print(n_atoms_str)
|
||
else:
|
||
print(n_atoms_str)
|
||
print(f" 位置范围: x [{traj_x.min():.3f}, {traj_x.max():.3f}], "
|
||
f"y [{traj_y.min():.3f}, {traj_y.max():.3f}], "
|
||
f"z [{traj_z.min():.3f}, {traj_z.max():.3f}]")
|
||
print(f" 速度范围: vx [{traj_vx.min():.3f}, {traj_vx.max():.3f}], "
|
||
f"vy [{traj_vy.min():.3f}, {traj_vy.max():.3f}], "
|
||
f"vz [{traj_vz.min():.3f}, {traj_vz.max():.3f}]")
|
||
|
||
# 创建图表
|
||
output_path = create_plots(
|
||
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz,
|
||
NT, DT, compute.get_output_dir(script_dir), ids_for_plot, extra_data,
|
||
)
|
||
|
||
print("绘图完成!")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|