""" plot_wave.py ============ 波形与能量动态图:读取 display.txt,绘制原子位移波形 (纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。 用法: python plot_wave.py # 使用 dynamics 根目录下 output/ python plot_wave.py examples/case05/output # 指定案例输出目录 """ import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation import os import sys import json sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import compute def load_disp_data(output_dir): """加载 display.npz(优先)或 display.txt。""" npz_path = os.path.join(output_dir, "display.npz") txt_path = os.path.join(output_dir, "display.txt") if os.path.exists(npz_path): return compute.load_display_npz(npz_path) if os.path.exists(txt_path): return compute.load_display_txt(txt_path) raise FileNotFoundError(f"找不到 display.npz 或 display.txt in {output_dir}") def _header_json(header_fields, key, default): raw = header_fields.get(key, "") if not raw: return default try: return json.loads(raw) except json.JSONDecodeError: return default def _load_wave_dataset(output_dir): """Load wave/energy plotting data from display metadata or sibling input files.""" disp_data = load_disp_data(output_dir) header = disp_data["header_fields"] x = disp_data["frames_x"] y = disp_data["frames_y"] z = disp_data["frames_z"] vx = disp_data["frames_vx"] vy = disp_data["frames_vy"] vz = disp_data["frames_vz"] atom_ids = np.array(disp_data["atom_ids"], dtype=np.int64) n_frames = x.shape[0] dt = float(header.get("DT", 0.001)) nstep = int(header.get("NSTEP", 1)) t = np.arange(n_frames, dtype=np.float64) * dt * nstep masses = np.array(_header_json(header, "atom_masses", []), dtype=np.float64) pos_0 = np.array(_header_json(header, "atom_positions", []), dtype=np.float64) bond_pairs = np.array(_header_json(header, "bond_pairs", []), dtype=np.int64) bond_stiffness = np.array(_header_json(header, "bond_stiffness", []), dtype=np.float64) bond_rest_lengths = np.array(_header_json(header, "bond_rest_lengths", []), dtype=np.float64) gravity_vec = _header_json(header, "G", [0.0, 0.0, 0.0]) # Backward-compatible fallback for older display.txt outputs. if masses.size == 0 or pos_0.size == 0: input_dir = os.path.join(os.path.dirname(output_dir), "input") coord_path = os.path.join(input_dir, "coord.txt") if os.path.exists(coord_path): (_, masses_fb, _, positions_fb, _, _) = compute.load_coord_file(coord_path) masses = np.array(masses_fb, dtype=np.float64) pos_0 = np.array(positions_fb, dtype=np.float64) else: raise ValueError("display.txt 缺少 atom_masses/atom_positions 元数据,且未找到 input/coord.txt") if bond_pairs.size == 0: input_dir = os.path.join(os.path.dirname(output_dir), "input") connection_path = os.path.join(input_dir, "connection.txt") bond_path = os.path.join(input_dir, "bond.txt") if os.path.exists(connection_path) and os.path.exists(bond_path): bond_map = compute.load_bond_parameters(bond_path) pairs_fb, _, stiffness_fb, rest_lengths_fb = compute.load_bond_connections( connection_path, atom_ids, pos_0, bond_map) bond_pairs = np.array(pairs_fb, dtype=np.int64) bond_stiffness = np.array(stiffness_fb, dtype=np.float64) bond_rest_lengths = np.array(rest_lengths_fb, dtype=np.float64) return { "n_frames": n_frames, "t": t, "x": x, "y": y, "z": z, "vx": vx, "vy": vy, "vz": vz, "pos_0": pos_0, "masses": masses, "atom_ids": atom_ids, "bond_pairs": bond_pairs, "bond_stiffness": bond_stiffness, "bond_rest_lengths": bond_rest_lengths, "gravity_field": int(header.get("gravity_field", 0)), "gravity_interaction": int(header.get("gravity_interaction", 0)), "gravity_strength": float(header.get("gravity_strength", 1.0)), "G": gravity_vec, "driving_force": int(header.get("driving_force", 0)), } def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr, bond_pairs, bond_stiffness, bond_rest_lengths, gravity_field, G, gravity_interaction, gravity_strength): """计算系统各能量分量。 Returns: ek_sys: 系统动能 (n_frames,) us_sys: 系统弹性势能 (n_frames,) ug_sys: 系统重力势能 (n_frames,) ugr_sys: 系统万有引力势能 (n_frames,) """ n_frames = x.shape[0] masses_2d = masses[np.newaxis, :] # (1, n_atoms) # 动能 Ek = ½ m v² ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) ek_sys = np.sum(ek, axis=1) # 弹性势能 Us = ½ k (d - d₀)² us_sys = np.zeros(n_frames) if bond_pairs is not None and len(bond_pairs) > 0: for b_idx in range(len(bond_pairs)): i, j = bond_pairs[b_idx] dx = x[:, j] - x[:, i] dy = y[:, j] - y[:, i] dz = z[:, j] - z[:, i] dist = np.sqrt(dx**2 + dy**2 + dz**2) stretch = dist - bond_rest_lengths[b_idx] us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2 # 均匀重力场势能 Ug = -m G·r ug_sys = np.zeros(n_frames) if gravity_field: G_vec = np.array(G) ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z) ug_sys = np.sum(ug_sys, axis=1) # 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r ugr_sys = np.zeros(n_frames) if gravity_interaction: n_atoms = len(masses) # 为避免巨大计算量,仅当原子数较少时计算 if n_atoms <= 200: for i in range(n_atoms): for j in range(i + 1, n_atoms): dx = x[:, j] - x[:, i] dy = y[:, j] - y[:, i] dz = z[:, j] - z[:, i] dist = np.sqrt(dx**2 + dy**2 + dz**2) dist = np.maximum(dist, 1e-12) pair_pe = -gravity_strength * masses[i] * masses[j] / dist ugr_sys += pair_pe return ek_sys, us_sys, ug_sys, ugr_sys def _load_driver_info(output_dir): """从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。 返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}。 """ input_dir = os.path.join(os.path.dirname(output_dir), "input") path = os.path.join(input_dir, "driver.txt") if not os.path.exists(path): return {} drivers = {} try: with open(path, encoding="utf-8") as f: f.readline() # skip header for line in f: line = line.strip() if not line or line.startswith("#"): continue parts = line.split() if len(parts) < 10: continue n = int(parts[0]) amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])]) freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])]) drivers[n] = {"amp": amp, "freq": freq} except Exception: pass return drivers def compute_per_atom_energy(x, y, z, vx, vy, vz, masses, bond_pairs, bond_stiffness, bond_rest_lengths, atom_ids, driver_info): """计算每帧每个粒子的动能、势能、总能。 势能分配规则: - 非驱动粒子之间的键:势能各分一半 - 键的一端是驱动粒子:势能全部归非驱动端 - 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算 (此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven) Returns: ek_atom: (n_frames, n_atoms) 每原子动能 pe_atom: (n_frames, n_atoms) 每原子势能 et_atom: (n_frames, n_atoms) 每原子总能 """ n_frames, n_atoms = x.shape # 动能 per atom masses_2d = masses[np.newaxis, :] ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms) # 构建 atom_id → index 映射,以及驱动原子 index 集合 id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)} driven_idx = set() for aid in driver_info: if aid in id_to_idx: driven_idx.add(id_to_idx[aid]) # 势能 per atom(弹簧键) pe_atom = np.zeros((n_frames, n_atoms)) if bond_pairs is not None and len(bond_pairs) > 0: for b in range(len(bond_pairs)): i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1]) ddx = x[:, j] - x[:, i] ddy = y[:, j] - y[:, i] ddz = z[:, j] - z[:, i] dist = np.sqrt(ddx**2 + ddy**2 + ddz**2) bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2 i_driven = i in driven_idx j_driven = j in driven_idx if i_driven and not j_driven: pe_atom[:, j] += bond_pe # 全归非驱动端 elif j_driven and not i_driven: pe_atom[:, i] += bond_pe # 全归非驱动端 else: pe_atom[:, i] += bond_pe * 0.5 pe_atom[:, j] += bond_pe * 0.5 # 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE TWO_PI = 2.0 * np.pi for aid, info in driver_info.items(): if aid not in id_to_idx: continue idx = id_to_idx[aid] m = masses[idx] amp = info["amp"] freq = info["freq"] e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2) pe_sho = e_sho - ek_atom[:, idx] pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负 et_atom = ek_atom + pe_atom return ek_atom, pe_atom, et_atom def compute_energy_flux(x, y, z, vx, vy, vz, bond_pairs, bond_stiffness, bond_rest_lengths): """计算每帧每根键的能流密度(Hardy 公式)。 对键 b 连接原子 i, j(j > i,方向 i→j): J_b = ½ · F_{b,i} · (v_i + v_j) 其中 F_{b,i} = k(d - r₀) * (r_j - r_i)/d 是键对原子 i 的弹力矢量, 点乘取两端速度均值。 - J > 0:能量从 i 流向 j(沿键方向正流) - J = 0:驻波,能量不流动 - 沿链从左到右,J 的分布揭示能量传播方向 Returns: flux: (n_frames, n_bonds) 每帧每键的能流(标量) bond_xpos: (n_bonds,) 各键中点的初始 x 坐标(用于绘图横轴) """ if bond_pairs is None or len(bond_pairs) == 0: n_frames = x.shape[0] return np.zeros((n_frames, 0)), np.zeros(0) n_bonds = len(bond_pairs) n_frames = x.shape[0] flux = np.zeros((n_frames, n_bonds)) # 键中点初始 x 坐标(用于横轴定位) bond_xpos = np.array([ 0.5 * (x[0, bond_pairs[b, 0]] + x[0, bond_pairs[b, 1]]) for b in range(n_bonds) ]) for b in range(n_bonds): i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1]) k = bond_stiffness[b] r0 = bond_rest_lengths[b] # 键矢量与长度 dx_ = x[:, j] - x[:, i] dy_ = y[:, j] - y[:, i] dz_ = z[:, j] - z[:, i] d = np.sqrt(dx_**2 + dy_**2 + dz_**2) d = np.maximum(d, 1e-12) # 弹力矢量(作用于原子 i,指向 j 方向) fac = k * (d - r0) / d # 标量因子 fx = fac * dx_ fy = fac * dy_ fz = fac * dz_ # 能流 = F_i · (v_i + v_j) / 2 flux[:, b] = 0.5 * (fx * (vx[:, i] + vx[:, j]) + fy * (vy[:, i] + vy[:, j]) + fz * (vz[:, i] + vz[:, j])) return flux, bond_xpos def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): """主绘图函数:读取 display.txt 并生成波形+能量动画。 布局(4行×1列,纵向排列): 行0:x/y/z 位移波形叠加在同一子图(vs 原子序号) 行1:每粒子动能、势能、总能叠加在同一子图 行2:键能流密度 J(Hardy 公式,vs 键中点位置) 行3:系统总能量随时间变化 Args: output_dir: 输出目录(含 display.npz 或 display.txt) save_gif: 是否保存 GIF save_mp4: 是否保存 MP4 show: 是否弹出交互窗口 """ data = _load_wave_dataset(output_dir) n_frames = int(data["n_frames"]) t = np.array(data["t"]) x = np.array(data["x"]) y = np.array(data["y"]) z = np.array(data["z"]) vx = np.array(data["vx"]) vy = np.array(data["vy"]) vz = np.array(data["vz"]) pos_0 = np.array(data["pos_0"]) masses = np.array(data["masses"]) atom_ids = np.array(data["atom_ids"]) n_atoms = len(atom_ids) bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64) bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64) bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64) gravity_field = int(data.get("gravity_field", 0)) gravity_interaction = int(data.get("gravity_interaction", 0)) G = data.get("G", [0, 0, 0]) gravity_strength = float(data.get("gravity_strength", 1.0)) driving_force = int(data.get("driving_force", 0)) # 驱动原子信息(用于势能计算) driver_info = _load_driver_info(output_dir) if driving_force else {} # ── 位移 ── dx = x - pos_0[np.newaxis, :, 0] dy = y - pos_0[np.newaxis, :, 1] dz = z - pos_0[np.newaxis, :, 2] # ── 系统总能量(用于右下时间图)── ek_sys, us_sys, ug_sys, ugr_sys = compute_energy( x, y, z, vx, vy, vz, masses, masses, bond_pairs, bond_stiffness, bond_rest_lengths, gravity_field, G, gravity_interaction, gravity_strength) e_total = ek_sys + us_sys + ug_sys + ugr_sys power = np.gradient(e_total, t) # ── 每粒子能量 ── ek_atom, pe_atom, et_atom = compute_per_atom_energy( x, y, z, vx, vy, vz, masses, bond_pairs, bond_stiffness, bond_rest_lengths, atom_ids, driver_info) # ── 能流密度 ── flux, bond_xpos = compute_energy_flux( x, y, z, vx, vy, vz, bond_pairs, bond_stiffness, bond_rest_lengths) # ── y 轴范围 ── def get_ylim(arr): vmax = np.max(np.abs(arr)) if vmax < 1e-10: return -1.0, 1.0 m = vmax * 0.2 return -vmax - m, vmax + m def get_ylim_pos(arr): vmax = np.max(arr) if vmax < 1e-12: return 0.0, 1.0 return 0.0, vmax * 1.2 # 共用 y 轴范围:位移图取三个方向最大值统一 disp_vmax = max(np.max(np.abs(dx)), np.max(np.abs(dy)), np.max(np.abs(dz))) disp_vmax = disp_vmax if disp_vmax > 1e-10 else 1.0 disp_ylim = (-disp_vmax * 1.2, disp_vmax * 1.2) # 每粒子能量:取三者最大值统一 y 轴 energy_vmax = max(np.max(ek_atom), np.max(pe_atom), np.max(et_atom)) energy_vmax = energy_vmax if energy_vmax > 1e-12 else 1.0 energy_ylim = (0.0, energy_vmax * 1.2) e_max = max(np.max(e_total), 0.01) * 1.3 p_max = max(np.max(np.abs(power)) * 1.3, 0.01) # 能流 y 轴范围(对称,正负各半) if flux.size > 0: flux_vmax = np.max(np.abs(flux)) flux_vmax = flux_vmax if flux_vmax > 1e-12 else 1.0 flux_ylim = (-flux_vmax * 1.2, flux_vmax * 1.2) else: flux_ylim = (-1.0, 1.0) atom_idx = np.arange(n_atoms) # ── 图形布局:4 行 × 1 列,纵向排列 ── plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False fig, (ax_wave, ax_energy, ax_flux, ax_ep) = plt.subplots(4, 1, figsize=(12, 18)) fig.suptitle("波形与能量分析", fontsize=16) fig.subplots_adjust(hspace=0.42, top=0.95) # ── 图1:x/y/z 位移波形叠加 ── ax_wave.set_xlim(0, n_atoms - 1) ax_wave.set_ylim(disp_ylim) ax_wave.set_xlabel("原子序号") ax_wave.set_ylabel("位移") ax_wave.set_title("粒子位移(x / y / z 方向)") ax_wave.grid(True, alpha=0.3) wave_disps = [dx, dy, dz] wave_labels = ["x 方向(纵波)", "y 方向(横波)", "z 方向(横波)"] wave_colors = ["#2563eb", "#ea580c", "#16a34a"] wave_lines = [] for label, color in zip(wave_labels, wave_colors): ln, = ax_wave.plot([], [], color=color, linewidth=1.5, label=label) wave_lines.append(ln) ax_wave.legend(loc="upper right", fontsize=9) time_text = ax_wave.text(0.02, 0.95, "", transform=ax_wave.transAxes, fontsize=10, verticalalignment="top") # ── 图2:每粒子动能、势能、总能叠加 ── ax_energy.set_xlim(0, n_atoms - 1) ax_energy.set_ylim(energy_ylim) ax_energy.set_xlabel("原子序号") ax_energy.set_ylabel("能量") ax_energy.set_title("每粒子能量(动能 / 势能 / 总能)") ax_energy.grid(True, alpha=0.3) energy_arrays = [ek_atom, pe_atom, et_atom] energy_labels = ["动能", "势能", "总能"] energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"] energy_lines = [] for label, color in zip(energy_labels, energy_colors): ln, = ax_energy.plot([], [], color=color, linewidth=1.5, label=label) energy_lines.append(ln) ax_energy.legend(loc="upper right", fontsize=9) # ── 图3:能流密度 J(Hardy 公式)── xmin_flux = bond_xpos[0] if len(bond_xpos) > 0 else 0 xmax_flux = bond_xpos[-1] if len(bond_xpos) > 0 else n_atoms - 1 ax_flux.set_xlim(xmin_flux, xmax_flux) ax_flux.set_ylim(flux_ylim) ax_flux.axhline(0, color="gray", linewidth=0.8, linestyle="--") ax_flux.set_xlabel("位置(键中点 x 坐标)") ax_flux.set_ylabel("能流密度 J") ax_flux.set_title("键能流密度 J = ½ F·(vᵢ+vⱼ) (J>0 向右传播,J<0 向左传播)") ax_flux.grid(True, alpha=0.3) flux_line, = ax_flux.plot([], [], color="#dc2626", linewidth=1.5) # ── 图4:系统总能量随时间 ── ax_ep.set_xlim(t[0], t[-1]) ep_yhigh = max(e_max, p_max) ep_ylow = min(-p_max * 0.1, 0.0) ax_ep.set_ylim(ep_ylow, ep_yhigh) ax_ep.set_xlabel("时间 (s)") ax_ep.set_ylabel("能量 / 功率") ax_ep.set_title("系统能量与输入功率") ax_ep.grid(True, alpha=0.3) ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能") ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能") ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量") ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)") ln_ug = None ln_ugr = None if gravity_field: ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能") if gravity_interaction and n_atoms <= 200: ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能") ax_ep.legend(loc="upper left", fontsize=9) # ── 动画更新 ── def update(frame): # 图1:位移波形 for i, ln in enumerate(wave_lines): ln.set_data(atom_idx, wave_disps[i][frame]) time_text.set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}") # 图2:每粒子能量 for i, ln in enumerate(energy_lines): ln.set_data(atom_idx, energy_arrays[i][frame]) # 图3:能流密度 if flux.shape[1] > 0: flux_line.set_data(bond_xpos, flux[frame]) # 图4:系统能量(累计到当前帧) cur_t = t[:frame + 1] ln_ek.set_data(cur_t, ek_sys[:frame + 1]) ln_us.set_data(cur_t, us_sys[:frame + 1]) ln_et.set_data(cur_t, e_total[:frame + 1]) ln_pw.set_data(cur_t, power[:frame + 1]) if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1]) if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1]) ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1])) artists = wave_lines + [time_text] + energy_lines + \ [flux_line, ln_ek, ln_us, ln_et, ln_pw] if ln_ug: artists.append(ln_ug) if ln_ugr: artists.append(ln_ugr) return artists ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True) # ── 输出文件 ── gif_path = None if save_gif: gif_path = os.path.join(output_dir, "wave_animation.gif") ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5))) print(f"[plot_wave] GIF 已保存: {gif_path}") if save_mp4: try: import matplotlib.animation as manim ffmpeg_path = None try: import imageio_ffmpeg ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() except Exception: pass if ffmpeg_path and os.path.exists(ffmpeg_path): plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path ffps = min(20, max(1, n_frames // 5)) writer = manim.FFMpegWriter(fps=ffps, codec="libx264", extra_args=["-pix_fmt", "yuv420p"]) mp4_path = os.path.join(output_dir, "wave_animation.mp4") ani.save(mp4_path, writer=writer) print(f"[plot_wave] MP4 已保存: {mp4_path}") except FileNotFoundError: print("[plot_wave] 警告: 未找到 ffmpeg,跳过 MP4 输出") except Exception as e: print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过") if show: plt.show() else: plt.close(fig) return gif_path if __name__ == "__main__": script_dir = os.path.dirname(os.path.abspath(__file__)) if len(sys.argv) > 1: output_dir = os.path.abspath(sys.argv[1]) else: output_dir = compute.get_output_dir(script_dir) plot_wave(output_dir)