feat: redesign plot_wave with per-atom energy panels
New 4x2 layout: left col = x/y/z displacement waves, right col = per-atom KE/PE/total energy + system energy vs time. PE split 50/50 for normal bonds; 100% to non-driven atom when bonded to driver; driven atom PE = E_SHO - KE. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+214
-83
@@ -167,20 +167,121 @@ def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
|
||||
return ek_sys, us_sys, ug_sys, ugr_sys
|
||||
|
||||
|
||||
def _load_driver_info(output_dir):
|
||||
"""从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。
|
||||
返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}。
|
||||
"""
|
||||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||||
path = os.path.join(input_dir, "driver.txt")
|
||||
if not os.path.exists(path):
|
||||
return {}
|
||||
drivers = {}
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
f.readline() # skip header
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) < 10:
|
||||
continue
|
||||
n = int(parts[0])
|
||||
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
|
||||
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])])
|
||||
drivers[n] = {"amp": amp, "freq": freq}
|
||||
except Exception:
|
||||
pass
|
||||
return drivers
|
||||
|
||||
|
||||
def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
|
||||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||||
atom_ids, driver_info):
|
||||
"""计算每帧每个粒子的动能、势能、总能。
|
||||
|
||||
势能分配规则:
|
||||
- 非驱动粒子之间的键:势能各分一半
|
||||
- 键的一端是驱动粒子:势能全部归非驱动端
|
||||
- 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算
|
||||
(此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven)
|
||||
|
||||
Returns:
|
||||
ek_atom: (n_frames, n_atoms) 每原子动能
|
||||
pe_atom: (n_frames, n_atoms) 每原子势能
|
||||
et_atom: (n_frames, n_atoms) 每原子总能
|
||||
"""
|
||||
n_frames, n_atoms = x.shape
|
||||
|
||||
# 动能 per atom
|
||||
masses_2d = masses[np.newaxis, :]
|
||||
ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms)
|
||||
|
||||
# 构建 atom_id → index 映射,以及驱动原子 index 集合
|
||||
id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)}
|
||||
driven_idx = set()
|
||||
for aid in driver_info:
|
||||
if aid in id_to_idx:
|
||||
driven_idx.add(id_to_idx[aid])
|
||||
|
||||
# 势能 per atom(弹簧键)
|
||||
pe_atom = np.zeros((n_frames, n_atoms))
|
||||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||||
for b in range(len(bond_pairs)):
|
||||
i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1])
|
||||
ddx = x[:, j] - x[:, i]
|
||||
ddy = y[:, j] - y[:, i]
|
||||
ddz = z[:, j] - z[:, i]
|
||||
dist = np.sqrt(ddx**2 + ddy**2 + ddz**2)
|
||||
bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2
|
||||
i_driven = i in driven_idx
|
||||
j_driven = j in driven_idx
|
||||
if i_driven and not j_driven:
|
||||
pe_atom[:, j] += bond_pe # 全归非驱动端
|
||||
elif j_driven and not i_driven:
|
||||
pe_atom[:, i] += bond_pe # 全归非驱动端
|
||||
else:
|
||||
pe_atom[:, i] += bond_pe * 0.5
|
||||
pe_atom[:, j] += bond_pe * 0.5
|
||||
|
||||
# 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE
|
||||
TWO_PI = 2.0 * np.pi
|
||||
for aid, info in driver_info.items():
|
||||
if aid not in id_to_idx:
|
||||
continue
|
||||
idx = id_to_idx[aid]
|
||||
m = masses[idx]
|
||||
amp = info["amp"]
|
||||
freq = info["freq"]
|
||||
e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2)
|
||||
pe_sho = e_sho - ek_atom[:, idx]
|
||||
pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负
|
||||
|
||||
et_atom = ek_atom + pe_atom
|
||||
return ek_atom, pe_atom, et_atom
|
||||
|
||||
|
||||
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||||
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
|
||||
|
||||
布局(4行×2列):
|
||||
左列 行0-2:x/y/z 位移波形(vs 原子序号)
|
||||
右列 行0:每粒子动能
|
||||
右列 行1:每粒子势能
|
||||
右列 行2:每粒子总能
|
||||
右列 行3:系统总能量随时间变化
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录(含 display.txt)
|
||||
output_dir: 输出目录(含 display.npz 或 display.txt)
|
||||
save_gif: 是否保存 GIF
|
||||
save_mp4: 是否保存 MP4
|
||||
show: 是否弹出交互窗口
|
||||
"""
|
||||
data = _load_wave_dataset(output_dir)
|
||||
|
||||
n_frames = int(data["n_frames"])
|
||||
t = np.array(data["t"])
|
||||
|
||||
# 位置 / 速度
|
||||
x = np.array(data["x"])
|
||||
y = np.array(data["y"])
|
||||
z = np.array(data["z"])
|
||||
@@ -188,158 +289,190 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||||
vy = np.array(data["vy"])
|
||||
vz = np.array(data["vz"])
|
||||
|
||||
# 原子信息
|
||||
pos_0 = np.array(data["pos_0"])
|
||||
masses = np.array(data["masses"])
|
||||
atom_ids = np.array(data["atom_ids"])
|
||||
n_atoms = len(atom_ids)
|
||||
|
||||
# 成键
|
||||
bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64)
|
||||
bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64)
|
||||
bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64)
|
||||
|
||||
# 物理开关
|
||||
gravity_field = int(data.get("gravity_field", 0))
|
||||
gravity_interaction = int(data.get("gravity_interaction", 0))
|
||||
G = data.get("G", [0, 0, 0])
|
||||
gravity_strength = float(data.get("gravity_strength", 1.0))
|
||||
driving_force = int(data.get("driving_force", 0))
|
||||
|
||||
# ── 位移(偏离初始平衡位形)──
|
||||
dx = x - pos_0[np.newaxis, :, 0] # 纵波(沿链方向 x)
|
||||
dy = y - pos_0[np.newaxis, :, 1] # 横波 1(y 方向)
|
||||
dz = z - pos_0[np.newaxis, :, 2] # 横波 2(z 方向)
|
||||
# 驱动原子信息(用于势能计算)
|
||||
driver_info = _load_driver_info(output_dir) if driving_force else {}
|
||||
|
||||
# ── 能量 ──
|
||||
ek, us, ug, ugr = compute_energy(
|
||||
# ── 位移 ──
|
||||
dx = x - pos_0[np.newaxis, :, 0]
|
||||
dy = y - pos_0[np.newaxis, :, 1]
|
||||
dz = z - pos_0[np.newaxis, :, 2]
|
||||
|
||||
# ── 系统总能量(用于右下时间图)──
|
||||
ek_sys, us_sys, ug_sys, ugr_sys = compute_energy(
|
||||
x, y, z, vx, vy, vz, masses, masses,
|
||||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||||
gravity_field, G, gravity_interaction, gravity_strength)
|
||||
e_total = ek + us + ug + ugr
|
||||
power = np.gradient(e_total, t) # 输入功率 = dE/dt
|
||||
e_total = ek_sys + us_sys + ug_sys + ugr_sys
|
||||
power = np.gradient(e_total, t)
|
||||
|
||||
# ── 波形纵轴范围(全局统一,避免抖动)──
|
||||
def get_ylim(disp_data):
|
||||
vmax = np.max(np.abs(disp_data))
|
||||
# ── 每粒子能量 ──
|
||||
ek_atom, pe_atom, et_atom = compute_per_atom_energy(
|
||||
x, y, z, vx, vy, vz, masses,
|
||||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||||
atom_ids, driver_info)
|
||||
|
||||
# ── y 轴范围 ──
|
||||
def get_ylim(arr):
|
||||
vmax = np.max(np.abs(arr))
|
||||
if vmax < 1e-10:
|
||||
return -1.0, 1.0
|
||||
margin = vmax * 0.2
|
||||
return -vmax - margin, vmax + margin
|
||||
m = vmax * 0.2
|
||||
return -vmax - m, vmax + m
|
||||
|
||||
ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
|
||||
def get_ylim_pos(arr):
|
||||
vmax = np.max(arr)
|
||||
if vmax < 1e-12:
|
||||
return 0.0, 1.0
|
||||
return 0.0, vmax * 1.2
|
||||
|
||||
disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
|
||||
ek_ylim = get_ylim_pos(ek_atom)
|
||||
pe_ylim = get_ylim_pos(pe_atom)
|
||||
et_ylim = get_ylim_pos(et_atom)
|
||||
e_max = max(np.max(e_total), 0.01) * 1.3
|
||||
p_abs = np.max(np.abs(power))
|
||||
p_max = max(p_abs * 1.3, 0.01)
|
||||
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
|
||||
|
||||
atom_idx = np.arange(n_atoms)
|
||||
|
||||
# ── 图形设置 ──
|
||||
# ── 图形布局 ──
|
||||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
fig = plt.figure(figsize=(16, 14))
|
||||
fig.suptitle("波形与能量分析", fontsize=16)
|
||||
|
||||
# ── 3 个波形子图 ──
|
||||
titles = [
|
||||
"纵波 (x 方向位移)",
|
||||
"横波 (y 方向位移)",
|
||||
"横波 (z 方向位移)",
|
||||
]
|
||||
colors = ["#2563eb", "#ea580c", "#16a34a"]
|
||||
wave_configs = list(zip(
|
||||
[axes[0, 0], axes[0, 1], axes[1, 0]],
|
||||
[dx, dy, dz], titles, colors, ylims
|
||||
))
|
||||
import matplotlib.gridspec as gridspec
|
||||
gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35)
|
||||
|
||||
# 左列:位移波形(行 0-2)
|
||||
ax_dx = fig.add_subplot(gs[0, 0])
|
||||
ax_dy = fig.add_subplot(gs[1, 0])
|
||||
ax_dz = fig.add_subplot(gs[2, 0])
|
||||
# 左列行 3 留空(或与右下对齐)
|
||||
ax_blank = fig.add_subplot(gs[3, 0])
|
||||
ax_blank.set_visible(False)
|
||||
|
||||
# 右列:每粒子能量(行 0-2)+ 时间能量图(行 3)
|
||||
ax_ek = fig.add_subplot(gs[0, 1])
|
||||
ax_pe = fig.add_subplot(gs[1, 1])
|
||||
ax_et = fig.add_subplot(gs[2, 1])
|
||||
ax_ep = fig.add_subplot(gs[3, 1])
|
||||
|
||||
# ── 初始化位移波形 ──
|
||||
wave_axes = [ax_dx, ax_dy, ax_dz]
|
||||
wave_disps = [dx, dy, dz]
|
||||
wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"]
|
||||
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
|
||||
wave_lines = []
|
||||
time_texts = []
|
||||
for ax, disp, title, color, yl in wave_configs:
|
||||
|
||||
for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims):
|
||||
ax.set_xlim(0, n_atoms - 1)
|
||||
ax.set_ylim(yl)
|
||||
ax.set_xlabel("原子序号")
|
||||
ax.set_ylabel("位移")
|
||||
ax.set_title(title)
|
||||
ax.grid(True, alpha=0.3)
|
||||
line, = ax.plot([], [], color=color, linewidth=1.5)
|
||||
wave_lines.append(line)
|
||||
ln, = ax.plot([], [], color=color, linewidth=1.5)
|
||||
wave_lines.append(ln)
|
||||
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
|
||||
fontsize=11, verticalalignment="top")
|
||||
fontsize=9, verticalalignment="top")
|
||||
time_texts.append(tt)
|
||||
|
||||
# ── 能量+功率子图 ──
|
||||
ax_ep = axes[1, 1]
|
||||
# ── 初始化每粒子能量图 ──
|
||||
energy_axes = [ax_ek, ax_pe, ax_et]
|
||||
energy_arrays = [ek_atom, pe_atom, et_atom]
|
||||
energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"]
|
||||
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
|
||||
energy_ylims = [ek_ylim, pe_ylim, et_ylim]
|
||||
energy_lines = []
|
||||
|
||||
for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims):
|
||||
ax.set_xlim(0, n_atoms - 1)
|
||||
ax.set_ylim(yl)
|
||||
ax.set_xlabel("原子序号")
|
||||
ax.set_ylabel("能量")
|
||||
ax.set_title(title)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ln, = ax.plot([], [], color=color, linewidth=1.5)
|
||||
energy_lines.append(ln)
|
||||
|
||||
# ── 初始化系统总能量时间图 ──
|
||||
ax_ep.set_xlim(t[0], t[-1])
|
||||
ep_ylow = min(-0.1 * max(e_max, p_max), -p_max * 0.1)
|
||||
ep_yhigh = max(e_max, p_max)
|
||||
ep_ylow = min(-p_max * 0.1, 0.0)
|
||||
ax_ep.set_ylim(ep_ylow, ep_yhigh)
|
||||
ax_ep.set_xlabel("时间 (s)")
|
||||
ax_ep.set_ylabel("能量 / 功率")
|
||||
ax_ep.set_title("系统能量与输入功率")
|
||||
ax_ep.grid(True, alpha=0.3)
|
||||
|
||||
line_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
|
||||
line_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
|
||||
line_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
|
||||
line_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
|
||||
ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
|
||||
ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
|
||||
ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
|
||||
ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
|
||||
ln_ug = None
|
||||
ln_ugr = None
|
||||
if gravity_field:
|
||||
line_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
|
||||
else:
|
||||
line_ug = None
|
||||
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
|
||||
if gravity_interaction and n_atoms <= 200:
|
||||
line_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
|
||||
else:
|
||||
line_ugr = None
|
||||
|
||||
ax_ep.legend(loc="upper right", fontsize=9)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
|
||||
ax_ep.legend(loc="upper left", fontsize=8)
|
||||
|
||||
# ── 动画更新 ──
|
||||
def update(frame):
|
||||
# 波形
|
||||
# 位移波形
|
||||
for i in range(3):
|
||||
wave_lines[i].set_data(atom_idx, [dx, dy, dz][i][frame])
|
||||
time_texts[i].set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}")
|
||||
wave_lines[i].set_data(atom_idx, wave_disps[i][frame])
|
||||
time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}")
|
||||
|
||||
# 能量(累计到当前帧)
|
||||
# 每粒子能量
|
||||
for i in range(3):
|
||||
energy_lines[i].set_data(atom_idx, energy_arrays[i][frame])
|
||||
|
||||
# 系统总能量(累计到当前帧)
|
||||
cur_t = t[:frame + 1]
|
||||
line_ek.set_data(cur_t, ek[:frame + 1])
|
||||
line_us.set_data(cur_t, us[:frame + 1])
|
||||
line_et.set_data(cur_t, e_total[:frame + 1])
|
||||
line_pw.set_data(cur_t, power[:frame + 1])
|
||||
if line_ug:
|
||||
line_ug.set_data(cur_t, ug[:frame + 1])
|
||||
if line_ugr:
|
||||
line_ugr.set_data(cur_t, ugr[:frame + 1])
|
||||
|
||||
# 能量图 x 轴动态扩展
|
||||
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
|
||||
ln_us.set_data(cur_t, us_sys[:frame + 1])
|
||||
ln_et.set_data(cur_t, e_total[:frame + 1])
|
||||
ln_pw.set_data(cur_t, power[:frame + 1])
|
||||
if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1])
|
||||
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
|
||||
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
|
||||
|
||||
artists = wave_lines + time_texts + [line_ek, line_us, line_et, line_pw]
|
||||
if line_ug: artists.append(line_ug)
|
||||
if line_ugr: artists.append(line_ugr)
|
||||
artists = wave_lines + time_texts + energy_lines + \
|
||||
[ln_ek, ln_us, ln_et, ln_pw]
|
||||
if ln_ug: artists.append(ln_ug)
|
||||
if ln_ugr: artists.append(ln_ugr)
|
||||
return artists
|
||||
|
||||
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
|
||||
|
||||
# ── 输出 GIF ──
|
||||
# ── 输出文件 ──
|
||||
gif_path = None
|
||||
if save_gif:
|
||||
gif_path = os.path.join(output_dir, "wave_animation.gif")
|
||||
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
|
||||
print(f"[plot_wave] GIF 已保存: {gif_path}")
|
||||
|
||||
# ── 输出 MP4(需要 ffmpeg)──
|
||||
gif_path = None
|
||||
if save_gif:
|
||||
gif_path = os.path.join(output_dir, "wave_animation.gif")
|
||||
if save_mp4:
|
||||
try:
|
||||
import matplotlib.animation as manim
|
||||
import matplotlib.pyplot as _plt
|
||||
|
||||
# 尝试通过 imageio_ffmpeg 定位 ffmpeg
|
||||
ffmpeg_path = None
|
||||
try:
|
||||
import imageio_ffmpeg
|
||||
@@ -347,8 +480,7 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||||
except Exception:
|
||||
pass
|
||||
if ffmpeg_path and os.path.exists(ffmpeg_path):
|
||||
_plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
|
||||
|
||||
plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
|
||||
ffps = min(20, max(1, n_frames // 5))
|
||||
writer = manim.FFMpegWriter(fps=ffps, codec="libx264",
|
||||
extra_args=["-pix_fmt", "yuv420p"])
|
||||
@@ -360,7 +492,6 @@ def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||||
except Exception as e:
|
||||
print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过")
|
||||
|
||||
# ── 最后显示动画窗口(仅直接运行时)──
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user