diff --git a/plot_wave.py b/plot_wave.py index bb40d7b..9d3bf96 100644 --- a/plot_wave.py +++ b/plot_wave.py @@ -264,12 +264,10 @@ def compute_per_atom_energy(x, y, z, vx, vy, vz, masses, def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): """主绘图函数:读取 display.txt 并生成波形+能量动画。 - 布局(4行×2列): - 左列 行0-2:x/y/z 位移波形(vs 原子序号) - 右列 行0:每粒子动能 - 右列 行1:每粒子势能 - 右列 行2:每粒子总能 - 右列 行3:系统总能量随时间变化 + 布局(3行×1列,纵向排列): + 行0:x/y/z 位移波形叠加在同一子图(vs 原子序号) + 行1:每粒子动能、势能、总能叠加在同一子图 + 行2:系统总能量随时间变化 Args: output_dir: 输出目录(含 display.npz 或 display.txt) @@ -340,79 +338,66 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): return 0.0, 1.0 return 0.0, vmax * 1.2 - disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)] - ek_ylim = get_ylim_pos(ek_atom) - pe_ylim = get_ylim_pos(pe_atom) - et_ylim = get_ylim_pos(et_atom) - e_max = max(np.max(e_total), 0.01) * 1.3 - p_max = max(np.max(np.abs(power)) * 1.3, 0.01) + # 共用 y 轴范围:位移图取三个方向最大值统一 + disp_vmax = max(np.max(np.abs(dx)), np.max(np.abs(dy)), np.max(np.abs(dz))) + disp_vmax = disp_vmax if disp_vmax > 1e-10 else 1.0 + disp_ylim = (-disp_vmax * 1.2, disp_vmax * 1.2) + + # 每粒子能量:取三者最大值统一 y 轴 + energy_vmax = max(np.max(ek_atom), np.max(pe_atom), np.max(et_atom)) + energy_vmax = energy_vmax if energy_vmax > 1e-12 else 1.0 + energy_ylim = (0.0, energy_vmax * 1.2) + + e_max = max(np.max(e_total), 0.01) * 1.3 + p_max = max(np.max(np.abs(power)) * 1.3, 0.01) atom_idx = np.arange(n_atoms) - # ── 图形布局 ── + # ── 图形布局:3 行 × 1 列,纵向排列 ── plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False - fig = plt.figure(figsize=(16, 14)) + fig, (ax_wave, ax_energy, ax_ep) = plt.subplots(3, 1, figsize=(12, 14)) fig.suptitle("波形与能量分析", fontsize=16) + fig.subplots_adjust(hspace=0.40, top=0.94) - import matplotlib.gridspec as gridspec - gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35) + # ── 图1:x/y/z 位移波形叠加 ── + ax_wave.set_xlim(0, n_atoms - 1) + ax_wave.set_ylim(disp_ylim) + ax_wave.set_xlabel("原子序号") + ax_wave.set_ylabel("位移") + ax_wave.set_title("粒子位移(x / y / z 方向)") + ax_wave.grid(True, alpha=0.3) - # 左列:位移波形(行 0-2) - ax_dx = fig.add_subplot(gs[0, 0]) - ax_dy = fig.add_subplot(gs[1, 0]) - ax_dz = fig.add_subplot(gs[2, 0]) - # 左列行 3 留空(或与右下对齐) - ax_blank = fig.add_subplot(gs[3, 0]) - ax_blank.set_visible(False) - - # 右列:每粒子能量(行 0-2)+ 时间能量图(行 3) - ax_ek = fig.add_subplot(gs[0, 1]) - ax_pe = fig.add_subplot(gs[1, 1]) - ax_et = fig.add_subplot(gs[2, 1]) - ax_ep = fig.add_subplot(gs[3, 1]) - - # ── 初始化位移波形 ── - wave_axes = [ax_dx, ax_dy, ax_dz] wave_disps = [dx, dy, dz] - wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"] + wave_labels = ["x 方向(纵波)", "y 方向(横波)", "z 方向(横波)"] wave_colors = ["#2563eb", "#ea580c", "#16a34a"] wave_lines = [] - time_texts = [] - - for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims): - ax.set_xlim(0, n_atoms - 1) - ax.set_ylim(yl) - ax.set_xlabel("原子序号") - ax.set_ylabel("位移") - ax.set_title(title) - ax.grid(True, alpha=0.3) - ln, = ax.plot([], [], color=color, linewidth=1.5) + for label, color in zip(wave_labels, wave_colors): + ln, = ax_wave.plot([], [], color=color, linewidth=1.5, label=label) wave_lines.append(ln) - tt = ax.text(0.02, 0.95, "", transform=ax.transAxes, - fontsize=9, verticalalignment="top") - time_texts.append(tt) + ax_wave.legend(loc="upper right", fontsize=9) + time_text = ax_wave.text(0.02, 0.95, "", transform=ax_wave.transAxes, + fontsize=10, verticalalignment="top") + + # ── 图2:每粒子动能、势能、总能叠加 ── + ax_energy.set_xlim(0, n_atoms - 1) + ax_energy.set_ylim(energy_ylim) + ax_energy.set_xlabel("原子序号") + ax_energy.set_ylabel("能量") + ax_energy.set_title("每粒子能量(动能 / 势能 / 总能)") + ax_energy.grid(True, alpha=0.3) - # ── 初始化每粒子能量图 ── - energy_axes = [ax_ek, ax_pe, ax_et] energy_arrays = [ek_atom, pe_atom, et_atom] - energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"] + energy_labels = ["动能", "势能", "总能"] energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"] - energy_ylims = [ek_ylim, pe_ylim, et_ylim] energy_lines = [] - - for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims): - ax.set_xlim(0, n_atoms - 1) - ax.set_ylim(yl) - ax.set_xlabel("原子序号") - ax.set_ylabel("能量") - ax.set_title(title) - ax.grid(True, alpha=0.3) - ln, = ax.plot([], [], color=color, linewidth=1.5) + for label, color in zip(energy_labels, energy_colors): + ln, = ax_energy.plot([], [], color=color, linewidth=1.5, label=label) energy_lines.append(ln) + ax_energy.legend(loc="upper right", fontsize=9) - # ── 初始化系统总能量时间图 ── + # ── 图3:系统总能量随时间 ── ax_ep.set_xlim(t[0], t[-1]) ep_yhigh = max(e_max, p_max) ep_ylow = min(-p_max * 0.1, 0.0) @@ -432,20 +417,20 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能") if gravity_interaction and n_atoms <= 200: ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能") - ax_ep.legend(loc="upper left", fontsize=8) + ax_ep.legend(loc="upper left", fontsize=9) # ── 动画更新 ── def update(frame): - # 位移波形 - for i in range(3): - wave_lines[i].set_data(atom_idx, wave_disps[i][frame]) - time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}") + # 图1:位移波形 + for i, ln in enumerate(wave_lines): + ln.set_data(atom_idx, wave_disps[i][frame]) + time_text.set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}") - # 每粒子能量 - for i in range(3): - energy_lines[i].set_data(atom_idx, energy_arrays[i][frame]) + # 图2:每粒子能量 + for i, ln in enumerate(energy_lines): + ln.set_data(atom_idx, energy_arrays[i][frame]) - # 系统总能量(累计到当前帧) + # 图3:系统能量(累计到当前帧) cur_t = t[:frame + 1] ln_ek.set_data(cur_t, ek_sys[:frame + 1]) ln_us.set_data(cur_t, us_sys[:frame + 1]) @@ -455,7 +440,7 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1]) ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1])) - artists = wave_lines + time_texts + energy_lines + \ + artists = wave_lines + [time_text] + energy_lines + \ [ln_ek, ln_us, ln_et, ln_pw] if ln_ug: artists.append(ln_ug) if ln_ugr: artists.append(ln_ugr)