Files
dynamics/plot_wave.py
T
admin 1cefe184d7 fix: plot_wave plt.show() crash when called non-interactively
Add show=False parameter to plot_wave(); when called from dynamics.py,
pass show=False and set matplotlib Agg backend to avoid NonGuiException.
Also print full traceback on failure for easier debugging.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-13 08:10:16 +08:00

378 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
plot_wave.py
============
波形与能量动态图:读取 display.txt,绘制原子位移波形
(纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。
用法:
python plot_wave.py # 使用 dynamics 根目录下 output/
python plot_wave.py examples/case05/output # 指定案例输出目录
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import sys
import json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def load_disp_data(output_dir):
"""加载 display.npz(优先)或 display.txt。"""
npz_path = os.path.join(output_dir, "display.npz")
txt_path = os.path.join(output_dir, "display.txt")
if os.path.exists(npz_path):
return compute.load_display_npz(npz_path)
if os.path.exists(txt_path):
return compute.load_display_txt(txt_path)
raise FileNotFoundError(f"找不到 display.npz 或 display.txt in {output_dir}")
def _header_json(header_fields, key, default):
raw = header_fields.get(key, "")
if not raw:
return default
try:
return json.loads(raw)
except json.JSONDecodeError:
return default
def _load_wave_dataset(output_dir):
"""Load wave/energy plotting data from display metadata or sibling input files."""
disp_data = load_disp_data(output_dir)
header = disp_data["header_fields"]
x = disp_data["frames_x"]
y = disp_data["frames_y"]
z = disp_data["frames_z"]
vx = disp_data["frames_vx"]
vy = disp_data["frames_vy"]
vz = disp_data["frames_vz"]
atom_ids = np.array(disp_data["atom_ids"], dtype=np.int64)
n_frames = x.shape[0]
dt = float(header.get("DT", 0.001))
nstep = int(header.get("NSTEP", 1))
t = np.arange(n_frames, dtype=np.float64) * dt * nstep
masses = np.array(_header_json(header, "atom_masses", []), dtype=np.float64)
pos_0 = np.array(_header_json(header, "atom_positions", []), dtype=np.float64)
bond_pairs = np.array(_header_json(header, "bond_pairs", []), dtype=np.int64)
bond_stiffness = np.array(_header_json(header, "bond_stiffness", []), dtype=np.float64)
bond_rest_lengths = np.array(_header_json(header, "bond_rest_lengths", []), dtype=np.float64)
gravity_vec = _header_json(header, "G", [0.0, 0.0, 0.0])
# Backward-compatible fallback for older display.txt outputs.
if masses.size == 0 or pos_0.size == 0:
input_dir = os.path.join(os.path.dirname(output_dir), "input")
coord_path = os.path.join(input_dir, "coord.txt")
if os.path.exists(coord_path):
(_, masses_fb, _, positions_fb, _, _) = compute.load_coord_file(coord_path)
masses = np.array(masses_fb, dtype=np.float64)
pos_0 = np.array(positions_fb, dtype=np.float64)
else:
raise ValueError("display.txt 缺少 atom_masses/atom_positions 元数据,且未找到 input/coord.txt")
if bond_pairs.size == 0:
input_dir = os.path.join(os.path.dirname(output_dir), "input")
connection_path = os.path.join(input_dir, "connection.txt")
bond_path = os.path.join(input_dir, "bond.txt")
if os.path.exists(connection_path) and os.path.exists(bond_path):
bond_map = compute.load_bond_parameters(bond_path)
pairs_fb, _, stiffness_fb, rest_lengths_fb = compute.load_bond_connections(
connection_path, atom_ids, pos_0, bond_map)
bond_pairs = np.array(pairs_fb, dtype=np.int64)
bond_stiffness = np.array(stiffness_fb, dtype=np.float64)
bond_rest_lengths = np.array(rest_lengths_fb, dtype=np.float64)
return {
"n_frames": n_frames,
"t": t,
"x": x,
"y": y,
"z": z,
"vx": vx,
"vy": vy,
"vz": vz,
"pos_0": pos_0,
"masses": masses,
"atom_ids": atom_ids,
"bond_pairs": bond_pairs,
"bond_stiffness": bond_stiffness,
"bond_rest_lengths": bond_rest_lengths,
"gravity_field": int(header.get("gravity_field", 0)),
"gravity_interaction": int(header.get("gravity_interaction", 0)),
"gravity_strength": float(header.get("gravity_strength", 1.0)),
"G": gravity_vec,
"driving_force": int(header.get("driving_force", 0)),
}
def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength):
"""计算系统各能量分量。
Returns:
ek_sys: 系统动能 (n_frames,)
us_sys: 系统弹性势能 (n_frames,)
ug_sys: 系统重力势能 (n_frames,)
ugr_sys: 系统万有引力势能 (n_frames,)
"""
n_frames = x.shape[0]
masses_2d = masses[np.newaxis, :] # (1, n_atoms)
# 动能 Ek = ½ m v²
ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2)
ek_sys = np.sum(ek, axis=1)
# 弹性势能 Us = ½ k (d - d₀)²
us_sys = np.zeros(n_frames)
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = x[:, j] - x[:, i]
dy = y[:, j] - y[:, i]
dz = z[:, j] - z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2
# 均匀重力场势能 Ug = -m G·r
ug_sys = np.zeros(n_frames)
if gravity_field:
G_vec = np.array(G)
ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z)
ug_sys = np.sum(ug_sys, axis=1)
# 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r
ugr_sys = np.zeros(n_frames)
if gravity_interaction:
n_atoms = len(masses)
# 为避免巨大计算量,仅当原子数较少时计算
if n_atoms <= 200:
for i in range(n_atoms):
for j in range(i + 1, n_atoms):
dx = x[:, j] - x[:, i]
dy = y[:, j] - y[:, i]
dz = z[:, j] - z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
dist = np.maximum(dist, 1e-12)
pair_pe = -gravity_strength * masses[i] * masses[j] / dist
ugr_sys += pair_pe
return ek_sys, us_sys, ug_sys, ugr_sys
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
Args:
output_dir: 输出目录(含 display.txt
save_gif: 是否保存 GIF
save_mp4: 是否保存 MP4
"""
data = _load_wave_dataset(output_dir)
n_frames = int(data["n_frames"])
t = np.array(data["t"])
# 位置 / 速度
x = np.array(data["x"])
y = np.array(data["y"])
z = np.array(data["z"])
vx = np.array(data["vx"])
vy = np.array(data["vy"])
vz = np.array(data["vz"])
# 原子信息
pos_0 = np.array(data["pos_0"])
masses = np.array(data["masses"])
atom_ids = np.array(data["atom_ids"])
n_atoms = len(atom_ids)
# 成键
bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64)
bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64)
bond_rest_lengths = np.array(data.get("bond_rest_lengths", []), dtype=np.float64)
# 物理开关
gravity_field = int(data.get("gravity_field", 0))
gravity_interaction = int(data.get("gravity_interaction", 0))
G = data.get("G", [0, 0, 0])
gravity_strength = float(data.get("gravity_strength", 1.0))
driving_force = int(data.get("driving_force", 0))
# ── 位移(偏离初始平衡位形)──
dx = x - pos_0[np.newaxis, :, 0] # 纵波(沿链方向 x
dy = y - pos_0[np.newaxis, :, 1] # 横波 1y 方向)
dz = z - pos_0[np.newaxis, :, 2] # 横波 2z 方向)
# ── 能量 ──
ek, us, ug, ugr = compute_energy(
x, y, z, vx, vy, vz, masses, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength)
e_total = ek + us + ug + ugr
power = np.gradient(e_total, t) # 输入功率 = dE/dt
# ── 波形纵轴范围(全局统一,避免抖动)──
def get_ylim(disp_data):
vmax = np.max(np.abs(disp_data))
if vmax < 1e-10:
return -1.0, 1.0
margin = vmax * 0.2
return -vmax - margin, vmax + margin
ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
e_max = max(np.max(e_total), 0.01) * 1.3
p_abs = np.max(np.abs(power))
p_max = max(p_abs * 1.3, 0.01)
atom_idx = np.arange(n_atoms)
# ── 图形设置 ──
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("波形与能量分析", fontsize=16)
# ── 3 个波形子图 ──
titles = [
"纵波 (x 方向位移)",
"横波 (y 方向位移)",
"横波 (z 方向位移)",
]
colors = ["#2563eb", "#ea580c", "#16a34a"]
wave_configs = list(zip(
[axes[0, 0], axes[0, 1], axes[1, 0]],
[dx, dy, dz], titles, colors, ylims
))
wave_lines = []
time_texts = []
for ax, disp, title, color, yl in wave_configs:
ax.set_xlim(0, n_atoms - 1)
ax.set_ylim(yl)
ax.set_xlabel("原子序号")
ax.set_ylabel("位移")
ax.set_title(title)
ax.grid(True, alpha=0.3)
line, = ax.plot([], [], color=color, linewidth=1.5)
wave_lines.append(line)
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
fontsize=11, verticalalignment="top")
time_texts.append(tt)
# ── 能量+功率子图 ──
ax_ep = axes[1, 1]
ax_ep.set_xlim(t[0], t[-1])
ep_ylow = min(-0.1 * max(e_max, p_max), -p_max * 0.1)
ep_yhigh = max(e_max, p_max)
ax_ep.set_ylim(ep_ylow, ep_yhigh)
ax_ep.set_xlabel("时间 (s)")
ax_ep.set_ylabel("能量 / 功率")
ax_ep.set_title("系统能量与输入功率")
ax_ep.grid(True, alpha=0.3)
line_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
line_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
line_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
line_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
if gravity_field:
line_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
else:
line_ug = None
if gravity_interaction and n_atoms <= 200:
line_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
else:
line_ugr = None
ax_ep.legend(loc="upper right", fontsize=9)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# ── 动画更新 ──
def update(frame):
# 波形
for i in range(3):
wave_lines[i].set_data(atom_idx, [dx, dy, dz][i][frame])
time_texts[i].set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}")
# 能量(累计到当前帧)
cur_t = t[:frame + 1]
line_ek.set_data(cur_t, ek[:frame + 1])
line_us.set_data(cur_t, us[:frame + 1])
line_et.set_data(cur_t, e_total[:frame + 1])
line_pw.set_data(cur_t, power[:frame + 1])
if line_ug:
line_ug.set_data(cur_t, ug[:frame + 1])
if line_ugr:
line_ugr.set_data(cur_t, ugr[:frame + 1])
# 能量图 x 轴动态扩展
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
artists = wave_lines + time_texts + [line_ek, line_us, line_et, line_pw]
if line_ug: artists.append(line_ug)
if line_ugr: artists.append(line_ugr)
return artists
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
# ── 输出 GIF ──
if save_gif:
gif_path = os.path.join(output_dir, "wave_animation.gif")
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
print(f"[plot_wave] GIF 已保存: {gif_path}")
# ── 输出 MP4(需要 ffmpeg)──
gif_path = None
if save_gif:
gif_path = os.path.join(output_dir, "wave_animation.gif")
if save_mp4:
try:
import matplotlib.animation as manim
import matplotlib.pyplot as _plt
# 尝试通过 imageio_ffmpeg 定位 ffmpeg
ffmpeg_path = None
try:
import imageio_ffmpeg
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
except Exception:
pass
if ffmpeg_path and os.path.exists(ffmpeg_path):
_plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
ffps = min(20, max(1, n_frames // 5))
writer = manim.FFMpegWriter(fps=ffps, codec="libx264",
extra_args=["-pix_fmt", "yuv420p"])
mp4_path = os.path.join(output_dir, "wave_animation.mp4")
ani.save(mp4_path, writer=writer)
print(f"[plot_wave] MP4 已保存: {mp4_path}")
except FileNotFoundError:
print("[plot_wave] 警告: 未找到 ffmpeg,跳过 MP4 输出")
except Exception as e:
print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过")
# ── 最后显示动画窗口(仅直接运行时)──
if show:
plt.show()
else:
plt.close(fig)
return gif_path
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
if len(sys.argv) > 1:
output_dir = os.path.abspath(sys.argv[1])
else:
output_dir = compute.get_output_dir(script_dir)
plot_wave(output_dir)