Files
dynamics/plot_wave.py
T
admin b584c4489c refactor: merge wave/energy panels into 3 vertical subplots
- Plot 1: x/y/z displacements overlaid on one axes
- Plot 2: per-atom KE/PE/total energy overlaid on one axes
- Plot 3: system energy vs time (unchanged)
All three stacked vertically. Shared y-axis scale within each panel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-13 08:32:41 +08:00

494 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
plot_wave.py
============
波形与能量动态图:读取 display.txt,绘制原子位移波形
(纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。
用法:
python plot_wave.py # 使用 dynamics 根目录下 output/
python plot_wave.py examples/case05/output # 指定案例输出目录
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import sys
import json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def load_disp_data(output_dir):
"""加载 display.npz(优先)或 display.txt。"""
npz_path = os.path.join(output_dir, "display.npz")
txt_path = os.path.join(output_dir, "display.txt")
if os.path.exists(npz_path):
return compute.load_display_npz(npz_path)
if os.path.exists(txt_path):
return compute.load_display_txt(txt_path)
raise FileNotFoundError(f"找不到 display.npz 或 display.txt in {output_dir}")
def _header_json(header_fields, key, default):
raw = header_fields.get(key, "")
if not raw:
return default
try:
return json.loads(raw)
except json.JSONDecodeError:
return default
def _load_wave_dataset(output_dir):
"""Load wave/energy plotting data from display metadata or sibling input files."""
disp_data = load_disp_data(output_dir)
header = disp_data["header_fields"]
x = disp_data["frames_x"]
y = disp_data["frames_y"]
z = disp_data["frames_z"]
vx = disp_data["frames_vx"]
vy = disp_data["frames_vy"]
vz = disp_data["frames_vz"]
atom_ids = np.array(disp_data["atom_ids"], dtype=np.int64)
n_frames = x.shape[0]
dt = float(header.get("DT", 0.001))
nstep = int(header.get("NSTEP", 1))
t = np.arange(n_frames, dtype=np.float64) * dt * nstep
masses = np.array(_header_json(header, "atom_masses", []), dtype=np.float64)
pos_0 = np.array(_header_json(header, "atom_positions", []), dtype=np.float64)
bond_pairs = np.array(_header_json(header, "bond_pairs", []), dtype=np.int64)
bond_stiffness = np.array(_header_json(header, "bond_stiffness", []), dtype=np.float64)
bond_rest_lengths = np.array(_header_json(header, "bond_rest_lengths", []), dtype=np.float64)
gravity_vec = _header_json(header, "G", [0.0, 0.0, 0.0])
# Backward-compatible fallback for older display.txt outputs.
if masses.size == 0 or pos_0.size == 0:
input_dir = os.path.join(os.path.dirname(output_dir), "input")
coord_path = os.path.join(input_dir, "coord.txt")
if os.path.exists(coord_path):
(_, masses_fb, _, positions_fb, _, _) = compute.load_coord_file(coord_path)
masses = np.array(masses_fb, dtype=np.float64)
pos_0 = np.array(positions_fb, dtype=np.float64)
else:
raise ValueError("display.txt 缺少 atom_masses/atom_positions 元数据,且未找到 input/coord.txt")
if bond_pairs.size == 0:
input_dir = os.path.join(os.path.dirname(output_dir), "input")
connection_path = os.path.join(input_dir, "connection.txt")
bond_path = os.path.join(input_dir, "bond.txt")
if os.path.exists(connection_path) and os.path.exists(bond_path):
bond_map = compute.load_bond_parameters(bond_path)
pairs_fb, _, stiffness_fb, rest_lengths_fb = compute.load_bond_connections(
connection_path, atom_ids, pos_0, bond_map)
bond_pairs = np.array(pairs_fb, dtype=np.int64)
bond_stiffness = np.array(stiffness_fb, dtype=np.float64)
bond_rest_lengths = np.array(rest_lengths_fb, dtype=np.float64)
return {
"n_frames": n_frames,
"t": t,
"x": x,
"y": y,
"z": z,
"vx": vx,
"vy": vy,
"vz": vz,
"pos_0": pos_0,
"masses": masses,
"atom_ids": atom_ids,
"bond_pairs": bond_pairs,
"bond_stiffness": bond_stiffness,
"bond_rest_lengths": bond_rest_lengths,
"gravity_field": int(header.get("gravity_field", 0)),
"gravity_interaction": int(header.get("gravity_interaction", 0)),
"gravity_strength": float(header.get("gravity_strength", 1.0)),
"G": gravity_vec,
"driving_force": int(header.get("driving_force", 0)),
}
def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength):
"""计算系统各能量分量。
Returns:
ek_sys: 系统动能 (n_frames,)
us_sys: 系统弹性势能 (n_frames,)
ug_sys: 系统重力势能 (n_frames,)
ugr_sys: 系统万有引力势能 (n_frames,)
"""
n_frames = x.shape[0]
masses_2d = masses[np.newaxis, :] # (1, n_atoms)
# 动能 Ek = ½ m v²
ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2)
ek_sys = np.sum(ek, axis=1)
# 弹性势能 Us = ½ k (d - d₀)²
us_sys = np.zeros(n_frames)
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = x[:, j] - x[:, i]
dy = y[:, j] - y[:, i]
dz = z[:, j] - z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2
# 均匀重力场势能 Ug = -m G·r
ug_sys = np.zeros(n_frames)
if gravity_field:
G_vec = np.array(G)
ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z)
ug_sys = np.sum(ug_sys, axis=1)
# 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r
ugr_sys = np.zeros(n_frames)
if gravity_interaction:
n_atoms = len(masses)
# 为避免巨大计算量,仅当原子数较少时计算
if n_atoms <= 200:
for i in range(n_atoms):
for j in range(i + 1, n_atoms):
dx = x[:, j] - x[:, i]
dy = y[:, j] - y[:, i]
dz = z[:, j] - z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
dist = np.maximum(dist, 1e-12)
pair_pe = -gravity_strength * masses[i] * masses[j] / dist
ugr_sys += pair_pe
return ek_sys, us_sys, ug_sys, ugr_sys
def _load_driver_info(output_dir):
"""从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。
返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}
"""
input_dir = os.path.join(os.path.dirname(output_dir), "input")
path = os.path.join(input_dir, "driver.txt")
if not os.path.exists(path):
return {}
drivers = {}
try:
with open(path, encoding="utf-8") as f:
f.readline() # skip header
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) < 10:
continue
n = int(parts[0])
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])])
drivers[n] = {"amp": amp, "freq": freq}
except Exception:
pass
return drivers
def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
atom_ids, driver_info):
"""计算每帧每个粒子的动能、势能、总能。
势能分配规则:
- 非驱动粒子之间的键:势能各分一半
- 键的一端是驱动粒子:势能全部归非驱动端
- 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算
(此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven
Returns:
ek_atom: (n_frames, n_atoms) 每原子动能
pe_atom: (n_frames, n_atoms) 每原子势能
et_atom: (n_frames, n_atoms) 每原子总能
"""
n_frames, n_atoms = x.shape
# 动能 per atom
masses_2d = masses[np.newaxis, :]
ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms)
# 构建 atom_id → index 映射,以及驱动原子 index 集合
id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)}
driven_idx = set()
for aid in driver_info:
if aid in id_to_idx:
driven_idx.add(id_to_idx[aid])
# 势能 per atom(弹簧键)
pe_atom = np.zeros((n_frames, n_atoms))
if bond_pairs is not None and len(bond_pairs) > 0:
for b in range(len(bond_pairs)):
i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1])
ddx = x[:, j] - x[:, i]
ddy = y[:, j] - y[:, i]
ddz = z[:, j] - z[:, i]
dist = np.sqrt(ddx**2 + ddy**2 + ddz**2)
bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2
i_driven = i in driven_idx
j_driven = j in driven_idx
if i_driven and not j_driven:
pe_atom[:, j] += bond_pe # 全归非驱动端
elif j_driven and not i_driven:
pe_atom[:, i] += bond_pe # 全归非驱动端
else:
pe_atom[:, i] += bond_pe * 0.5
pe_atom[:, j] += bond_pe * 0.5
# 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE
TWO_PI = 2.0 * np.pi
for aid, info in driver_info.items():
if aid not in id_to_idx:
continue
idx = id_to_idx[aid]
m = masses[idx]
amp = info["amp"]
freq = info["freq"]
e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2)
pe_sho = e_sho - ek_atom[:, idx]
pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负
et_atom = ek_atom + pe_atom
return ek_atom, pe_atom, et_atom
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
布局(3行×1列,纵向排列):
行0:x/y/z 位移波形叠加在同一子图(vs 原子序号)
行1:每粒子动能、势能、总能叠加在同一子图
行2:系统总能量随时间变化
Args:
output_dir: 输出目录(含 display.npz 或 display.txt
save_gif: 是否保存 GIF
save_mp4: 是否保存 MP4
show: 是否弹出交互窗口
"""
data = _load_wave_dataset(output_dir)
n_frames = int(data["n_frames"])
t = np.array(data["t"])
x = np.array(data["x"])
y = np.array(data["y"])
z = np.array(data["z"])
vx = np.array(data["vx"])
vy = np.array(data["vy"])
vz = np.array(data["vz"])
pos_0 = np.array(data["pos_0"])
masses = np.array(data["masses"])
atom_ids = np.array(data["atom_ids"])
n_atoms = len(atom_ids)
bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64)
bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64)
bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64)
gravity_field = int(data.get("gravity_field", 0))
gravity_interaction = int(data.get("gravity_interaction", 0))
G = data.get("G", [0, 0, 0])
gravity_strength = float(data.get("gravity_strength", 1.0))
driving_force = int(data.get("driving_force", 0))
# 驱动原子信息(用于势能计算)
driver_info = _load_driver_info(output_dir) if driving_force else {}
# ── 位移 ──
dx = x - pos_0[np.newaxis, :, 0]
dy = y - pos_0[np.newaxis, :, 1]
dz = z - pos_0[np.newaxis, :, 2]
# ── 系统总能量(用于右下时间图)──
ek_sys, us_sys, ug_sys, ugr_sys = compute_energy(
x, y, z, vx, vy, vz, masses, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength)
e_total = ek_sys + us_sys + ug_sys + ugr_sys
power = np.gradient(e_total, t)
# ── 每粒子能量 ──
ek_atom, pe_atom, et_atom = compute_per_atom_energy(
x, y, z, vx, vy, vz, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
atom_ids, driver_info)
# ── y 轴范围 ──
def get_ylim(arr):
vmax = np.max(np.abs(arr))
if vmax < 1e-10:
return -1.0, 1.0
m = vmax * 0.2
return -vmax - m, vmax + m
def get_ylim_pos(arr):
vmax = np.max(arr)
if vmax < 1e-12:
return 0.0, 1.0
return 0.0, vmax * 1.2
# 共用 y 轴范围:位移图取三个方向最大值统一
disp_vmax = max(np.max(np.abs(dx)), np.max(np.abs(dy)), np.max(np.abs(dz)))
disp_vmax = disp_vmax if disp_vmax > 1e-10 else 1.0
disp_ylim = (-disp_vmax * 1.2, disp_vmax * 1.2)
# 每粒子能量:取三者最大值统一 y 轴
energy_vmax = max(np.max(ek_atom), np.max(pe_atom), np.max(et_atom))
energy_vmax = energy_vmax if energy_vmax > 1e-12 else 1.0
energy_ylim = (0.0, energy_vmax * 1.2)
e_max = max(np.max(e_total), 0.01) * 1.3
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
atom_idx = np.arange(n_atoms)
# ── 图形布局:3 行 × 1 列,纵向排列 ──
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig, (ax_wave, ax_energy, ax_ep) = plt.subplots(3, 1, figsize=(12, 14))
fig.suptitle("波形与能量分析", fontsize=16)
fig.subplots_adjust(hspace=0.40, top=0.94)
# ── 图1:x/y/z 位移波形叠加 ──
ax_wave.set_xlim(0, n_atoms - 1)
ax_wave.set_ylim(disp_ylim)
ax_wave.set_xlabel("原子序号")
ax_wave.set_ylabel("位移")
ax_wave.set_title("粒子位移(x / y / z 方向)")
ax_wave.grid(True, alpha=0.3)
wave_disps = [dx, dy, dz]
wave_labels = ["x 方向(纵波)", "y 方向(横波)", "z 方向(横波)"]
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
wave_lines = []
for label, color in zip(wave_labels, wave_colors):
ln, = ax_wave.plot([], [], color=color, linewidth=1.5, label=label)
wave_lines.append(ln)
ax_wave.legend(loc="upper right", fontsize=9)
time_text = ax_wave.text(0.02, 0.95, "", transform=ax_wave.transAxes,
fontsize=10, verticalalignment="top")
# ── 图2:每粒子动能、势能、总能叠加 ──
ax_energy.set_xlim(0, n_atoms - 1)
ax_energy.set_ylim(energy_ylim)
ax_energy.set_xlabel("原子序号")
ax_energy.set_ylabel("能量")
ax_energy.set_title("每粒子能量(动能 / 势能 / 总能)")
ax_energy.grid(True, alpha=0.3)
energy_arrays = [ek_atom, pe_atom, et_atom]
energy_labels = ["动能", "势能", "总能"]
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
energy_lines = []
for label, color in zip(energy_labels, energy_colors):
ln, = ax_energy.plot([], [], color=color, linewidth=1.5, label=label)
energy_lines.append(ln)
ax_energy.legend(loc="upper right", fontsize=9)
# ── 图3:系统总能量随时间 ──
ax_ep.set_xlim(t[0], t[-1])
ep_yhigh = max(e_max, p_max)
ep_ylow = min(-p_max * 0.1, 0.0)
ax_ep.set_ylim(ep_ylow, ep_yhigh)
ax_ep.set_xlabel("时间 (s)")
ax_ep.set_ylabel("能量 / 功率")
ax_ep.set_title("系统能量与输入功率")
ax_ep.grid(True, alpha=0.3)
ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
ln_ug = None
ln_ugr = None
if gravity_field:
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
if gravity_interaction and n_atoms <= 200:
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
ax_ep.legend(loc="upper left", fontsize=9)
# ── 动画更新 ──
def update(frame):
# 图1:位移波形
for i, ln in enumerate(wave_lines):
ln.set_data(atom_idx, wave_disps[i][frame])
time_text.set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}")
# 图2:每粒子能量
for i, ln in enumerate(energy_lines):
ln.set_data(atom_idx, energy_arrays[i][frame])
# 图3:系统能量(累计到当前帧)
cur_t = t[:frame + 1]
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
ln_us.set_data(cur_t, us_sys[:frame + 1])
ln_et.set_data(cur_t, e_total[:frame + 1])
ln_pw.set_data(cur_t, power[:frame + 1])
if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1])
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
artists = wave_lines + [time_text] + energy_lines + \
[ln_ek, ln_us, ln_et, ln_pw]
if ln_ug: artists.append(ln_ug)
if ln_ugr: artists.append(ln_ugr)
return artists
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
# ── 输出文件 ──
gif_path = None
if save_gif:
gif_path = os.path.join(output_dir, "wave_animation.gif")
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
print(f"[plot_wave] GIF 已保存: {gif_path}")
if save_mp4:
try:
import matplotlib.animation as manim
ffmpeg_path = None
try:
import imageio_ffmpeg
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
except Exception:
pass
if ffmpeg_path and os.path.exists(ffmpeg_path):
plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
ffps = min(20, max(1, n_frames // 5))
writer = manim.FFMpegWriter(fps=ffps, codec="libx264",
extra_args=["-pix_fmt", "yuv420p"])
mp4_path = os.path.join(output_dir, "wave_animation.mp4")
ani.save(mp4_path, writer=writer)
print(f"[plot_wave] MP4 已保存: {mp4_path}")
except FileNotFoundError:
print("[plot_wave] 警告: 未找到 ffmpeg,跳过 MP4 输出")
except Exception as e:
print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过")
if show:
plt.show()
else:
plt.close(fig)
return gif_path
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
if len(sys.argv) > 1:
output_dir = os.path.abspath(sys.argv[1])
else:
output_dir = compute.get_output_dir(script_dir)
plot_wave(output_dir)