ea99f09f9b
Add compute_energy_flux() using the Hardy formula: J_b = 1/2 * F_bond_on_i * (v_i + v_j) New 4th subplot shows J vs bond position (x coordinate of bond midpoint). J > 0: energy flows rightward; J = 0: standing wave; J < 0: leftward. Ideal standing wave would show J ≈ 0 everywhere. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
583 lines
22 KiB
Python
583 lines
22 KiB
Python
"""
|
||
plot_wave.py
|
||
============
|
||
波形与能量动态图:读取 display.txt,绘制原子位移波形
|
||
(纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。
|
||
|
||
用法:
|
||
python plot_wave.py # 使用 dynamics 根目录下 output/
|
||
python plot_wave.py examples/case05/output # 指定案例输出目录
|
||
"""
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib.animation import FuncAnimation
|
||
import os
|
||
import sys
|
||
import json
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
import compute
|
||
|
||
|
||
def load_disp_data(output_dir):
|
||
"""加载 display.npz(优先)或 display.txt。"""
|
||
npz_path = os.path.join(output_dir, "display.npz")
|
||
txt_path = os.path.join(output_dir, "display.txt")
|
||
if os.path.exists(npz_path):
|
||
return compute.load_display_npz(npz_path)
|
||
if os.path.exists(txt_path):
|
||
return compute.load_display_txt(txt_path)
|
||
raise FileNotFoundError(f"找不到 display.npz 或 display.txt in {output_dir}")
|
||
|
||
|
||
def _header_json(header_fields, key, default):
|
||
raw = header_fields.get(key, "")
|
||
if not raw:
|
||
return default
|
||
try:
|
||
return json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
return default
|
||
|
||
|
||
def _load_wave_dataset(output_dir):
|
||
"""Load wave/energy plotting data from display metadata or sibling input files."""
|
||
disp_data = load_disp_data(output_dir)
|
||
header = disp_data["header_fields"]
|
||
|
||
x = disp_data["frames_x"]
|
||
y = disp_data["frames_y"]
|
||
z = disp_data["frames_z"]
|
||
vx = disp_data["frames_vx"]
|
||
vy = disp_data["frames_vy"]
|
||
vz = disp_data["frames_vz"]
|
||
atom_ids = np.array(disp_data["atom_ids"], dtype=np.int64)
|
||
n_frames = x.shape[0]
|
||
dt = float(header.get("DT", 0.001))
|
||
nstep = int(header.get("NSTEP", 1))
|
||
t = np.arange(n_frames, dtype=np.float64) * dt * nstep
|
||
|
||
masses = np.array(_header_json(header, "atom_masses", []), dtype=np.float64)
|
||
pos_0 = np.array(_header_json(header, "atom_positions", []), dtype=np.float64)
|
||
bond_pairs = np.array(_header_json(header, "bond_pairs", []), dtype=np.int64)
|
||
bond_stiffness = np.array(_header_json(header, "bond_stiffness", []), dtype=np.float64)
|
||
bond_rest_lengths = np.array(_header_json(header, "bond_rest_lengths", []), dtype=np.float64)
|
||
gravity_vec = _header_json(header, "G", [0.0, 0.0, 0.0])
|
||
|
||
# Backward-compatible fallback for older display.txt outputs.
|
||
if masses.size == 0 or pos_0.size == 0:
|
||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||
coord_path = os.path.join(input_dir, "coord.txt")
|
||
if os.path.exists(coord_path):
|
||
(_, masses_fb, _, positions_fb, _, _) = compute.load_coord_file(coord_path)
|
||
masses = np.array(masses_fb, dtype=np.float64)
|
||
pos_0 = np.array(positions_fb, dtype=np.float64)
|
||
else:
|
||
raise ValueError("display.txt 缺少 atom_masses/atom_positions 元数据,且未找到 input/coord.txt")
|
||
|
||
if bond_pairs.size == 0:
|
||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||
connection_path = os.path.join(input_dir, "connection.txt")
|
||
bond_path = os.path.join(input_dir, "bond.txt")
|
||
if os.path.exists(connection_path) and os.path.exists(bond_path):
|
||
bond_map = compute.load_bond_parameters(bond_path)
|
||
pairs_fb, _, stiffness_fb, rest_lengths_fb = compute.load_bond_connections(
|
||
connection_path, atom_ids, pos_0, bond_map)
|
||
bond_pairs = np.array(pairs_fb, dtype=np.int64)
|
||
bond_stiffness = np.array(stiffness_fb, dtype=np.float64)
|
||
bond_rest_lengths = np.array(rest_lengths_fb, dtype=np.float64)
|
||
|
||
return {
|
||
"n_frames": n_frames,
|
||
"t": t,
|
||
"x": x,
|
||
"y": y,
|
||
"z": z,
|
||
"vx": vx,
|
||
"vy": vy,
|
||
"vz": vz,
|
||
"pos_0": pos_0,
|
||
"masses": masses,
|
||
"atom_ids": atom_ids,
|
||
"bond_pairs": bond_pairs,
|
||
"bond_stiffness": bond_stiffness,
|
||
"bond_rest_lengths": bond_rest_lengths,
|
||
"gravity_field": int(header.get("gravity_field", 0)),
|
||
"gravity_interaction": int(header.get("gravity_interaction", 0)),
|
||
"gravity_strength": float(header.get("gravity_strength", 1.0)),
|
||
"G": gravity_vec,
|
||
"driving_force": int(header.get("driving_force", 0)),
|
||
}
|
||
|
||
|
||
def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
gravity_field, G, gravity_interaction, gravity_strength):
|
||
"""计算系统各能量分量。
|
||
|
||
Returns:
|
||
ek_sys: 系统动能 (n_frames,)
|
||
us_sys: 系统弹性势能 (n_frames,)
|
||
ug_sys: 系统重力势能 (n_frames,)
|
||
ugr_sys: 系统万有引力势能 (n_frames,)
|
||
"""
|
||
n_frames = x.shape[0]
|
||
masses_2d = masses[np.newaxis, :] # (1, n_atoms)
|
||
|
||
# 动能 Ek = ½ m v²
|
||
ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2)
|
||
ek_sys = np.sum(ek, axis=1)
|
||
|
||
# 弹性势能 Us = ½ k (d - d₀)²
|
||
us_sys = np.zeros(n_frames)
|
||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||
for b_idx in range(len(bond_pairs)):
|
||
i, j = bond_pairs[b_idx]
|
||
dx = x[:, j] - x[:, i]
|
||
dy = y[:, j] - y[:, i]
|
||
dz = z[:, j] - z[:, i]
|
||
dist = np.sqrt(dx**2 + dy**2 + dz**2)
|
||
stretch = dist - bond_rest_lengths[b_idx]
|
||
us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2
|
||
|
||
# 均匀重力场势能 Ug = -m G·r
|
||
ug_sys = np.zeros(n_frames)
|
||
if gravity_field:
|
||
G_vec = np.array(G)
|
||
ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z)
|
||
ug_sys = np.sum(ug_sys, axis=1)
|
||
|
||
# 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r
|
||
ugr_sys = np.zeros(n_frames)
|
||
if gravity_interaction:
|
||
n_atoms = len(masses)
|
||
# 为避免巨大计算量,仅当原子数较少时计算
|
||
if n_atoms <= 200:
|
||
for i in range(n_atoms):
|
||
for j in range(i + 1, n_atoms):
|
||
dx = x[:, j] - x[:, i]
|
||
dy = y[:, j] - y[:, i]
|
||
dz = z[:, j] - z[:, i]
|
||
dist = np.sqrt(dx**2 + dy**2 + dz**2)
|
||
dist = np.maximum(dist, 1e-12)
|
||
pair_pe = -gravity_strength * masses[i] * masses[j] / dist
|
||
ugr_sys += pair_pe
|
||
|
||
return ek_sys, us_sys, ug_sys, ugr_sys
|
||
|
||
|
||
def _load_driver_info(output_dir):
|
||
"""从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。
|
||
返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}。
|
||
"""
|
||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||
path = os.path.join(input_dir, "driver.txt")
|
||
if not os.path.exists(path):
|
||
return {}
|
||
drivers = {}
|
||
try:
|
||
with open(path, encoding="utf-8") as f:
|
||
f.readline() # skip header
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
parts = line.split()
|
||
if len(parts) < 10:
|
||
continue
|
||
n = int(parts[0])
|
||
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
|
||
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])])
|
||
drivers[n] = {"amp": amp, "freq": freq}
|
||
except Exception:
|
||
pass
|
||
return drivers
|
||
|
||
|
||
def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
atom_ids, driver_info):
|
||
"""计算每帧每个粒子的动能、势能、总能。
|
||
|
||
势能分配规则:
|
||
- 非驱动粒子之间的键:势能各分一半
|
||
- 键的一端是驱动粒子:势能全部归非驱动端
|
||
- 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算
|
||
(此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven)
|
||
|
||
Returns:
|
||
ek_atom: (n_frames, n_atoms) 每原子动能
|
||
pe_atom: (n_frames, n_atoms) 每原子势能
|
||
et_atom: (n_frames, n_atoms) 每原子总能
|
||
"""
|
||
n_frames, n_atoms = x.shape
|
||
|
||
# 动能 per atom
|
||
masses_2d = masses[np.newaxis, :]
|
||
ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms)
|
||
|
||
# 构建 atom_id → index 映射,以及驱动原子 index 集合
|
||
id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)}
|
||
driven_idx = set()
|
||
for aid in driver_info:
|
||
if aid in id_to_idx:
|
||
driven_idx.add(id_to_idx[aid])
|
||
|
||
# 势能 per atom(弹簧键)
|
||
pe_atom = np.zeros((n_frames, n_atoms))
|
||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||
for b in range(len(bond_pairs)):
|
||
i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1])
|
||
ddx = x[:, j] - x[:, i]
|
||
ddy = y[:, j] - y[:, i]
|
||
ddz = z[:, j] - z[:, i]
|
||
dist = np.sqrt(ddx**2 + ddy**2 + ddz**2)
|
||
bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2
|
||
i_driven = i in driven_idx
|
||
j_driven = j in driven_idx
|
||
if i_driven and not j_driven:
|
||
pe_atom[:, j] += bond_pe # 全归非驱动端
|
||
elif j_driven and not i_driven:
|
||
pe_atom[:, i] += bond_pe # 全归非驱动端
|
||
else:
|
||
pe_atom[:, i] += bond_pe * 0.5
|
||
pe_atom[:, j] += bond_pe * 0.5
|
||
|
||
# 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE
|
||
TWO_PI = 2.0 * np.pi
|
||
for aid, info in driver_info.items():
|
||
if aid not in id_to_idx:
|
||
continue
|
||
idx = id_to_idx[aid]
|
||
m = masses[idx]
|
||
amp = info["amp"]
|
||
freq = info["freq"]
|
||
e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2)
|
||
pe_sho = e_sho - ek_atom[:, idx]
|
||
pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负
|
||
|
||
et_atom = ek_atom + pe_atom
|
||
return ek_atom, pe_atom, et_atom
|
||
|
||
|
||
def compute_energy_flux(x, y, z, vx, vy, vz,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths):
|
||
"""计算每帧每根键的能流密度(Hardy 公式)。
|
||
|
||
对键 b 连接原子 i, j(j > i,方向 i→j):
|
||
|
||
J_b = ½ · F_{b,i} · (v_i + v_j)
|
||
|
||
其中 F_{b,i} = k(d - r₀) * (r_j - r_i)/d 是键对原子 i 的弹力矢量,
|
||
点乘取两端速度均值。
|
||
|
||
- J > 0:能量从 i 流向 j(沿键方向正流)
|
||
- J = 0:驻波,能量不流动
|
||
- 沿链从左到右,J 的分布揭示能量传播方向
|
||
|
||
Returns:
|
||
flux: (n_frames, n_bonds) 每帧每键的能流(标量)
|
||
bond_xpos: (n_bonds,) 各键中点的初始 x 坐标(用于绘图横轴)
|
||
"""
|
||
if bond_pairs is None or len(bond_pairs) == 0:
|
||
n_frames = x.shape[0]
|
||
return np.zeros((n_frames, 0)), np.zeros(0)
|
||
|
||
n_bonds = len(bond_pairs)
|
||
n_frames = x.shape[0]
|
||
flux = np.zeros((n_frames, n_bonds))
|
||
|
||
# 键中点初始 x 坐标(用于横轴定位)
|
||
bond_xpos = np.array([
|
||
0.5 * (x[0, bond_pairs[b, 0]] + x[0, bond_pairs[b, 1]])
|
||
for b in range(n_bonds)
|
||
])
|
||
|
||
for b in range(n_bonds):
|
||
i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1])
|
||
k = bond_stiffness[b]
|
||
r0 = bond_rest_lengths[b]
|
||
|
||
# 键矢量与长度
|
||
dx_ = x[:, j] - x[:, i]
|
||
dy_ = y[:, j] - y[:, i]
|
||
dz_ = z[:, j] - z[:, i]
|
||
d = np.sqrt(dx_**2 + dy_**2 + dz_**2)
|
||
d = np.maximum(d, 1e-12)
|
||
|
||
# 弹力矢量(作用于原子 i,指向 j 方向)
|
||
fac = k * (d - r0) / d # 标量因子
|
||
fx = fac * dx_
|
||
fy = fac * dy_
|
||
fz = fac * dz_
|
||
|
||
# 能流 = F_i · (v_i + v_j) / 2
|
||
flux[:, b] = 0.5 * (fx * (vx[:, i] + vx[:, j])
|
||
+ fy * (vy[:, i] + vy[:, j])
|
||
+ fz * (vz[:, i] + vz[:, j]))
|
||
|
||
return flux, bond_xpos
|
||
|
||
|
||
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
|
||
|
||
布局(4行×1列,纵向排列):
|
||
行0:x/y/z 位移波形叠加在同一子图(vs 原子序号)
|
||
行1:每粒子动能、势能、总能叠加在同一子图
|
||
行2:键能流密度 J(Hardy 公式,vs 键中点位置)
|
||
行3:系统总能量随时间变化
|
||
|
||
Args:
|
||
output_dir: 输出目录(含 display.npz 或 display.txt)
|
||
save_gif: 是否保存 GIF
|
||
save_mp4: 是否保存 MP4
|
||
show: 是否弹出交互窗口
|
||
"""
|
||
data = _load_wave_dataset(output_dir)
|
||
|
||
n_frames = int(data["n_frames"])
|
||
t = np.array(data["t"])
|
||
|
||
x = np.array(data["x"])
|
||
y = np.array(data["y"])
|
||
z = np.array(data["z"])
|
||
vx = np.array(data["vx"])
|
||
vy = np.array(data["vy"])
|
||
vz = np.array(data["vz"])
|
||
|
||
pos_0 = np.array(data["pos_0"])
|
||
masses = np.array(data["masses"])
|
||
atom_ids = np.array(data["atom_ids"])
|
||
n_atoms = len(atom_ids)
|
||
|
||
bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64)
|
||
bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64)
|
||
bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64)
|
||
|
||
gravity_field = int(data.get("gravity_field", 0))
|
||
gravity_interaction = int(data.get("gravity_interaction", 0))
|
||
G = data.get("G", [0, 0, 0])
|
||
gravity_strength = float(data.get("gravity_strength", 1.0))
|
||
driving_force = int(data.get("driving_force", 0))
|
||
|
||
# 驱动原子信息(用于势能计算)
|
||
driver_info = _load_driver_info(output_dir) if driving_force else {}
|
||
|
||
# ── 位移 ──
|
||
dx = x - pos_0[np.newaxis, :, 0]
|
||
dy = y - pos_0[np.newaxis, :, 1]
|
||
dz = z - pos_0[np.newaxis, :, 2]
|
||
|
||
# ── 系统总能量(用于右下时间图)──
|
||
ek_sys, us_sys, ug_sys, ugr_sys = compute_energy(
|
||
x, y, z, vx, vy, vz, masses, masses,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
gravity_field, G, gravity_interaction, gravity_strength)
|
||
e_total = ek_sys + us_sys + ug_sys + ugr_sys
|
||
power = np.gradient(e_total, t)
|
||
|
||
# ── 每粒子能量 ──
|
||
ek_atom, pe_atom, et_atom = compute_per_atom_energy(
|
||
x, y, z, vx, vy, vz, masses,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
atom_ids, driver_info)
|
||
|
||
# ── 能流密度 ──
|
||
flux, bond_xpos = compute_energy_flux(
|
||
x, y, z, vx, vy, vz,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths)
|
||
|
||
# ── y 轴范围 ──
|
||
def get_ylim(arr):
|
||
vmax = np.max(np.abs(arr))
|
||
if vmax < 1e-10:
|
||
return -1.0, 1.0
|
||
m = vmax * 0.2
|
||
return -vmax - m, vmax + m
|
||
|
||
def get_ylim_pos(arr):
|
||
vmax = np.max(arr)
|
||
if vmax < 1e-12:
|
||
return 0.0, 1.0
|
||
return 0.0, vmax * 1.2
|
||
|
||
# 共用 y 轴范围:位移图取三个方向最大值统一
|
||
disp_vmax = max(np.max(np.abs(dx)), np.max(np.abs(dy)), np.max(np.abs(dz)))
|
||
disp_vmax = disp_vmax if disp_vmax > 1e-10 else 1.0
|
||
disp_ylim = (-disp_vmax * 1.2, disp_vmax * 1.2)
|
||
|
||
# 每粒子能量:取三者最大值统一 y 轴
|
||
energy_vmax = max(np.max(ek_atom), np.max(pe_atom), np.max(et_atom))
|
||
energy_vmax = energy_vmax if energy_vmax > 1e-12 else 1.0
|
||
energy_ylim = (0.0, energy_vmax * 1.2)
|
||
|
||
e_max = max(np.max(e_total), 0.01) * 1.3
|
||
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
|
||
|
||
# 能流 y 轴范围(对称,正负各半)
|
||
if flux.size > 0:
|
||
flux_vmax = np.max(np.abs(flux))
|
||
flux_vmax = flux_vmax if flux_vmax > 1e-12 else 1.0
|
||
flux_ylim = (-flux_vmax * 1.2, flux_vmax * 1.2)
|
||
else:
|
||
flux_ylim = (-1.0, 1.0)
|
||
|
||
atom_idx = np.arange(n_atoms)
|
||
|
||
# ── 图形布局:4 行 × 1 列,纵向排列 ──
|
||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
fig, (ax_wave, ax_energy, ax_flux, ax_ep) = plt.subplots(4, 1, figsize=(12, 18))
|
||
fig.suptitle("波形与能量分析", fontsize=16)
|
||
fig.subplots_adjust(hspace=0.42, top=0.95)
|
||
|
||
# ── 图1:x/y/z 位移波形叠加 ──
|
||
ax_wave.set_xlim(0, n_atoms - 1)
|
||
ax_wave.set_ylim(disp_ylim)
|
||
ax_wave.set_xlabel("原子序号")
|
||
ax_wave.set_ylabel("位移")
|
||
ax_wave.set_title("粒子位移(x / y / z 方向)")
|
||
ax_wave.grid(True, alpha=0.3)
|
||
|
||
wave_disps = [dx, dy, dz]
|
||
wave_labels = ["x 方向(纵波)", "y 方向(横波)", "z 方向(横波)"]
|
||
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
|
||
wave_lines = []
|
||
for label, color in zip(wave_labels, wave_colors):
|
||
ln, = ax_wave.plot([], [], color=color, linewidth=1.5, label=label)
|
||
wave_lines.append(ln)
|
||
ax_wave.legend(loc="upper right", fontsize=9)
|
||
time_text = ax_wave.text(0.02, 0.95, "", transform=ax_wave.transAxes,
|
||
fontsize=10, verticalalignment="top")
|
||
|
||
# ── 图2:每粒子动能、势能、总能叠加 ──
|
||
ax_energy.set_xlim(0, n_atoms - 1)
|
||
ax_energy.set_ylim(energy_ylim)
|
||
ax_energy.set_xlabel("原子序号")
|
||
ax_energy.set_ylabel("能量")
|
||
ax_energy.set_title("每粒子能量(动能 / 势能 / 总能)")
|
||
ax_energy.grid(True, alpha=0.3)
|
||
|
||
energy_arrays = [ek_atom, pe_atom, et_atom]
|
||
energy_labels = ["动能", "势能", "总能"]
|
||
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
|
||
energy_lines = []
|
||
for label, color in zip(energy_labels, energy_colors):
|
||
ln, = ax_energy.plot([], [], color=color, linewidth=1.5, label=label)
|
||
energy_lines.append(ln)
|
||
ax_energy.legend(loc="upper right", fontsize=9)
|
||
|
||
# ── 图3:能流密度 J(Hardy 公式)──
|
||
xmin_flux = bond_xpos[0] if len(bond_xpos) > 0 else 0
|
||
xmax_flux = bond_xpos[-1] if len(bond_xpos) > 0 else n_atoms - 1
|
||
ax_flux.set_xlim(xmin_flux, xmax_flux)
|
||
ax_flux.set_ylim(flux_ylim)
|
||
ax_flux.axhline(0, color="gray", linewidth=0.8, linestyle="--")
|
||
ax_flux.set_xlabel("位置(键中点 x 坐标)")
|
||
ax_flux.set_ylabel("能流密度 J")
|
||
ax_flux.set_title("键能流密度 J = ½ F·(vᵢ+vⱼ) (J>0 向右传播,J<0 向左传播)")
|
||
ax_flux.grid(True, alpha=0.3)
|
||
flux_line, = ax_flux.plot([], [], color="#dc2626", linewidth=1.5)
|
||
|
||
# ── 图4:系统总能量随时间 ──
|
||
ax_ep.set_xlim(t[0], t[-1])
|
||
ep_yhigh = max(e_max, p_max)
|
||
ep_ylow = min(-p_max * 0.1, 0.0)
|
||
ax_ep.set_ylim(ep_ylow, ep_yhigh)
|
||
ax_ep.set_xlabel("时间 (s)")
|
||
ax_ep.set_ylabel("能量 / 功率")
|
||
ax_ep.set_title("系统能量与输入功率")
|
||
ax_ep.grid(True, alpha=0.3)
|
||
|
||
ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
|
||
ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
|
||
ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
|
||
ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
|
||
ln_ug = None
|
||
ln_ugr = None
|
||
if gravity_field:
|
||
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
|
||
if gravity_interaction and n_atoms <= 200:
|
||
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
|
||
ax_ep.legend(loc="upper left", fontsize=9)
|
||
|
||
# ── 动画更新 ──
|
||
def update(frame):
|
||
# 图1:位移波形
|
||
for i, ln in enumerate(wave_lines):
|
||
ln.set_data(atom_idx, wave_disps[i][frame])
|
||
time_text.set_text(f"t = {t[frame]:.2f} s | 帧 {frame+1}/{n_frames}")
|
||
|
||
# 图2:每粒子能量
|
||
for i, ln in enumerate(energy_lines):
|
||
ln.set_data(atom_idx, energy_arrays[i][frame])
|
||
|
||
# 图3:能流密度
|
||
if flux.shape[1] > 0:
|
||
flux_line.set_data(bond_xpos, flux[frame])
|
||
|
||
# 图4:系统能量(累计到当前帧)
|
||
cur_t = t[:frame + 1]
|
||
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
|
||
ln_us.set_data(cur_t, us_sys[:frame + 1])
|
||
ln_et.set_data(cur_t, e_total[:frame + 1])
|
||
ln_pw.set_data(cur_t, power[:frame + 1])
|
||
if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1])
|
||
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
|
||
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
|
||
|
||
artists = wave_lines + [time_text] + energy_lines + \
|
||
[flux_line, ln_ek, ln_us, ln_et, ln_pw]
|
||
if ln_ug: artists.append(ln_ug)
|
||
if ln_ugr: artists.append(ln_ugr)
|
||
return artists
|
||
|
||
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
|
||
|
||
# ── 输出文件 ──
|
||
gif_path = None
|
||
if save_gif:
|
||
gif_path = os.path.join(output_dir, "wave_animation.gif")
|
||
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
|
||
print(f"[plot_wave] GIF 已保存: {gif_path}")
|
||
|
||
if save_mp4:
|
||
try:
|
||
import matplotlib.animation as manim
|
||
ffmpeg_path = None
|
||
try:
|
||
import imageio_ffmpeg
|
||
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
|
||
except Exception:
|
||
pass
|
||
if ffmpeg_path and os.path.exists(ffmpeg_path):
|
||
plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
|
||
ffps = min(20, max(1, n_frames // 5))
|
||
writer = manim.FFMpegWriter(fps=ffps, codec="libx264",
|
||
extra_args=["-pix_fmt", "yuv420p"])
|
||
mp4_path = os.path.join(output_dir, "wave_animation.mp4")
|
||
ani.save(mp4_path, writer=writer)
|
||
print(f"[plot_wave] MP4 已保存: {mp4_path}")
|
||
except FileNotFoundError:
|
||
print("[plot_wave] 警告: 未找到 ffmpeg,跳过 MP4 输出")
|
||
except Exception as e:
|
||
print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过")
|
||
|
||
if show:
|
||
plt.show()
|
||
else:
|
||
plt.close(fig)
|
||
return gif_path
|
||
|
||
|
||
if __name__ == "__main__":
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
if len(sys.argv) > 1:
|
||
output_dir = os.path.abspath(sys.argv[1])
|
||
else:
|
||
output_dir = compute.get_output_dir(script_dir)
|
||
plot_wave(output_dir)
|