refactor: merge wave/energy panels into 3 vertical subplots

- Plot 1: x/y/z displacements overlaid on one axes
- Plot 2: per-atom KE/PE/total energy overlaid on one axes
- Plot 3: system energy vs time (unchanged)
All three stacked vertically. Shared y-axis scale within each panel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-13 08:32:41 +08:00
parent 2ab3436235
commit b584c4489c
+55 -70
View File
@@ -264,12 +264,10 @@ def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
布局(4行×2列):
左列 行0-2:x/y/z 位移波形(vs 原子序号)
右列 行0:每粒子动能
右列 行1:每粒子势能
右列 行2:每粒子总能
右列 行3:系统总能量随时间变化
布局(3行×1列,纵向排列):
行0x/y/z 位移波形叠加在同一子图vs 原子序号)
行1:每粒子动能、势能、总能叠加在同一子图
行2:系统总能量随时间变化
Args:
output_dir: 输出目录(含 display.npz 或 display.txt
@@ -340,79 +338,66 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
return 0.0, 1.0
return 0.0, vmax * 1.2
disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
ek_ylim = get_ylim_pos(ek_atom)
pe_ylim = get_ylim_pos(pe_atom)
et_ylim = get_ylim_pos(et_atom)
e_max = max(np.max(e_total), 0.01) * 1.3
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
# 共用 y 轴范围:位移图取三个方向最大值统一
disp_vmax = max(np.max(np.abs(dx)), np.max(np.abs(dy)), np.max(np.abs(dz)))
disp_vmax = disp_vmax if disp_vmax > 1e-10 else 1.0
disp_ylim = (-disp_vmax * 1.2, disp_vmax * 1.2)
# 每粒子能量:取三者最大值统一 y 轴
energy_vmax = max(np.max(ek_atom), np.max(pe_atom), np.max(et_atom))
energy_vmax = energy_vmax if energy_vmax > 1e-12 else 1.0
energy_ylim = (0.0, energy_vmax * 1.2)
e_max = max(np.max(e_total), 0.01) * 1.3
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
atom_idx = np.arange(n_atoms)
# ── 图形布局 ──
# ── 图形布局3 行 × 1 列,纵向排列 ──
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig = plt.figure(figsize=(16, 14))
fig, (ax_wave, ax_energy, ax_ep) = plt.subplots(3, 1, figsize=(12, 14))
fig.suptitle("波形与能量分析", fontsize=16)
fig.subplots_adjust(hspace=0.40, top=0.94)
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35)
# ── 图1:x/y/z 位移波形叠加 ──
ax_wave.set_xlim(0, n_atoms - 1)
ax_wave.set_ylim(disp_ylim)
ax_wave.set_xlabel("原子序号")
ax_wave.set_ylabel("位移")
ax_wave.set_title("粒子位移(x / y / z 方向)")
ax_wave.grid(True, alpha=0.3)
# 左列:位移波形(行 0-2
ax_dx = fig.add_subplot(gs[0, 0])
ax_dy = fig.add_subplot(gs[1, 0])
ax_dz = fig.add_subplot(gs[2, 0])
# 左列行 3 留空(或与右下对齐)
ax_blank = fig.add_subplot(gs[3, 0])
ax_blank.set_visible(False)
# 右列:每粒子能量(行 0-2)+ 时间能量图(行 3)
ax_ek = fig.add_subplot(gs[0, 1])
ax_pe = fig.add_subplot(gs[1, 1])
ax_et = fig.add_subplot(gs[2, 1])
ax_ep = fig.add_subplot(gs[3, 1])
# ── 初始化位移波形 ──
wave_axes = [ax_dx, ax_dy, ax_dz]
wave_disps = [dx, dy, dz]
wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"]
wave_labels = ["x 方向(纵波)", "y 方向(横波)", "z 方向(横波)"]
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
wave_lines = []
time_texts = []
for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims):
ax.set_xlim(0, n_atoms - 1)
ax.set_ylim(yl)
ax.set_xlabel("原子序号")
ax.set_ylabel("位移")
ax.set_title(title)
ax.grid(True, alpha=0.3)
ln, = ax.plot([], [], color=color, linewidth=1.5)
for label, color in zip(wave_labels, wave_colors):
ln, = ax_wave.plot([], [], color=color, linewidth=1.5, label=label)
wave_lines.append(ln)
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
fontsize=9, verticalalignment="top")
time_texts.append(tt)
ax_wave.legend(loc="upper right", fontsize=9)
time_text = ax_wave.text(0.02, 0.95, "", transform=ax_wave.transAxes,
fontsize=10, verticalalignment="top")
# ── 图2:每粒子动能、势能、总能叠加 ──
ax_energy.set_xlim(0, n_atoms - 1)
ax_energy.set_ylim(energy_ylim)
ax_energy.set_xlabel("原子序号")
ax_energy.set_ylabel("能量")
ax_energy.set_title("每粒子能量(动能 / 势能 / 总能)")
ax_energy.grid(True, alpha=0.3)
# ── 初始化每粒子能量图 ──
energy_axes = [ax_ek, ax_pe, ax_et]
energy_arrays = [ek_atom, pe_atom, et_atom]
energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"]
energy_labels = ["动能", "势能", "总能"]
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
energy_ylims = [ek_ylim, pe_ylim, et_ylim]
energy_lines = []
for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims):
ax.set_xlim(0, n_atoms - 1)
ax.set_ylim(yl)
ax.set_xlabel("原子序号")
ax.set_ylabel("能量")
ax.set_title(title)
ax.grid(True, alpha=0.3)
ln, = ax.plot([], [], color=color, linewidth=1.5)
for label, color in zip(energy_labels, energy_colors):
ln, = ax_energy.plot([], [], color=color, linewidth=1.5, label=label)
energy_lines.append(ln)
ax_energy.legend(loc="upper right", fontsize=9)
# ── 初始化系统总能量时间 ──
# ── 图3系统总能量时间 ──
ax_ep.set_xlim(t[0], t[-1])
ep_yhigh = max(e_max, p_max)
ep_ylow = min(-p_max * 0.1, 0.0)
@@ -432,20 +417,20 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
if gravity_interaction and n_atoms <= 200:
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
ax_ep.legend(loc="upper left", fontsize=8)
ax_ep.legend(loc="upper left", fontsize=9)
# ── 动画更新 ──
def update(frame):
# 位移波形
for i in range(3):
wave_lines[i].set_data(atom_idx, wave_disps[i][frame])
time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}")
# 图1位移波形
for i, ln in enumerate(wave_lines):
ln.set_data(atom_idx, wave_disps[i][frame])
time_text.set_text(f"t = {t[frame]:.2f} s |{frame+1}/{n_frames}")
# 每粒子能量
for i in range(3):
energy_lines[i].set_data(atom_idx, energy_arrays[i][frame])
# 图2每粒子能量
for i, ln in enumerate(energy_lines):
ln.set_data(atom_idx, energy_arrays[i][frame])
# 系统能量(累计到当前帧)
# 图3系统能量(累计到当前帧)
cur_t = t[:frame + 1]
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
ln_us.set_data(cur_t, us_sys[:frame + 1])
@@ -455,7 +440,7 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
artists = wave_lines + time_texts + energy_lines + \
artists = wave_lines + [time_text] + energy_lines + \
[ln_ek, ln_us, ln_et, ln_pw]
if ln_ug: artists.append(ln_ug)
if ln_ugr: artists.append(ln_ugr)