refactor: merge wave/energy panels into 3 vertical subplots
- Plot 1: x/y/z displacements overlaid on one axes - Plot 2: per-atom KE/PE/total energy overlaid on one axes - Plot 3: system energy vs time (unchanged) All three stacked vertically. Shared y-axis scale within each panel. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+55
-70
@@ -264,12 +264,10 @@ def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
|
|||||||
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||||||
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
|
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
|
||||||
|
|
||||||
布局(4行×2列):
|
布局(3行×1列,纵向排列):
|
||||||
左列 行0-2:x/y/z 位移波形(vs 原子序号)
|
行0:x/y/z 位移波形叠加在同一子图(vs 原子序号)
|
||||||
右列 行0:每粒子动能
|
行1:每粒子动能、势能、总能叠加在同一子图
|
||||||
右列 行1:每粒子势能
|
行2:系统总能量随时间变化
|
||||||
右列 行2:每粒子总能
|
|
||||||
右列 行3:系统总能量随时间变化
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_dir: 输出目录(含 display.npz 或 display.txt)
|
output_dir: 输出目录(含 display.npz 或 display.txt)
|
||||||
@@ -340,79 +338,66 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
|||||||
return 0.0, 1.0
|
return 0.0, 1.0
|
||||||
return 0.0, vmax * 1.2
|
return 0.0, vmax * 1.2
|
||||||
|
|
||||||
disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
|
# 共用 y 轴范围:位移图取三个方向最大值统一
|
||||||
ek_ylim = get_ylim_pos(ek_atom)
|
disp_vmax = max(np.max(np.abs(dx)), np.max(np.abs(dy)), np.max(np.abs(dz)))
|
||||||
pe_ylim = get_ylim_pos(pe_atom)
|
disp_vmax = disp_vmax if disp_vmax > 1e-10 else 1.0
|
||||||
et_ylim = get_ylim_pos(et_atom)
|
disp_ylim = (-disp_vmax * 1.2, disp_vmax * 1.2)
|
||||||
e_max = max(np.max(e_total), 0.01) * 1.3
|
|
||||||
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
|
# 每粒子能量:取三者最大值统一 y 轴
|
||||||
|
energy_vmax = max(np.max(ek_atom), np.max(pe_atom), np.max(et_atom))
|
||||||
|
energy_vmax = energy_vmax if energy_vmax > 1e-12 else 1.0
|
||||||
|
energy_ylim = (0.0, energy_vmax * 1.2)
|
||||||
|
|
||||||
|
e_max = max(np.max(e_total), 0.01) * 1.3
|
||||||
|
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
|
||||||
|
|
||||||
atom_idx = np.arange(n_atoms)
|
atom_idx = np.arange(n_atoms)
|
||||||
|
|
||||||
# ── 图形布局 ──
|
# ── 图形布局:3 行 × 1 列,纵向排列 ──
|
||||||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
|
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
|
||||||
plt.rcParams['axes.unicode_minus'] = False
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
fig = plt.figure(figsize=(16, 14))
|
fig, (ax_wave, ax_energy, ax_ep) = plt.subplots(3, 1, figsize=(12, 14))
|
||||||
fig.suptitle("波形与能量分析", fontsize=16)
|
fig.suptitle("波形与能量分析", fontsize=16)
|
||||||
|
fig.subplots_adjust(hspace=0.40, top=0.94)
|
||||||
|
|
||||||
import matplotlib.gridspec as gridspec
|
# ── 图1:x/y/z 位移波形叠加 ──
|
||||||
gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35)
|
ax_wave.set_xlim(0, n_atoms - 1)
|
||||||
|
ax_wave.set_ylim(disp_ylim)
|
||||||
|
ax_wave.set_xlabel("原子序号")
|
||||||
|
ax_wave.set_ylabel("位移")
|
||||||
|
ax_wave.set_title("粒子位移(x / y / z 方向)")
|
||||||
|
ax_wave.grid(True, alpha=0.3)
|
||||||
|
|
||||||
# 左列:位移波形(行 0-2)
|
|
||||||
ax_dx = fig.add_subplot(gs[0, 0])
|
|
||||||
ax_dy = fig.add_subplot(gs[1, 0])
|
|
||||||
ax_dz = fig.add_subplot(gs[2, 0])
|
|
||||||
# 左列行 3 留空(或与右下对齐)
|
|
||||||
ax_blank = fig.add_subplot(gs[3, 0])
|
|
||||||
ax_blank.set_visible(False)
|
|
||||||
|
|
||||||
# 右列:每粒子能量(行 0-2)+ 时间能量图(行 3)
|
|
||||||
ax_ek = fig.add_subplot(gs[0, 1])
|
|
||||||
ax_pe = fig.add_subplot(gs[1, 1])
|
|
||||||
ax_et = fig.add_subplot(gs[2, 1])
|
|
||||||
ax_ep = fig.add_subplot(gs[3, 1])
|
|
||||||
|
|
||||||
# ── 初始化位移波形 ──
|
|
||||||
wave_axes = [ax_dx, ax_dy, ax_dz]
|
|
||||||
wave_disps = [dx, dy, dz]
|
wave_disps = [dx, dy, dz]
|
||||||
wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"]
|
wave_labels = ["x 方向(纵波)", "y 方向(横波)", "z 方向(横波)"]
|
||||||
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
|
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
|
||||||
wave_lines = []
|
wave_lines = []
|
||||||
time_texts = []
|
for label, color in zip(wave_labels, wave_colors):
|
||||||
|
ln, = ax_wave.plot([], [], color=color, linewidth=1.5, label=label)
|
||||||
for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims):
|
|
||||||
ax.set_xlim(0, n_atoms - 1)
|
|
||||||
ax.set_ylim(yl)
|
|
||||||
ax.set_xlabel("原子序号")
|
|
||||||
ax.set_ylabel("位移")
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
ln, = ax.plot([], [], color=color, linewidth=1.5)
|
|
||||||
wave_lines.append(ln)
|
wave_lines.append(ln)
|
||||||
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
|
ax_wave.legend(loc="upper right", fontsize=9)
|
||||||
fontsize=9, verticalalignment="top")
|
time_text = ax_wave.text(0.02, 0.95, "", transform=ax_wave.transAxes,
|
||||||
time_texts.append(tt)
|
fontsize=10, verticalalignment="top")
|
||||||
|
|
||||||
|
# ── 图2:每粒子动能、势能、总能叠加 ──
|
||||||
|
ax_energy.set_xlim(0, n_atoms - 1)
|
||||||
|
ax_energy.set_ylim(energy_ylim)
|
||||||
|
ax_energy.set_xlabel("原子序号")
|
||||||
|
ax_energy.set_ylabel("能量")
|
||||||
|
ax_energy.set_title("每粒子能量(动能 / 势能 / 总能)")
|
||||||
|
ax_energy.grid(True, alpha=0.3)
|
||||||
|
|
||||||
# ── 初始化每粒子能量图 ──
|
|
||||||
energy_axes = [ax_ek, ax_pe, ax_et]
|
|
||||||
energy_arrays = [ek_atom, pe_atom, et_atom]
|
energy_arrays = [ek_atom, pe_atom, et_atom]
|
||||||
energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"]
|
energy_labels = ["动能", "势能", "总能"]
|
||||||
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
|
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
|
||||||
energy_ylims = [ek_ylim, pe_ylim, et_ylim]
|
|
||||||
energy_lines = []
|
energy_lines = []
|
||||||
|
for label, color in zip(energy_labels, energy_colors):
|
||||||
for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims):
|
ln, = ax_energy.plot([], [], color=color, linewidth=1.5, label=label)
|
||||||
ax.set_xlim(0, n_atoms - 1)
|
|
||||||
ax.set_ylim(yl)
|
|
||||||
ax.set_xlabel("原子序号")
|
|
||||||
ax.set_ylabel("能量")
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
ln, = ax.plot([], [], color=color, linewidth=1.5)
|
|
||||||
energy_lines.append(ln)
|
energy_lines.append(ln)
|
||||||
|
ax_energy.legend(loc="upper right", fontsize=9)
|
||||||
|
|
||||||
# ── 初始化系统总能量时间图 ──
|
# ── 图3:系统总能量随时间 ──
|
||||||
ax_ep.set_xlim(t[0], t[-1])
|
ax_ep.set_xlim(t[0], t[-1])
|
||||||
ep_yhigh = max(e_max, p_max)
|
ep_yhigh = max(e_max, p_max)
|
||||||
ep_ylow = min(-p_max * 0.1, 0.0)
|
ep_ylow = min(-p_max * 0.1, 0.0)
|
||||||
@@ -432,20 +417,20 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
|||||||
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
|
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
|
||||||
if gravity_interaction and n_atoms <= 200:
|
if gravity_interaction and n_atoms <= 200:
|
||||||
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
|
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
|
||||||
ax_ep.legend(loc="upper left", fontsize=8)
|
ax_ep.legend(loc="upper left", fontsize=9)
|
||||||
|
|
||||||
# ── 动画更新 ──
|
# ── 动画更新 ──
|
||||||
def update(frame):
|
def update(frame):
|
||||||
# 位移波形
|
# 图1:位移波形
|
||||||
for i in range(3):
|
for i, ln in enumerate(wave_lines):
|
||||||
wave_lines[i].set_data(atom_idx, wave_disps[i][frame])
|
ln.set_data(atom_idx, wave_disps[i][frame])
|
||||||
time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}")
|
time_text.set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}")
|
||||||
|
|
||||||
# 每粒子能量
|
# 图2:每粒子能量
|
||||||
for i in range(3):
|
for i, ln in enumerate(energy_lines):
|
||||||
energy_lines[i].set_data(atom_idx, energy_arrays[i][frame])
|
ln.set_data(atom_idx, energy_arrays[i][frame])
|
||||||
|
|
||||||
# 系统总能量(累计到当前帧)
|
# 图3:系统能量(累计到当前帧)
|
||||||
cur_t = t[:frame + 1]
|
cur_t = t[:frame + 1]
|
||||||
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
|
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
|
||||||
ln_us.set_data(cur_t, us_sys[:frame + 1])
|
ln_us.set_data(cur_t, us_sys[:frame + 1])
|
||||||
@@ -455,7 +440,7 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
|||||||
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
|
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
|
||||||
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
|
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
|
||||||
|
|
||||||
artists = wave_lines + time_texts + energy_lines + \
|
artists = wave_lines + [time_text] + energy_lines + \
|
||||||
[ln_ek, ln_us, ln_et, ln_pw]
|
[ln_ek, ln_us, ln_et, ln_pw]
|
||||||
if ln_ug: artists.append(ln_ug)
|
if ln_ug: artists.append(ln_ug)
|
||||||
if ln_ugr: artists.append(ln_ugr)
|
if ln_ugr: artists.append(ln_ugr)
|
||||||
|
|||||||
Reference in New Issue
Block a user