This commit is contained in:
2026-05-17 08:47:25 +08:00
parent 1159d86b8b
commit 45513fe334
27 changed files with 4734 additions and 2 deletions
+265
View File
@@ -0,0 +1,265 @@
"""
plot_trajectory.py
------------------
绘制轨迹数据的x-t, y-t, z-t, vx-t, vy-t, vz-t函数图像,支持多原子同时绘制。
用法:
python plot_trajectory.py # 绘制所有原子
python plot_trajectory.py output/trajectory.txt # 指定文件
python plot_trajectory.py output/trajectory.txt --atom-id 1 # 仅绘制原子1
参数:
trajectory_file: 轨迹数据文件,支持结构化 .txt 格式(默认:output/trajectory.txt
--atom-id: 可选,指定只绘制单个原子的轨迹
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def select_atom_series(data, atom_id=None):
"""Select one atom from trajectory arrays when the file stores many atoms."""
traj_x = data['traj_x']
traj_y = data['traj_y']
traj_z = data['traj_z']
traj_vx = data['traj_vx']
traj_vy = data['traj_vy']
traj_vz = data['traj_vz']
if traj_x.ndim == 1:
selected_id = int(data["plot_atom_id"]) if "plot_atom_id" in data else 1
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id
atom_ids = data["atom_ids"] if "atom_ids" in data else np.arange(traj_x.shape[1]) + 1
if atom_id is None:
row = int(data["plot_atom_row"]) if "plot_atom_row" in data else 0
else:
matches = np.where(atom_ids == int(atom_id))[0]
if len(matches) == 0:
raise ValueError(f"原子序号 {atom_id} 不在轨迹文件中")
row = int(matches[0])
selected_id = int(atom_ids[row])
return (
traj_x[:, row], traj_y[:, row], traj_z[:, row],
traj_vx[:, row], traj_vy[:, row], traj_vz[:, row],
selected_id,
)
def load_trajectory_txt(file_path, atom_id=None):
"""加载结构化 .txt 格式的轨迹数据"""
data = compute.load_text_data(file_path)
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, selected_id = select_atom_series(data, atom_id)
NT = int(data['NT'])
DT = float(data['DT'])
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id
def load_full_trajectory(file_path):
"""加载完整轨迹数据(所有原子),返回 2D 数组及原始数据字典。"""
data = compute.load_text_data(file_path)
traj_x = data['traj_x']
traj_y = data['traj_y']
traj_z = data['traj_z']
traj_vx = data['traj_vx']
traj_vy = data['traj_vy']
traj_vz = data['traj_vz']
NT = int(data['NT'])
DT = float(data['DT'])
n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1
atom_ids = data.get("atom_ids", np.arange(n_atoms) + 1)
return traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, data
def create_plots(traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, output_dir='.', atom_ids=None, extra_data=None):
"""创建六个子图的图表,支持多原子绘制与能量曲线。"""
# 配置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 生成时间数组
time = np.arange(NT) * DT
# 判断单原子(1D)或多原子(2D
is_multi = traj_x.ndim == 2
n_atoms = traj_x.shape[1] if is_multi else 1
# 是否绘制能量子图(需要 multi-atom + extra_data 中有 G
has_energy = is_multi and extra_data is not None and "G" in extra_data
n_rows = 3 if has_energy else 2
n_cols = 3
figsize = (15, 13) if has_energy else (15, 9)
# 创建图形和子图
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
title = '轨迹与能量分析' if has_energy else '轨迹分析'
if atom_ids is not None and not is_multi:
title += f' - 原子 {int(atom_ids)}'
fig.suptitle(title, fontsize=16)
plot_configs = [
(axes[0, 0], traj_x, 'x'),
(axes[0, 1], traj_y, 'y'),
(axes[0, 2], traj_z, 'z'),
(axes[1, 0], traj_vx, 'vx'),
(axes[1, 1], traj_vy, 'vy'),
(axes[1, 2], traj_vz, 'vz'),
]
if is_multi:
colors = plt.cm.tab10(np.linspace(0, 1, n_atoms))
if atom_ids is None:
atom_ids = np.arange(n_atoms) + 1
for ax, data_arr, label in plot_configs:
if is_multi:
for i in range(n_atoms):
aid = int(atom_ids[i])
ax.plot(time, data_arr[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
ax.legend()
else:
ax.plot(time, data_arr, 'b-', linewidth=1.5)
ax.set_title(f'{label} - 时间')
ax.set_xlabel('时间 (s)')
ax.set_ylabel(label)
ax.grid(True, alpha=0.3)
# ── 能量子图 ─────────────────────────────────────
if has_energy:
masses = np.array(extra_data["atom_masses"]) # (n_atoms,)
G_vec = np.array(extra_data["G"]) # [gx, gy, gz]
# 动能 Ek = ½ m v²
ek = 0.5 * masses[np.newaxis, :] * (traj_vx**2 + traj_vy**2 + traj_vz**2)
# 重力势能 Ug = -m G·r
ug = -masses[np.newaxis, :] * (G_vec[0] * traj_x + G_vec[1] * traj_y + G_vec[2] * traj_z)
# 弹性势能 Us = ½ k (d - d₀)²
us = np.zeros_like(ek)
bond_pairs = extra_data.get("bond_pairs")
bond_stiffness = extra_data.get("bond_stiffness")
bond_rest_lengths = extra_data.get("bond_rest_lengths")
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = traj_x[:, j] - traj_x[:, i]
dy = traj_y[:, j] - traj_y[:, i]
dz = traj_z[:, j] - traj_z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us[:, i] += 0.5 * bond_stiffness[b_idx] * stretch**2
us[:, j] += 0.5 * bond_stiffness[b_idx] * stretch**2
e_total = ek + ug + us # (NT, n_atoms)
ek_sys = np.sum(ek, axis=1)
ug_sys = np.sum(ug, axis=1)
us_sys = np.sum(us, axis=1)
e_sys = ek_sys + ug_sys + us_sys
# 第 3 行左:各原子总能量
ax_e = axes[2, 0]
for i in range(n_atoms):
aid = int(atom_ids[i])
ax_e.plot(time, e_total[:, i], color=colors[i], linewidth=1.5, label=f"原子 {aid}")
ax_e.set_title("各原子总能量")
ax_e.set_xlabel("时间 (s)")
ax_e.set_ylabel("能量")
ax_e.grid(True, alpha=0.3)
ax_e.legend(loc="upper right")
# 第 3 行右:系统总能量
ax_sys = axes[2, 1]
ax_sys.plot(time, ek_sys, 'b-', linewidth=1.5, label="系统动能")
ax_sys.plot(time, ug_sys, 'g-', linewidth=1.5, label="系统重力势能")
if bond_pairs is not None and len(bond_pairs) > 0:
ax_sys.plot(time, us_sys, color='orange', linewidth=1.5, label="系统弹性势能")
ax_sys.plot(time, e_sys, 'r--', linewidth=1.5, label="系统总能量")
ax_sys.set_title("系统总能量")
ax_sys.set_xlabel("时间 (s)")
ax_sys.set_ylabel("能量")
ax_sys.grid(True, alpha=0.3)
ax_sys.legend(loc="upper right")
# 隐藏第 3 行第 3 列空子图
axes[2, 2].set_visible(False)
# 调整布局
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
# 保存图像
output_path = os.path.join(output_dir, 'trajectory_plots.png')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"图表已保存至: {output_path}")
# 显示图像(如果环境支持)
plt.show()
return output_path
def main():
import argparse
parser = argparse.ArgumentParser(description="绘制轨迹数据图表")
parser.add_argument("trajectory_file", nargs="?",
help="轨迹数据文件路径(默认: output/trajectory.txt")
parser.add_argument("--atom-id", type=int, default=None,
help="指定只绘制某个原子的轨迹(默认绘制所有原子)")
args = parser.parse_args()
# 默认文件
script_dir = os.path.dirname(os.path.abspath(__file__))
default_file = os.path.join(compute.get_output_dir(script_dir), "trajectory.txt")
input_file = args.trajectory_file or default_file
# 检查文件是否存在
if not os.path.exists(input_file):
print(f"错误: 文件 '{input_file}' 不存在")
print(f"请先运行 compute.py 生成轨迹数据")
sys.exit(1)
file_ext = os.path.splitext(input_file)[1].lower()
try:
if args.atom_id is not None:
# 单原子模式(旧版兼容)
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, selected_id = \
load_trajectory_txt(input_file, args.atom_id)
ids_for_plot = selected_id
extra_data = None
n_atoms_str = f" 绘图原子序号: {selected_id}"
else:
# 全原子模式
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz, NT, DT, atom_ids, raw_data = \
load_full_trajectory(input_file)
ids_for_plot = atom_ids
extra_data = raw_data
n_atoms = traj_x.shape[1] if traj_x.ndim > 1 else 1
n_atoms_str = f" 绘图原子数: {n_atoms}"
except Exception as e:
print(f"加载文件时出错: {e}")
sys.exit(1)
print(f"加载轨迹数据: NT={NT}, DT={DT}")
if args.atom_id is not None:
print(n_atoms_str)
else:
print(n_atoms_str)
print(f" 位置范围: x [{traj_x.min():.3f}, {traj_x.max():.3f}], "
f"y [{traj_y.min():.3f}, {traj_y.max():.3f}], "
f"z [{traj_z.min():.3f}, {traj_z.max():.3f}]")
print(f" 速度范围: vx [{traj_vx.min():.3f}, {traj_vx.max():.3f}], "
f"vy [{traj_vy.min():.3f}, {traj_vy.max():.3f}], "
f"vz [{traj_vz.min():.3f}, {traj_vz.max():.3f}]")
# 创建图表
output_path = create_plots(
traj_x, traj_y, traj_z, traj_vx, traj_vy, traj_vz,
NT, DT, compute.get_output_dir(script_dir), ids_for_plot, extra_data,
)
print("绘图完成!")
if __name__ == "__main__":
main()