From 2ab34362357fe2d93f8f9cf5c7c94ad07e948d04 Mon Sep 17 00:00:00 2001 From: Ying-Li Niu <64801511@qq.com> Date: Sat, 13 Jun 2026 08:26:51 +0800 Subject: [PATCH] feat: redesign plot_wave with per-atom energy panels New 4x2 layout: left col = x/y/z displacement waves, right col = per-atom KE/PE/total energy + system energy vs time. PE split 50/50 for normal bonds; 100% to non-driven atom when bonded to driver; driven atom PE = E_SHO - KE. Co-Authored-By: Claude Sonnet 4.6 --- plot_wave.py | 331 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 231 insertions(+), 100 deletions(-) diff --git a/plot_wave.py b/plot_wave.py index b1c6b26..bb40d7b 100644 --- a/plot_wave.py +++ b/plot_wave.py @@ -167,179 +167,312 @@ def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr, return ek_sys, us_sys, ug_sys, ugr_sys +def _load_driver_info(output_dir): + """从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。 + 返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}。 + """ + input_dir = os.path.join(os.path.dirname(output_dir), "input") + path = os.path.join(input_dir, "driver.txt") + if not os.path.exists(path): + return {} + drivers = {} + try: + with open(path, encoding="utf-8") as f: + f.readline() # skip header + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + n = int(parts[0]) + amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])]) + freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])]) + drivers[n] = {"amp": amp, "freq": freq} + except Exception: + pass + return drivers + + +def compute_per_atom_energy(x, y, z, vx, vy, vz, masses, + bond_pairs, bond_stiffness, bond_rest_lengths, + atom_ids, driver_info): + """计算每帧每个粒子的动能、势能、总能。 + + 势能分配规则: + - 非驱动粒子之间的键:势能各分一半 + - 键的一端是驱动粒子:势能全部归非驱动端 + - 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算 + (此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven) + + Returns: + ek_atom: (n_frames, n_atoms) 每原子动能 + pe_atom: (n_frames, n_atoms) 每原子势能 + et_atom: (n_frames, n_atoms) 每原子总能 + """ + n_frames, n_atoms = x.shape + + # 动能 per atom + masses_2d = masses[np.newaxis, :] + ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms) + + # 构建 atom_id → index 映射,以及驱动原子 index 集合 + id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)} + driven_idx = set() + for aid in driver_info: + if aid in id_to_idx: + driven_idx.add(id_to_idx[aid]) + + # 势能 per atom(弹簧键) + pe_atom = np.zeros((n_frames, n_atoms)) + if bond_pairs is not None and len(bond_pairs) > 0: + for b in range(len(bond_pairs)): + i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1]) + ddx = x[:, j] - x[:, i] + ddy = y[:, j] - y[:, i] + ddz = z[:, j] - z[:, i] + dist = np.sqrt(ddx**2 + ddy**2 + ddz**2) + bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2 + i_driven = i in driven_idx + j_driven = j in driven_idx + if i_driven and not j_driven: + pe_atom[:, j] += bond_pe # 全归非驱动端 + elif j_driven and not i_driven: + pe_atom[:, i] += bond_pe # 全归非驱动端 + else: + pe_atom[:, i] += bond_pe * 0.5 + pe_atom[:, j] += bond_pe * 0.5 + + # 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE + TWO_PI = 2.0 * np.pi + for aid, info in driver_info.items(): + if aid not in id_to_idx: + continue + idx = id_to_idx[aid] + m = masses[idx] + amp = info["amp"] + freq = info["freq"] + e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2) + pe_sho = e_sho - ek_atom[:, idx] + pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负 + + et_atom = ek_atom + pe_atom + return ek_atom, pe_atom, et_atom + + def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): """主绘图函数:读取 display.txt 并生成波形+能量动画。 + 布局(4行×2列): + 左列 行0-2:x/y/z 位移波形(vs 原子序号) + 右列 行0:每粒子动能 + 右列 行1:每粒子势能 + 右列 行2:每粒子总能 + 右列 行3:系统总能量随时间变化 + Args: - output_dir: 输出目录(含 display.txt) + output_dir: 输出目录(含 display.npz 或 display.txt) save_gif: 是否保存 GIF save_mp4: 是否保存 MP4 + show: 是否弹出交互窗口 """ data = _load_wave_dataset(output_dir) n_frames = int(data["n_frames"]) t = np.array(data["t"]) - # 位置 / 速度 - x = np.array(data["x"]) - y = np.array(data["y"]) - z = np.array(data["z"]) + x = np.array(data["x"]) + y = np.array(data["y"]) + z = np.array(data["z"]) vx = np.array(data["vx"]) vy = np.array(data["vy"]) vz = np.array(data["vz"]) - # 原子信息 - pos_0 = np.array(data["pos_0"]) - masses = np.array(data["masses"]) + pos_0 = np.array(data["pos_0"]) + masses = np.array(data["masses"]) atom_ids = np.array(data["atom_ids"]) - n_atoms = len(atom_ids) + n_atoms = len(atom_ids) - # 成键 - bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64) - bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64) - bond_rest_lengths = np.array(data.get("bond_rest_lengths", []), dtype=np.float64) + bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64) + bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64) + bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64) - # 物理开关 - gravity_field = int(data.get("gravity_field", 0)) + gravity_field = int(data.get("gravity_field", 0)) gravity_interaction = int(data.get("gravity_interaction", 0)) - G = data.get("G", [0, 0, 0]) - gravity_strength = float(data.get("gravity_strength", 1.0)) - driving_force = int(data.get("driving_force", 0)) + G = data.get("G", [0, 0, 0]) + gravity_strength = float(data.get("gravity_strength", 1.0)) + driving_force = int(data.get("driving_force", 0)) - # ── 位移(偏离初始平衡位形)── - dx = x - pos_0[np.newaxis, :, 0] # 纵波(沿链方向 x) - dy = y - pos_0[np.newaxis, :, 1] # 横波 1(y 方向) - dz = z - pos_0[np.newaxis, :, 2] # 横波 2(z 方向) + # 驱动原子信息(用于势能计算) + driver_info = _load_driver_info(output_dir) if driving_force else {} - # ── 能量 ── - ek, us, ug, ugr = compute_energy( + # ── 位移 ── + dx = x - pos_0[np.newaxis, :, 0] + dy = y - pos_0[np.newaxis, :, 1] + dz = z - pos_0[np.newaxis, :, 2] + + # ── 系统总能量(用于右下时间图)── + ek_sys, us_sys, ug_sys, ugr_sys = compute_energy( x, y, z, vx, vy, vz, masses, masses, bond_pairs, bond_stiffness, bond_rest_lengths, gravity_field, G, gravity_interaction, gravity_strength) - e_total = ek + us + ug + ugr - power = np.gradient(e_total, t) # 输入功率 = dE/dt + e_total = ek_sys + us_sys + ug_sys + ugr_sys + power = np.gradient(e_total, t) - # ── 波形纵轴范围(全局统一,避免抖动)── - def get_ylim(disp_data): - vmax = np.max(np.abs(disp_data)) + # ── 每粒子能量 ── + ek_atom, pe_atom, et_atom = compute_per_atom_energy( + x, y, z, vx, vy, vz, masses, + bond_pairs, bond_stiffness, bond_rest_lengths, + atom_ids, driver_info) + + # ── y 轴范围 ── + def get_ylim(arr): + vmax = np.max(np.abs(arr)) if vmax < 1e-10: return -1.0, 1.0 - margin = vmax * 0.2 - return -vmax - margin, vmax + margin + m = vmax * 0.2 + return -vmax - m, vmax + m - ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)] - e_max = max(np.max(e_total), 0.01) * 1.3 - p_abs = np.max(np.abs(power)) - p_max = max(p_abs * 1.3, 0.01) + def get_ylim_pos(arr): + vmax = np.max(arr) + if vmax < 1e-12: + return 0.0, 1.0 + return 0.0, vmax * 1.2 + + disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)] + ek_ylim = get_ylim_pos(ek_atom) + pe_ylim = get_ylim_pos(pe_atom) + et_ylim = get_ylim_pos(et_atom) + e_max = max(np.max(e_total), 0.01) * 1.3 + p_max = max(np.max(np.abs(power)) * 1.3, 0.01) atom_idx = np.arange(n_atoms) - # ── 图形设置 ── + # ── 图形布局 ── plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False - fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + fig = plt.figure(figsize=(16, 14)) fig.suptitle("波形与能量分析", fontsize=16) - # ── 3 个波形子图 ── - titles = [ - "纵波 (x 方向位移)", - "横波 (y 方向位移)", - "横波 (z 方向位移)", - ] - colors = ["#2563eb", "#ea580c", "#16a34a"] - wave_configs = list(zip( - [axes[0, 0], axes[0, 1], axes[1, 0]], - [dx, dy, dz], titles, colors, ylims - )) + import matplotlib.gridspec as gridspec + gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35) - wave_lines = [] - time_texts = [] - for ax, disp, title, color, yl in wave_configs: + # 左列:位移波形(行 0-2) + ax_dx = fig.add_subplot(gs[0, 0]) + ax_dy = fig.add_subplot(gs[1, 0]) + ax_dz = fig.add_subplot(gs[2, 0]) + # 左列行 3 留空(或与右下对齐) + ax_blank = fig.add_subplot(gs[3, 0]) + ax_blank.set_visible(False) + + # 右列:每粒子能量(行 0-2)+ 时间能量图(行 3) + ax_ek = fig.add_subplot(gs[0, 1]) + ax_pe = fig.add_subplot(gs[1, 1]) + ax_et = fig.add_subplot(gs[2, 1]) + ax_ep = fig.add_subplot(gs[3, 1]) + + # ── 初始化位移波形 ── + wave_axes = [ax_dx, ax_dy, ax_dz] + wave_disps = [dx, dy, dz] + wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"] + wave_colors = ["#2563eb", "#ea580c", "#16a34a"] + wave_lines = [] + time_texts = [] + + for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims): ax.set_xlim(0, n_atoms - 1) ax.set_ylim(yl) ax.set_xlabel("原子序号") ax.set_ylabel("位移") ax.set_title(title) ax.grid(True, alpha=0.3) - line, = ax.plot([], [], color=color, linewidth=1.5) - wave_lines.append(line) + ln, = ax.plot([], [], color=color, linewidth=1.5) + wave_lines.append(ln) tt = ax.text(0.02, 0.95, "", transform=ax.transAxes, - fontsize=11, verticalalignment="top") + fontsize=9, verticalalignment="top") time_texts.append(tt) - # ── 能量+功率子图 ── - ax_ep = axes[1, 1] + # ── 初始化每粒子能量图 ── + energy_axes = [ax_ek, ax_pe, ax_et] + energy_arrays = [ek_atom, pe_atom, et_atom] + energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"] + energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"] + energy_ylims = [ek_ylim, pe_ylim, et_ylim] + energy_lines = [] + + for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims): + ax.set_xlim(0, n_atoms - 1) + ax.set_ylim(yl) + ax.set_xlabel("原子序号") + ax.set_ylabel("能量") + ax.set_title(title) + ax.grid(True, alpha=0.3) + ln, = ax.plot([], [], color=color, linewidth=1.5) + energy_lines.append(ln) + + # ── 初始化系统总能量时间图 ── ax_ep.set_xlim(t[0], t[-1]) - ep_ylow = min(-0.1 * max(e_max, p_max), -p_max * 0.1) ep_yhigh = max(e_max, p_max) + ep_ylow = min(-p_max * 0.1, 0.0) ax_ep.set_ylim(ep_ylow, ep_yhigh) ax_ep.set_xlabel("时间 (s)") ax_ep.set_ylabel("能量 / 功率") ax_ep.set_title("系统能量与输入功率") ax_ep.grid(True, alpha=0.3) - line_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能") - line_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能") - line_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量") - line_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)") + ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能") + ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能") + ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量") + ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)") + ln_ug = None + ln_ugr = None if gravity_field: - line_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能") - else: - line_ug = None + ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能") if gravity_interaction and n_atoms <= 200: - line_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能") - else: - line_ugr = None - - ax_ep.legend(loc="upper right", fontsize=9) - - plt.tight_layout(rect=[0, 0, 1, 0.95]) + ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能") + ax_ep.legend(loc="upper left", fontsize=8) # ── 动画更新 ── def update(frame): - # 波形 + # 位移波形 for i in range(3): - wave_lines[i].set_data(atom_idx, [dx, dy, dz][i][frame]) - time_texts[i].set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}") + wave_lines[i].set_data(atom_idx, wave_disps[i][frame]) + time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}") - # 能量(累计到当前帧) + # 每粒子能量 + for i in range(3): + energy_lines[i].set_data(atom_idx, energy_arrays[i][frame]) + + # 系统总能量(累计到当前帧) cur_t = t[:frame + 1] - line_ek.set_data(cur_t, ek[:frame + 1]) - line_us.set_data(cur_t, us[:frame + 1]) - line_et.set_data(cur_t, e_total[:frame + 1]) - line_pw.set_data(cur_t, power[:frame + 1]) - if line_ug: - line_ug.set_data(cur_t, ug[:frame + 1]) - if line_ugr: - line_ugr.set_data(cur_t, ugr[:frame + 1]) - - # 能量图 x 轴动态扩展 + ln_ek.set_data(cur_t, ek_sys[:frame + 1]) + ln_us.set_data(cur_t, us_sys[:frame + 1]) + ln_et.set_data(cur_t, e_total[:frame + 1]) + ln_pw.set_data(cur_t, power[:frame + 1]) + if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1]) + if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1]) ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1])) - artists = wave_lines + time_texts + [line_ek, line_us, line_et, line_pw] - if line_ug: artists.append(line_ug) - if line_ugr: artists.append(line_ugr) + artists = wave_lines + time_texts + energy_lines + \ + [ln_ek, ln_us, ln_et, ln_pw] + if ln_ug: artists.append(ln_ug) + if ln_ugr: artists.append(ln_ugr) return artists ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True) - # ── 输出 GIF ── + # ── 输出文件 ── + gif_path = None if save_gif: gif_path = os.path.join(output_dir, "wave_animation.gif") ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5))) print(f"[plot_wave] GIF 已保存: {gif_path}") - # ── 输出 MP4(需要 ffmpeg)── - gif_path = None - if save_gif: - gif_path = os.path.join(output_dir, "wave_animation.gif") if save_mp4: try: import matplotlib.animation as manim - import matplotlib.pyplot as _plt - - # 尝试通过 imageio_ffmpeg 定位 ffmpeg ffmpeg_path = None try: import imageio_ffmpeg @@ -347,11 +480,10 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): except Exception: pass if ffmpeg_path and os.path.exists(ffmpeg_path): - _plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path - + plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path ffps = min(20, max(1, n_frames // 5)) writer = manim.FFMpegWriter(fps=ffps, codec="libx264", - extra_args=["-pix_fmt", "yuv420p"]) + extra_args=["-pix_fmt", "yuv420p"]) mp4_path = os.path.join(output_dir, "wave_animation.mp4") ani.save(mp4_path, writer=writer) print(f"[plot_wave] MP4 已保存: {mp4_path}") @@ -360,7 +492,6 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): except Exception as e: print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过") - # ── 最后显示动画窗口(仅直接运行时)── if show: plt.show() else: