Files
dynamics/plot_wave.py
T
admin 80520590d1 feat: 新增波形能量动画系统 plot_wave.py
- 创建 plot_wave.py: 从 display.txt 读取原子位移数据
  绘制纵波(x) + 横波(y) + 横波(z) 波形随时间的动画
  同时绘制系统动能/弹性势能/总能量/输入功率(dE/dt)时变曲线
  输出 wave_animation.gif
- 所有 input.txt 新增 step_plot_wave: 0 开关
- case05 开启 step_plot_wave: 1
- dynamics.py disp_data 新增 bond_stiffness/bond_rest_lengths
- 更新案例文档
2026-06-11 12:39:46 +08:00

261 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
plot_wave.py
============
波形与能量动态图:读取 display.txt,绘制原子位移波形
(纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。
用法:
python plot_wave.py # 使用 dynamics 根目录下 output/
python plot_wave.py examples/case05/output # 指定案例输出目录
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
def load_disp_data(output_dir):
"""加载 display.txt"""
disp_path = os.path.join(output_dir, "display.txt")
if not os.path.exists(disp_path):
raise FileNotFoundError(f"找不到 {disp_path}")
return compute.load_text_data(disp_path)
def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength):
"""计算系统各能量分量。
Returns:
ek_sys: 系统动能 (n_frames,)
us_sys: 系统弹性势能 (n_frames,)
ug_sys: 系统重力势能 (n_frames,)
ugr_sys: 系统万有引力势能 (n_frames,)
"""
n_frames = x.shape[0]
masses_2d = masses[np.newaxis, :] # (1, n_atoms)
# 动能 Ek = ½ m v²
ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2)
ek_sys = np.sum(ek, axis=1)
# 弹性势能 Us = ½ k (d - d₀)²
us_sys = np.zeros(n_frames)
if bond_pairs is not None and len(bond_pairs) > 0:
for b_idx in range(len(bond_pairs)):
i, j = bond_pairs[b_idx]
dx = x[:, j] - x[:, i]
dy = y[:, j] - y[:, i]
dz = z[:, j] - z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
stretch = dist - bond_rest_lengths[b_idx]
us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2
# 均匀重力场势能 Ug = -m G·r
ug_sys = np.zeros(n_frames)
if gravity_field:
G_vec = np.array(G)
ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z)
ug_sys = np.sum(ug_sys, axis=1)
# 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r
ugr_sys = np.zeros(n_frames)
if gravity_interaction:
n_atoms = len(masses)
# 为避免巨大计算量,仅当原子数较少时计算
if n_atoms <= 200:
for i in range(n_atoms):
for j in range(i + 1, n_atoms):
dx = x[:, j] - x[:, i]
dy = y[:, j] - y[:, i]
dz = z[:, j] - z[:, i]
dist = np.sqrt(dx**2 + dy**2 + dz**2)
dist = np.maximum(dist, 1e-12)
pair_pe = -gravity_strength * masses[i] * masses[j] / dist
ugr_sys += pair_pe
return ek_sys, us_sys, ug_sys, ugr_sys
def plot_wave(output_dir):
"""主绘图函数:读取 display.txt 并生成波形+能量动画。"""
data = load_disp_data(output_dir)
n_frames = int(data["n_frames"])
t = np.array(data["disp_t"])
# 位置 / 速度
x = np.array(data["disp_all_x"])
y = np.array(data["disp_all_y"])
z = np.array(data["disp_all_z"])
vx = np.array(data["disp_all_vx"])
vy = np.array(data["disp_all_vy"])
vz = np.array(data["disp_all_vz"])
# 原子信息
pos_0 = np.array(data["atom_positions"]) # (n_atoms, 3)
masses = np.array(data["atom_masses"])
atom_ids = np.array(data["atom_ids"])
n_atoms = len(atom_ids)
# 成键
bond_pairs = data.get("bond_pairs", [])
bond_stiffness = np.array(data.get("bond_stiffness", []))
bond_rest_lengths = np.array(data.get("bond_rest_lengths", []))
# 物理开关
gravity_field = int(data.get("gravity_field", 0))
gravity_interaction = int(data.get("gravity_interaction", 0))
G = data.get("G", [0, 0, 0])
gravity_strength = float(data.get("gravity_strength", 1.0))
driving_force = int(data.get("driving_force", 0))
# ── 位移(偏离初始平衡位形)──
dx = x - pos_0[np.newaxis, :, 0] # 纵波(沿链方向 x
dy = y - pos_0[np.newaxis, :, 1] # 横波 1y 方向)
dz = z - pos_0[np.newaxis, :, 2] # 横波 2z 方向)
# ── 能量 ──
ek, us, ug, ugr = compute_energy(
x, y, z, vx, vy, vz, masses, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength)
e_total = ek + us + ug + ugr
power = np.gradient(e_total, t) # 输入功率 = dE/dt
# ── 波形纵轴范围(全局统一,避免抖动)──
def get_ylim(disp_data):
vmax = np.max(np.abs(disp_data))
if vmax < 1e-10:
return -1.0, 1.0
margin = vmax * 0.2
return -vmax - margin, vmax + margin
ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
e_max = max(np.max(e_total), 0.01) * 1.3
p_abs = np.max(np.abs(power))
p_max = max(p_abs * 1.3, 0.01)
atom_idx = np.arange(n_atoms)
# ── 图形设置 ──
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("波形与能量分析", fontsize=16)
# ── 3 个波形子图 ──
titles = [
"纵波 (x 方向位移)",
"横波 (y 方向位移)",
"横波 (z 方向位移)",
]
colors = ["#2563eb", "#ea580c", "#16a34a"]
wave_configs = list(zip(
[axes[0, 0], axes[0, 1], axes[1, 0]],
[dx, dy, dz], titles, colors, ylims
))
wave_lines = []
time_texts = []
for ax, disp, title, color, yl in wave_configs:
ax.set_xlim(0, n_atoms - 1)
ax.set_ylim(yl)
ax.set_xlabel("原子序号")
ax.set_ylabel("位移")
ax.set_title(title)
ax.grid(True, alpha=0.3)
line, = ax.plot([], [], color=color, linewidth=1.5)
wave_lines.append(line)
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
fontsize=11, verticalalignment="top")
time_texts.append(tt)
# ── 能量+功率子图 ──
ax_ep = axes[1, 1]
ax_ep.set_xlim(t[0], t[-1])
ep_ylow = min(-0.1 * max(e_max, p_max), -p_max * 0.1)
ep_yhigh = max(e_max, p_max)
ax_ep.set_ylim(ep_ylow, ep_yhigh)
ax_ep.set_xlabel("时间 (s)")
ax_ep.set_ylabel("能量 / 功率")
ax_ep.set_title("系统能量与输入功率")
ax_ep.grid(True, alpha=0.3)
line_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
line_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
line_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
line_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
if gravity_field:
line_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
else:
line_ug = None
if gravity_interaction and n_atoms <= 200:
line_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
else:
line_ugr = None
ax_ep.legend(loc="upper right", fontsize=9)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# ── 动画更新 ──
def update(frame):
# 波形
for i in range(3):
wave_lines[i].set_data(atom_idx, [dx, dy, dz][i][frame])
time_texts[i].set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}")
# 能量(累计到当前帧)
cur_t = t[:frame + 1]
line_ek.set_data(cur_t, ek[:frame + 1])
line_us.set_data(cur_t, us[:frame + 1])
line_et.set_data(cur_t, e_total[:frame + 1])
line_pw.set_data(cur_t, power[:frame + 1])
if line_ug:
line_ug.set_data(cur_t, ug[:frame + 1])
if line_ugr:
line_ugr.set_data(cur_t, ugr[:frame + 1])
# 能量图 x 轴动态扩展
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
artists = wave_lines + time_texts + [line_ek, line_us, line_et, line_pw]
if line_ug: artists.append(line_ug)
if line_ugr: artists.append(line_ugr)
return artists
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
# 保存 GIF
gif_path = os.path.join(output_dir, "wave_animation.gif")
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
print(f"[plot_wave] GIF 已保存: {gif_path}")
# 尝试保存 MP4
try:
mp4_path = os.path.join(output_dir, "wave_animation.mp4")
ani.save(mp4_path, writer="ffmpeg", fps=min(20, max(1, n_frames // 5)))
print(f"[plot_wave] MP4 已保存: {mp4_path}")
except Exception:
pass
plt.show()
return gif_path
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
if len(sys.argv) > 1:
output_dir = os.path.abspath(sys.argv[1])
else:
output_dir = compute.get_output_dir(script_dir)
plot_wave(output_dir)