Files
dynamics/plot_trajectory.py
T
2026-05-17 08:47:25 +08:00

266 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
plot_trajectory.py
------------------
绘制轨迹数据的x-t, y-t, z-t, vx-t, vy-t, vz-t函数图像,支持多原子同时绘制。
用法:
python plot_trajectory.py # 绘制所有原子
python plot_trajectory.py output/trajectory.txt # 指定文件
python plot_trajectory.py output/trajectory.txt --atom-id 1 # 仅绘制原子1
参数:
trajectory_file: 轨迹数据文件,支持结构化 .txt 格式(默认:output/trajectory.txt
--atom-id: 可选,指定只绘制单个原子的轨迹
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def select_atom_series(data, atom_id=None):
"""Select one atom from trajectory arrays when the file stores many atoms."""
traj_x = data['traj_x']
traj_y = data['traj_y']
traj_z = data['traj_z']
traj_vx = data['traj_vx']
traj_vy = data['traj_vy']
traj_vz = data['traj_vz']
if traj_x.ndim == 1:
selected_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else 1
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id
atom_ids = data["atom_ids"] if "atom_ids" in data else np.arange(traj_x.shape[1]) + 1
if atom_id is None:
row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0
else:
matches = np.where(atom_ids == int(atom_id))[0]
if len(matches) == 0:
raise ValueError(f"原子序号 {atom_id} 不在轨迹文件中")
row = int(matches[0])
selected_id = int(atom_ids[row])
return (
traj_x[:, row], traj_y[:, row], traj_z[:, row],
traj_vx[:, row], traj_vy[:, row], traj_vz[:, row],
selected_id,
)
def load_trajectory_txt(file_path, atom_id=None):
"""加载结构化 .txt 格式的轨迹数据"""
data = compute.load_text_data(file_path)
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id = select_atom_series(data, atom_id)
NT = int(data['NT'])
DT = float(data['DT'])
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id
def load_full_trajectory(file_path):
"""加载完整轨迹数据(所有原子),返回 2D 数组及原始数据字典。"""
data = compute.load_text_data(file_path)
traj_x = data['traj_x']
traj_y = data['traj_y']
traj_z = data['traj_z']
traj_vx = data['traj_vx']
traj_vy = data['traj_vy']
traj_vz = data['traj_vz']
NT = int(data['NT'])
DT = float(data['DT'])
n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1
atom_ids = data.get("atom_ids", np.arange(n_atoms) + 1)
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, data
def create_plots(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, output_dir='.', atom_ids=None, extra_data=None):
"""创建六个子图的图表,支持多原子绘制与能量曲线。"""
# 配置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 生成时间数组
time = np.arange(NT) * DT
# 判断单原子(1D)或多原子(2D
is_multi = traj_x.ndim == 2
n_atoms = traj_x.shape[1] if is_multi else 1
# 是否绘制能量子图(需要 multi-atom + extra_data 中有 G
has_energy = is_multi and extra_data is not None and "G" in extra_data
n_rows = 3 if has_energy else 2
n_cols = 3
figsize = (15, 13) if has_energy else (15, 9)
# 创建图形和子图
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
title = '轨迹与能量分析' if has_energy else '轨迹分析'
if atom_ids is not None and not is_multi:
title += f' - 原子 {int(atom_ids)}'
fig.suptitle(title, fontsize=16)
plot_configs = [
(axes[0, 0], traj_x, 'x'),
(axes[0, 1], traj_y, 'y'),
(axes[0, 2], traj_z, 'z'),
(axes[1, 0], traj_vx, 'vx'),
(axes[1, 1], traj_vy, 'vy'),
(axes[1, 2], traj_vz, 'vz'),
]
if is_multi:
colors = plt.cm.tab10(np.linspace(0, 1, n_atoms))
if atom_ids is None:
atom_ids = np.arange(n_atoms) + 1
for ax, data_arr, label in plot_configs:
if is_multi:
for i in range(n_atoms):
aid = int(atom_ids[i])
ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
ax.legend()
else:
ax.plot(time, data_arr, 'b-', linewidth=1.5)
ax.set_title(f'{label} - 时间')
ax.set_xlabel('时间 (s)')
ax.set_ylabel(label)
ax.grid(True, alpha=0.3)
# ── 能量子图 ─────────────────────────────────────
if has_energy:
masses = np.array(extra_data["atom_masses"]) # (n_atoms,)
G_vec = np.array(extra_data["G"]) # [gx, gy, gz]
# 动能 Ek = ½ m v²
ek = 0.5 * masses[np.newaxis, :] * (traj_vx**2 + traj_vy**2 + traj_vz**2)
# 重力势能 Ug = -m G·r
ug = -masses[np.newaxis, :] * (G_vec[0] * traj_x + G_vec[1] * traj_y + G_vec[2] * traj_z)
# 弹性势能 Us = ½ k (d - d₀)²
us = np.zeros_like(ek)
bond_pairs = extra_data.get("bond_pairs")
bond_stiffness = extra_data.get("bond_stiffness")
bond_rest_lengths = extra_data.get("bond_rest_lengths")
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = traj_x[:, j] - traj_x[:, i]
dy = traj_y[:, j] - traj_y[:, i]
dz = traj_z[:, j] - traj_z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2
us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2
e_total = ek + ug + us # (NT, n_atoms)
ek_sys = np.sum(ek, axis=1)
ug_sys = np.sum(ug, axis=1)
us_sys = np.sum(us, axis=1)
e_sys = ek_sys + ug_sys + us_sys
# 第 3 行左:各原子总能量
ax_e = axes[2, 0]
for i in range(n_atoms):
aid = int(atom_ids[i])
ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
ax_e.set_title("各原子总能量")
ax_e.set_xlabel("时间 (s)")
ax_e.set_ylabel("能量")
ax_e.grid(True, alpha=0.3)
ax_e.legend(loc="upper right")
# 第 3 行右:系统总能量
ax_sys = axes[2, 1]
ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能")
ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能")
if bond_pairs is not None and len(bond_pairs) > 0:
ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能")
ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量")
ax_sys.set_title("系统总能量")
ax_sys.set_xlabel("时间 (s)")
ax_sys.set_ylabel("能量")
ax_sys.grid(True, alpha=0.3)
ax_sys.legend(loc="upper right")
# 隐藏第 3 行第 3 列空子图
axes[2, 2].set_visible(False)
# 调整布局
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
# 保存图像
output_path = os.path.join(output_dir, 'trajectory_plots.png')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"图表已保存至: {output_path}")
# 显示图像(如果环境支持)
plt.show()
return output_path
def main():
import argparse
parser = argparse.ArgumentParser(description="绘制轨迹数据图表")
parser.add_argument("trajectory_file", nargs="?",
help="轨迹数据文件路径(默认: output/trajectory.txt")
parser.add_argument("--atom-id", type=int, default=None,
help="指定只绘制某个原子的轨迹(默认绘制所有原子)")
args = parser.parse_args()
# 默认文件
script_dir = os.path.dirname(os.path.abspath(__file__))
default_file = os.path.join(compute.get_output_dir(script_dir), "trajectory.txt")
input_file = args.trajectory_file or default_file
# 检查文件是否存在
if not os.path.exists(input_file):
print(f"错误: 文件 '{input_file}' 不存在")
print(f"请先运行 compute.py 生成轨迹数据")
sys.exit(1)
file_ext = os.path.splitext(input_file)[1].lower()
try:
if args.atom_id is not None:
# 单原子模式(旧版兼容)
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id = \
load_trajectory_txt(input_file, args.atom_id)
ids_for_plot = selected_id
extra_data = None
n_atoms_str = f" 绘图原子序号: {selected_id}"
else:
# 全原子模式
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, raw_data = \
load_full_trajectory(input_file)
ids_for_plot = atom_ids
extra_data = raw_data
n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1
n_atoms_str = f" 绘图原子数: {n_atoms}"
except Exception as e:
print(f"加载文件时出错: {e}")
sys.exit(1)
print(f"加载轨迹数据: NT={NT}, DT={DT}")
if args.atom_id is not None:
print(n_atoms_str)
else:
print(n_atoms_str)
print(f" 位置范围: x [{traj_x.min():.3f}, {traj_x.max():.3f}], "
f"y [{traj_y.min():.3f}, {traj_y.max():.3f}], "
f"z [{traj_z.min():.3f}, {traj_z.max():.3f}]")
print(f" 速度范围: vx [{traj_vx.min():.3f}, {traj_vx.max():.3f}], "
f"vy [{traj_vy.min():.3f}, {traj_vy.max():.3f}], "
f"vz [{traj_vz.min():.3f}, {traj_vz.max():.3f}]")
# 创建图表
output_path = create_plots(
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz,
NT, DT, compute.get_output_dir(script_dir), ids_for_plot, extra_data,
)
print("绘图完成!")
if __name__ == "__main__":
main()