""" plot_wave.py ============ 波形与能量动态图:读取 display.txt,绘制原子位移波形 (纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。 用法: python plot_wave.py # 使用 dynamics 根目录下 output/ python plot_wave.py examples/case05/output # 指定案例输出目录 """ import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation import os import sys import json sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import compute def load_disp_data(output_dir): """加载 display.npz(优先)或 display.txt。""" npz_path = os.path.join(output_dir, "display.npz") txt_path = os.path.join(output_dir, "display.txt") if os.path.exists(npz_path): return compute.load_display_npz(npz_path) if os.path.exists(txt_path): return compute.load_display_txt(txt_path) raise FileNotFoundError(f"找不到 display.npz 或 display.txt in {output_dir}") def _header_json(header_fields, key, default): raw = header_fields.get(key, "") if not raw: return default try: return json.loads(raw) except json.JSONDecodeError: return default def _load_wave_dataset(output_dir): """Load wave/energy plotting data from display metadata or sibling input files.""" disp_data = load_disp_data(output_dir) header = disp_data["header_fields"] x = disp_data["frames_x"] y = disp_data["frames_y"] z = disp_data["frames_z"] vx = disp_data["frames_vx"] vy = disp_data["frames_vy"] vz = disp_data["frames_vz"] atom_ids = np.array(disp_data["atom_ids"], dtype=np.int64) n_frames = x.shape[0] dt = float(header.get("DT", 0.001)) nstep = int(header.get("NSTEP", 1)) t = np.arange(n_frames, dtype=np.float64) * dt * nstep masses = np.array(_header_json(header, "atom_masses", []), dtype=np.float64) pos_0 = np.array(_header_json(header, "atom_positions", []), dtype=np.float64) bond_pairs = np.array(_header_json(header, "bond_pairs", []), dtype=np.int64) bond_stiffness = np.array(_header_json(header, "bond_stiffness", []), dtype=np.float64) bond_rest_lengths = np.array(_header_json(header, "bond_rest_lengths", []), dtype=np.float64) gravity_vec = _header_json(header, "G", [0.0, 0.0, 0.0]) # Backward-compatible fallback for older display.txt outputs. if masses.size == 0 or pos_0.size == 0: input_dir = os.path.join(os.path.dirname(output_dir), "input") coord_path = os.path.join(input_dir, "coord.txt") if os.path.exists(coord_path): (_, masses_fb, _, positions_fb, _, _) = compute.load_coord_file(coord_path) masses = np.array(masses_fb, dtype=np.float64) pos_0 = np.array(positions_fb, dtype=np.float64) else: raise ValueError("display.txt 缺少 atom_masses/atom_positions 元数据,且未找到 input/coord.txt") if bond_pairs.size == 0: input_dir = os.path.join(os.path.dirname(output_dir), "input") connection_path = os.path.join(input_dir, "connection.txt") bond_path = os.path.join(input_dir, "bond.txt") if os.path.exists(connection_path) and os.path.exists(bond_path): bond_map = compute.load_bond_parameters(bond_path) pairs_fb, _, stiffness_fb, rest_lengths_fb = compute.load_bond_connections( connection_path, atom_ids, pos_0, bond_map) bond_pairs = np.array(pairs_fb, dtype=np.int64) bond_stiffness = np.array(stiffness_fb, dtype=np.float64) bond_rest_lengths = np.array(rest_lengths_fb, dtype=np.float64) return { "n_frames": n_frames, "t": t, "x": x, "y": y, "z": z, "vx": vx, "vy": vy, "vz": vz, "pos_0": pos_0, "masses": masses, "atom_ids": atom_ids, "bond_pairs": bond_pairs, "bond_stiffness": bond_stiffness, "bond_rest_lengths": bond_rest_lengths, "gravity_field": int(header.get("gravity_field", 0)), "gravity_interaction": int(header.get("gravity_interaction", 0)), "gravity_strength": float(header.get("gravity_strength", 1.0)), "G": gravity_vec, "driving_force": int(header.get("driving_force", 0)), } def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr, bond_pairs, bond_stiffness, bond_rest_lengths, gravity_field, G, gravity_interaction, gravity_strength): """计算系统各能量分量。 Returns: ek_sys: 系统动能 (n_frames,) us_sys: 系统弹性势能 (n_frames,) ug_sys: 系统重力势能 (n_frames,) ugr_sys: 系统万有引力势能 (n_frames,) """ n_frames = x.shape[0] masses_2d = masses[np.newaxis, :] # (1, n_atoms) # 动能 Ek = ½ m v² ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) ek_sys = np.sum(ek, axis=1) # 弹性势能 Us = ½ k (d - d₀)² us_sys = np.zeros(n_frames) if bond_pairs is not None and len(bond_pairs) > 0: for b_idx in range(len(bond_pairs)): i, j = bond_pairs[b_idx] dx = x[:, j] - x[:, i] dy = y[:, j] - y[:, i] dz = z[:, j] - z[:, i] dist = np.sqrt(dx**2 + dy**2 + dz**2) stretch = dist - bond_rest_lengths[b_idx] us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2 # 均匀重力场势能 Ug = -m G·r ug_sys = np.zeros(n_frames) if gravity_field: G_vec = np.array(G) ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z) ug_sys = np.sum(ug_sys, axis=1) # 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r ugr_sys = np.zeros(n_frames) if gravity_interaction: n_atoms = len(masses) # 为避免巨大计算量,仅当原子数较少时计算 if n_atoms <= 200: for i in range(n_atoms): for j in range(i + 1, n_atoms): dx = x[:, j] - x[:, i] dy = y[:, j] - y[:, i] dz = z[:, j] - z[:, i] dist = np.sqrt(dx**2 + dy**2 + dz**2) dist = np.maximum(dist, 1e-12) pair_pe = -gravity_strength * masses[i] * masses[j] / dist ugr_sys += pair_pe return ek_sys, us_sys, ug_sys, ugr_sys def plot_wave(output_dir, save_gif=False, save_mp4=False): """主绘图函数:读取 display.txt 并生成波形+能量动画。 Args: output_dir: 输出目录(含 display.txt) save_gif: 是否保存 GIF save_mp4: 是否保存 MP4 """ data = _load_wave_dataset(output_dir) n_frames = int(data["n_frames"]) t = np.array(data["t"]) # 位置 / 速度 x = np.array(data["x"]) y = np.array(data["y"]) z = np.array(data["z"]) vx = np.array(data["vx"]) vy = np.array(data["vy"]) vz = np.array(data["vz"]) # 原子信息 pos_0 = np.array(data["pos_0"]) masses = np.array(data["masses"]) atom_ids = np.array(data["atom_ids"]) n_atoms = len(atom_ids) # 成键 bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64) bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64) bond_rest_lengths = np.array(data.get("bond_rest_lengths", []), dtype=np.float64) # 物理开关 gravity_field = int(data.get("gravity_field", 0)) gravity_interaction = int(data.get("gravity_interaction", 0)) G = data.get("G", [0, 0, 0]) gravity_strength = float(data.get("gravity_strength", 1.0)) driving_force = int(data.get("driving_force", 0)) # ── 位移(偏离初始平衡位形)── dx = x - pos_0[np.newaxis, :, 0] # 纵波(沿链方向 x) dy = y - pos_0[np.newaxis, :, 1] # 横波 1(y 方向) dz = z - pos_0[np.newaxis, :, 2] # 横波 2(z 方向) # ── 能量 ── ek, us, ug, ugr = compute_energy( x, y, z, vx, vy, vz, masses, masses, bond_pairs, bond_stiffness, bond_rest_lengths, gravity_field, G, gravity_interaction, gravity_strength) e_total = ek + us + ug + ugr power = np.gradient(e_total, t) # 输入功率 = dE/dt # ── 波形纵轴范围(全局统一,避免抖动)── def get_ylim(disp_data): vmax = np.max(np.abs(disp_data)) if vmax < 1e-10: return -1.0, 1.0 margin = vmax * 0.2 return -vmax - margin, vmax + margin ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)] e_max = max(np.max(e_total), 0.01) * 1.3 p_abs = np.max(np.abs(power)) p_max = max(p_abs * 1.3, 0.01) atom_idx = np.arange(n_atoms) # ── 图形设置 ── plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig.suptitle("波形与能量分析", fontsize=16) # ── 3 个波形子图 ── titles = [ "纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)", ] colors = ["#2563eb", "#ea580c", "#16a34a"] wave_configs = list(zip( [axes[0, 0], axes[0, 1], axes[1, 0]], [dx, dy, dz], titles, colors, ylims )) wave_lines = [] time_texts = [] for ax, disp, title, color, yl in wave_configs: ax.set_xlim(0, n_atoms - 1) ax.set_ylim(yl) ax.set_xlabel("原子序号") ax.set_ylabel("位移") ax.set_title(title) ax.grid(True, alpha=0.3) line, = ax.plot([], [], color=color, linewidth=1.5) wave_lines.append(line) tt = ax.text(0.02, 0.95, "", transform=ax.transAxes, fontsize=11, verticalalignment="top") time_texts.append(tt) # ── 能量+功率子图 ── ax_ep = axes[1, 1] ax_ep.set_xlim(t[0], t[-1]) ep_ylow = min(-0.1 * max(e_max, p_max), -p_max * 0.1) ep_yhigh = max(e_max, p_max) ax_ep.set_ylim(ep_ylow, ep_yhigh) ax_ep.set_xlabel("时间 (s)") ax_ep.set_ylabel("能量 / 功率") ax_ep.set_title("系统能量与输入功率") ax_ep.grid(True, alpha=0.3) line_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能") line_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能") line_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量") line_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)") if gravity_field: line_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能") else: line_ug = None if gravity_interaction and n_atoms <= 200: line_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能") else: line_ugr = None ax_ep.legend(loc="upper right", fontsize=9) plt.tight_layout(rect=[0, 0, 1, 0.95]) # ── 动画更新 ── def update(frame): # 波形 for i in range(3): wave_lines[i].set_data(atom_idx, [dx, dy, dz][i][frame]) time_texts[i].set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}") # 能量(累计到当前帧) cur_t = t[:frame + 1] line_ek.set_data(cur_t, ek[:frame + 1]) line_us.set_data(cur_t, us[:frame + 1]) line_et.set_data(cur_t, e_total[:frame + 1]) line_pw.set_data(cur_t, power[:frame + 1]) if line_ug: line_ug.set_data(cur_t, ug[:frame + 1]) if line_ugr: line_ugr.set_data(cur_t, ugr[:frame + 1]) # 能量图 x 轴动态扩展 ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1])) artists = wave_lines + time_texts + [line_ek, line_us, line_et, line_pw] if line_ug: artists.append(line_ug) if line_ugr: artists.append(line_ugr) return artists ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True) # ── 输出 GIF ── if save_gif: gif_path = os.path.join(output_dir, "wave_animation.gif") ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5))) print(f"[plot_wave] GIF 已保存: {gif_path}") # ── 输出 MP4(需要 ffmpeg)── gif_path = None if save_gif: gif_path = os.path.join(output_dir, "wave_animation.gif") if save_mp4: try: import matplotlib.animation as manim import matplotlib.pyplot as _plt # 尝试通过 imageio_ffmpeg 定位 ffmpeg ffmpeg_path = None try: import imageio_ffmpeg ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() except Exception: pass if ffmpeg_path and os.path.exists(ffmpeg_path): _plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path ffps = min(20, max(1, n_frames // 5)) writer = manim.FFMpegWriter(fps=ffps, codec="libx264", extra_args=["-pix_fmt", "yuv420p"]) mp4_path = os.path.join(output_dir, "wave_animation.mp4") ani.save(mp4_path, writer=writer) print(f"[plot_wave] MP4 已保存: {mp4_path}") except FileNotFoundError: print("[plot_wave] 警告: 未找到 ffmpeg,跳过 MP4 输出") except Exception as e: print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过") # ── 最后显示动画窗口 ── plt.show() return gif_path if __name__ == "__main__": script_dir = os.path.dirname(os.path.abspath(__file__)) if len(sys.argv) > 1: output_dir = os.path.abspath(sys.argv[1]) else: output_dir = compute.get_output_dir(script_dir) plot_wave(output_dir)