feat: redesign plot_wave with per-atom energy panels

New 4x2 layout: left col = x/y/z displacement waves, right col = per-atom KE/PE/total energy + system energy vs time.
PE split 50/50 for normal bonds; 100% to non-driven atom when bonded to driver; driven atom PE = E_SHO - KE.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-13 08:26:51 +08:00
parent 39ff650539
commit 2ab3436235
+214 -83
View File
@@ -167,20 +167,121 @@ def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
return ek_sys, us_sys, ug_sys, ugr_sys return ek_sys, us_sys, ug_sys, ugr_sys
def _load_driver_info(output_dir):
"""从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。
返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}
"""
input_dir = os.path.join(os.path.dirname(output_dir), "input")
path = os.path.join(input_dir, "driver.txt")
if not os.path.exists(path):
return {}
drivers = {}
try:
with open(path, encoding="utf-8") as f:
f.readline() # skip header
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) < 10:
continue
n = int(parts[0])
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])])
drivers[n] = {"amp": amp, "freq": freq}
except Exception:
pass
return drivers
def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
atom_ids, driver_info):
"""计算每帧每个粒子的动能、势能、总能。
势能分配规则:
- 非驱动粒子之间的键:势能各分一半
- 键的一端是驱动粒子:势能全部归非驱动端
- 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算
(此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven
Returns:
ek_atom: (n_frames, n_atoms) 每原子动能
pe_atom: (n_frames, n_atoms) 每原子势能
et_atom: (n_frames, n_atoms) 每原子总能
"""
n_frames, n_atoms = x.shape
# 动能 per atom
masses_2d = masses[np.newaxis, :]
ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms)
# 构建 atom_id → index 映射,以及驱动原子 index 集合
id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)}
driven_idx = set()
for aid in driver_info:
if aid in id_to_idx:
driven_idx.add(id_to_idx[aid])
# 势能 per atom(弹簧键)
pe_atom = np.zeros((n_frames, n_atoms))
if bond_pairs is not None and len(bond_pairs) > 0:
for b in range(len(bond_pairs)):
i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1])
ddx = x[:, j] - x[:, i]
ddy = y[:, j] - y[:, i]
ddz = z[:, j] - z[:, i]
dist = np.sqrt(ddx**2 + ddy**2 + ddz**2)
bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2
i_driven = i in driven_idx
j_driven = j in driven_idx
if i_driven and not j_driven:
pe_atom[:, j] += bond_pe # 全归非驱动端
elif j_driven and not i_driven:
pe_atom[:, i] += bond_pe # 全归非驱动端
else:
pe_atom[:, i] += bond_pe * 0.5
pe_atom[:, j] += bond_pe * 0.5
# 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE
TWO_PI = 2.0 * np.pi
for aid, info in driver_info.items():
if aid not in id_to_idx:
continue
idx = id_to_idx[aid]
m = masses[idx]
amp = info["amp"]
freq = info["freq"]
e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2)
pe_sho = e_sho - ek_atom[:, idx]
pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负
et_atom = ek_atom + pe_atom
return ek_atom, pe_atom, et_atom
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True): def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
"""主绘图函数:读取 display.txt 并生成波形+能量动画。 """主绘图函数:读取 display.txt 并生成波形+能量动画。
布局(4行×2列):
左列 行0-2:x/y/z 位移波形(vs 原子序号)
右列 行0:每粒子动能
右列 行1:每粒子势能
右列 行2:每粒子总能
右列 行3:系统总能量随时间变化
Args: Args:
output_dir: 输出目录(含 display.txt output_dir: 输出目录(含 display.npz 或 display.txt
save_gif: 是否保存 GIF save_gif: 是否保存 GIF
save_mp4: 是否保存 MP4 save_mp4: 是否保存 MP4
show: 是否弹出交互窗口
""" """
data = _load_wave_dataset(output_dir) data = _load_wave_dataset(output_dir)
n_frames = int(data["n_frames"]) n_frames = int(data["n_frames"])
t = np.array(data["t"]) t = np.array(data["t"])
# 位置 / 速度
x = np.array(data["x"]) x = np.array(data["x"])
y = np.array(data["y"]) y = np.array(data["y"])
z = np.array(data["z"]) z = np.array(data["z"])
@@ -188,158 +289,190 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
vy = np.array(data["vy"]) vy = np.array(data["vy"])
vz = np.array(data["vz"]) vz = np.array(data["vz"])
# 原子信息
pos_0 = np.array(data["pos_0"]) pos_0 = np.array(data["pos_0"])
masses = np.array(data["masses"]) masses = np.array(data["masses"])
atom_ids = np.array(data["atom_ids"]) atom_ids = np.array(data["atom_ids"])
n_atoms = len(atom_ids) n_atoms = len(atom_ids)
# 成键
bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64) bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64)
bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64) bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64)
bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64) bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64)
# 物理开关
gravity_field = int(data.get("gravity_field", 0)) gravity_field = int(data.get("gravity_field", 0))
gravity_interaction = int(data.get("gravity_interaction", 0)) gravity_interaction = int(data.get("gravity_interaction", 0))
G = data.get("G", [0, 0, 0]) G = data.get("G", [0, 0, 0])
gravity_strength = float(data.get("gravity_strength", 1.0)) gravity_strength = float(data.get("gravity_strength", 1.0))
driving_force = int(data.get("driving_force", 0)) driving_force = int(data.get("driving_force", 0))
# ── 位移(偏离初始平衡位形)── # 驱动原子信息(用于势能计算)
dx = x - pos_0[np.newaxis, :, 0] # 纵波(沿链方向 x driver_info = _load_driver_info(output_dir) if driving_force else {}
dy = y - pos_0[np.newaxis, :, 1] # 横波 1y 方向)
dz = z - pos_0[np.newaxis, :, 2] # 横波 2z 方向)
# ── 能量 ── # ── 位移 ──
ek, us, ug, ugr = compute_energy( dx = x - pos_0[np.newaxis, :, 0]
dy = y - pos_0[np.newaxis, :, 1]
dz = z - pos_0[np.newaxis, :, 2]
# ── 系统总能量(用于右下时间图)──
ek_sys, us_sys, ug_sys, ugr_sys = compute_energy(
x, y, z, vx, vy, vz, masses, masses, x, y, z, vx, vy, vz, masses, masses,
bond_pairs, bond_stiffness, bond_rest_lengths, bond_pairs, bond_stiffness, bond_rest_lengths,
gravity_field, G, gravity_interaction, gravity_strength) gravity_field, G, gravity_interaction, gravity_strength)
e_total = ek + us + ug + ugr e_total = ek_sys + us_sys + ug_sys + ugr_sys
power = np.gradient(e_total, t) # 输入功率 = dE/dt power = np.gradient(e_total, t)
# ── 波形纵轴范围(全局统一,避免抖动)── # ── 每粒子能量 ──
def get_ylim(disp_data): ek_atom, pe_atom, et_atom = compute_per_atom_energy(
vmax = np.max(np.abs(disp_data)) x, y, z, vx, vy, vz, masses,
bond_pairs, bond_stiffness, bond_rest_lengths,
atom_ids, driver_info)
# ── y 轴范围 ──
def get_ylim(arr):
vmax = np.max(np.abs(arr))
if vmax < 1e-10: if vmax < 1e-10:
return -1.0, 1.0 return -1.0, 1.0
margin = vmax * 0.2 m = vmax * 0.2
return -vmax - margin, vmax + margin return -vmax - m, vmax + m
ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)] def get_ylim_pos(arr):
vmax = np.max(arr)
if vmax < 1e-12:
return 0.0, 1.0
return 0.0, vmax * 1.2
disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
ek_ylim = get_ylim_pos(ek_atom)
pe_ylim = get_ylim_pos(pe_atom)
et_ylim = get_ylim_pos(et_atom)
e_max = max(np.max(e_total), 0.01) * 1.3 e_max = max(np.max(e_total), 0.01) * 1.3
p_abs = np.max(np.abs(power)) p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
p_max = max(p_abs * 1.3, 0.01)
atom_idx = np.arange(n_atoms) atom_idx = np.arange(n_atoms)
# ── 图形设置 ── # ── 图形布局 ──
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans'] plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False plt.rcParams['axes.unicode_minus'] = False
fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig = plt.figure(figsize=(16, 14))
fig.suptitle("波形与能量分析", fontsize=16) fig.suptitle("波形与能量分析", fontsize=16)
# ── 3 个波形子图 ── import matplotlib.gridspec as gridspec
titles = [ gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35)
"纵波 (x 方向位移)",
"横波 (y 方向位移)",
"横波 (z 方向位移)",
]
colors = ["#2563eb", "#ea580c", "#16a34a"]
wave_configs = list(zip(
[axes[0, 0], axes[0, 1], axes[1, 0]],
[dx, dy, dz], titles, colors, ylims
))
# 左列:位移波形(行 0-2
ax_dx = fig.add_subplot(gs[0, 0])
ax_dy = fig.add_subplot(gs[1, 0])
ax_dz = fig.add_subplot(gs[2, 0])
# 左列行 3 留空(或与右下对齐)
ax_blank = fig.add_subplot(gs[3, 0])
ax_blank.set_visible(False)
# 右列:每粒子能量(行 0-2)+ 时间能量图(行 3)
ax_ek = fig.add_subplot(gs[0, 1])
ax_pe = fig.add_subplot(gs[1, 1])
ax_et = fig.add_subplot(gs[2, 1])
ax_ep = fig.add_subplot(gs[3, 1])
# ── 初始化位移波形 ──
wave_axes = [ax_dx, ax_dy, ax_dz]
wave_disps = [dx, dy, dz]
wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"]
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
wave_lines = [] wave_lines = []
time_texts = [] time_texts = []
for ax, disp, title, color, yl in wave_configs:
for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims):
ax.set_xlim(0, n_atoms - 1) ax.set_xlim(0, n_atoms - 1)
ax.set_ylim(yl) ax.set_ylim(yl)
ax.set_xlabel("原子序号") ax.set_xlabel("原子序号")
ax.set_ylabel("位移") ax.set_ylabel("位移")
ax.set_title(title) ax.set_title(title)
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
line, = ax.plot([], [], color=color, linewidth=1.5) ln, = ax.plot([], [], color=color, linewidth=1.5)
wave_lines.append(line) wave_lines.append(ln)
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes, tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
fontsize=11, verticalalignment="top") fontsize=9, verticalalignment="top")
time_texts.append(tt) time_texts.append(tt)
# ── 能量+功率子图 ── # ── 初始化每粒子能量图 ──
ax_ep = axes[1, 1] energy_axes = [ax_ek, ax_pe, ax_et]
energy_arrays = [ek_atom, pe_atom, et_atom]
energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"]
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
energy_ylims = [ek_ylim, pe_ylim, et_ylim]
energy_lines = []
for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims):
ax.set_xlim(0, n_atoms - 1)
ax.set_ylim(yl)
ax.set_xlabel("原子序号")
ax.set_ylabel("能量")
ax.set_title(title)
ax.grid(True, alpha=0.3)
ln, = ax.plot([], [], color=color, linewidth=1.5)
energy_lines.append(ln)
# ── 初始化系统总能量时间图 ──
ax_ep.set_xlim(t[0], t[-1]) ax_ep.set_xlim(t[0], t[-1])
ep_ylow = min(-0.1 * max(e_max, p_max), -p_max * 0.1)
ep_yhigh = max(e_max, p_max) ep_yhigh = max(e_max, p_max)
ep_ylow = min(-p_max * 0.1, 0.0)
ax_ep.set_ylim(ep_ylow, ep_yhigh) ax_ep.set_ylim(ep_ylow, ep_yhigh)
ax_ep.set_xlabel("时间 (s)") ax_ep.set_xlabel("时间 (s)")
ax_ep.set_ylabel("能量 / 功率") ax_ep.set_ylabel("能量 / 功率")
ax_ep.set_title("系统能量与输入功率") ax_ep.set_title("系统能量与输入功率")
ax_ep.grid(True, alpha=0.3) ax_ep.grid(True, alpha=0.3)
line_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能") ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
line_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能") ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
line_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量") ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
line_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)") ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
ln_ug = None
ln_ugr = None
if gravity_field: if gravity_field:
line_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能") ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
else:
line_ug = None
if gravity_interaction and n_atoms <= 200: if gravity_interaction and n_atoms <= 200:
line_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能") ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
else: ax_ep.legend(loc="upper left", fontsize=8)
line_ugr = None
ax_ep.legend(loc="upper right", fontsize=9)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# ── 动画更新 ── # ── 动画更新 ──
def update(frame): def update(frame):
# 波形 # 位移波形
for i in range(3): for i in range(3):
wave_lines[i].set_data(atom_idx, [dx, dy, dz][i][frame]) wave_lines[i].set_data(atom_idx, wave_disps[i][frame])
time_texts[i].set_text(f"t = {t[frame]:.2f} s |{frame+1}/{n_frames}") time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}")
# 能量(累计到当前帧) # 每粒子能量
for i in range(3):
energy_lines[i].set_data(atom_idx, energy_arrays[i][frame])
# 系统总能量(累计到当前帧)
cur_t = t[:frame + 1] cur_t = t[:frame + 1]
line_ek.set_data(cur_t, ek[:frame + 1]) ln_ek.set_data(cur_t, ek_sys[:frame + 1])
line_us.set_data(cur_t, us[:frame + 1]) ln_us.set_data(cur_t, us_sys[:frame + 1])
line_et.set_data(cur_t, e_total[:frame + 1]) ln_et.set_data(cur_t, e_total[:frame + 1])
line_pw.set_data(cur_t, power[:frame + 1]) ln_pw.set_data(cur_t, power[:frame + 1])
if line_ug: if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1])
line_ug.set_data(cur_t, ug[:frame + 1]) if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
if line_ugr:
line_ugr.set_data(cur_t, ugr[:frame + 1])
# 能量图 x 轴动态扩展
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1])) ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
artists = wave_lines + time_texts + [line_ek, line_us, line_et, line_pw] artists = wave_lines + time_texts + energy_lines + \
if line_ug: artists.append(line_ug) [ln_ek, ln_us, ln_et, ln_pw]
if line_ugr: artists.append(line_ugr) if ln_ug: artists.append(ln_ug)
if ln_ugr: artists.append(ln_ugr)
return artists return artists
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True) ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
# ── 输出 GIF ── # ── 输出文件 ──
gif_path = None
if save_gif: if save_gif:
gif_path = os.path.join(output_dir, "wave_animation.gif") gif_path = os.path.join(output_dir, "wave_animation.gif")
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5))) ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
print(f"[plot_wave] GIF 已保存: {gif_path}") print(f"[plot_wave] GIF 已保存: {gif_path}")
# ── 输出 MP4(需要 ffmpeg)──
gif_path = None
if save_gif:
gif_path = os.path.join(output_dir, "wave_animation.gif")
if save_mp4: if save_mp4:
try: try:
import matplotlib.animation as manim import matplotlib.animation as manim
import matplotlib.pyplot as _plt
# 尝试通过 imageio_ffmpeg 定位 ffmpeg
ffmpeg_path = None ffmpeg_path = None
try: try:
import imageio_ffmpeg import imageio_ffmpeg
@@ -347,8 +480,7 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
except Exception: except Exception:
pass pass
if ffmpeg_path and os.path.exists(ffmpeg_path): if ffmpeg_path and os.path.exists(ffmpeg_path):
_plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
ffps = min(20, max(1, n_frames // 5)) ffps = min(20, max(1, n_frames // 5))
writer = manim.FFMpegWriter(fps=ffps, codec="libx264", writer = manim.FFMpegWriter(fps=ffps, codec="libx264",
extra_args=["-pix_fmt", "yuv420p"]) extra_args=["-pix_fmt", "yuv420p"])
@@ -360,7 +492,6 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
except Exception as e: except Exception as e:
print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过") print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过")
# ── 最后显示动画窗口(仅直接运行时)──
if show: if show:
plt.show() plt.show()
else: else: