2ab3436235
New 4x2 layout: left col = x/y/z displacement waves, right col = per-atom KE/PE/total energy + system energy vs time. PE split 50/50 for normal bonds; 100% to non-driven atom when bonded to driver; driven atom PE = E_SHO - KE. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
509 lines
19 KiB
Python
509 lines
19 KiB
Python
"""
|
||
plot_wave.py
|
||
============
|
||
波形与能量动态图:读取 display.txt,绘制原子位移波形
|
||
(纵波 + 2 个横波)和系统能量/输入功率随时间变化的二维动画。
|
||
|
||
用法:
|
||
python plot_wave.py # 使用 dynamics 根目录下 output/
|
||
python plot_wave.py examples/case05/output # 指定案例输出目录
|
||
"""
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib.animation import FuncAnimation
|
||
import os
|
||
import sys
|
||
import json
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
import compute
|
||
|
||
|
||
def load_disp_data(output_dir):
|
||
"""加载 display.npz(优先)或 display.txt。"""
|
||
npz_path = os.path.join(output_dir, "display.npz")
|
||
txt_path = os.path.join(output_dir, "display.txt")
|
||
if os.path.exists(npz_path):
|
||
return compute.load_display_npz(npz_path)
|
||
if os.path.exists(txt_path):
|
||
return compute.load_display_txt(txt_path)
|
||
raise FileNotFoundError(f"找不到 display.npz 或 display.txt in {output_dir}")
|
||
|
||
|
||
def _header_json(header_fields, key, default):
|
||
raw = header_fields.get(key, "")
|
||
if not raw:
|
||
return default
|
||
try:
|
||
return json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
return default
|
||
|
||
|
||
def _load_wave_dataset(output_dir):
|
||
"""Load wave/energy plotting data from display metadata or sibling input files."""
|
||
disp_data = load_disp_data(output_dir)
|
||
header = disp_data["header_fields"]
|
||
|
||
x = disp_data["frames_x"]
|
||
y = disp_data["frames_y"]
|
||
z = disp_data["frames_z"]
|
||
vx = disp_data["frames_vx"]
|
||
vy = disp_data["frames_vy"]
|
||
vz = disp_data["frames_vz"]
|
||
atom_ids = np.array(disp_data["atom_ids"], dtype=np.int64)
|
||
n_frames = x.shape[0]
|
||
dt = float(header.get("DT", 0.001))
|
||
nstep = int(header.get("NSTEP", 1))
|
||
t = np.arange(n_frames, dtype=np.float64) * dt * nstep
|
||
|
||
masses = np.array(_header_json(header, "atom_masses", []), dtype=np.float64)
|
||
pos_0 = np.array(_header_json(header, "atom_positions", []), dtype=np.float64)
|
||
bond_pairs = np.array(_header_json(header, "bond_pairs", []), dtype=np.int64)
|
||
bond_stiffness = np.array(_header_json(header, "bond_stiffness", []), dtype=np.float64)
|
||
bond_rest_lengths = np.array(_header_json(header, "bond_rest_lengths", []), dtype=np.float64)
|
||
gravity_vec = _header_json(header, "G", [0.0, 0.0, 0.0])
|
||
|
||
# Backward-compatible fallback for older display.txt outputs.
|
||
if masses.size == 0 or pos_0.size == 0:
|
||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||
coord_path = os.path.join(input_dir, "coord.txt")
|
||
if os.path.exists(coord_path):
|
||
(_, masses_fb, _, positions_fb, _, _) = compute.load_coord_file(coord_path)
|
||
masses = np.array(masses_fb, dtype=np.float64)
|
||
pos_0 = np.array(positions_fb, dtype=np.float64)
|
||
else:
|
||
raise ValueError("display.txt 缺少 atom_masses/atom_positions 元数据,且未找到 input/coord.txt")
|
||
|
||
if bond_pairs.size == 0:
|
||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||
connection_path = os.path.join(input_dir, "connection.txt")
|
||
bond_path = os.path.join(input_dir, "bond.txt")
|
||
if os.path.exists(connection_path) and os.path.exists(bond_path):
|
||
bond_map = compute.load_bond_parameters(bond_path)
|
||
pairs_fb, _, stiffness_fb, rest_lengths_fb = compute.load_bond_connections(
|
||
connection_path, atom_ids, pos_0, bond_map)
|
||
bond_pairs = np.array(pairs_fb, dtype=np.int64)
|
||
bond_stiffness = np.array(stiffness_fb, dtype=np.float64)
|
||
bond_rest_lengths = np.array(rest_lengths_fb, dtype=np.float64)
|
||
|
||
return {
|
||
"n_frames": n_frames,
|
||
"t": t,
|
||
"x": x,
|
||
"y": y,
|
||
"z": z,
|
||
"vx": vx,
|
||
"vy": vy,
|
||
"vz": vz,
|
||
"pos_0": pos_0,
|
||
"masses": masses,
|
||
"atom_ids": atom_ids,
|
||
"bond_pairs": bond_pairs,
|
||
"bond_stiffness": bond_stiffness,
|
||
"bond_rest_lengths": bond_rest_lengths,
|
||
"gravity_field": int(header.get("gravity_field", 0)),
|
||
"gravity_interaction": int(header.get("gravity_interaction", 0)),
|
||
"gravity_strength": float(header.get("gravity_strength", 1.0)),
|
||
"G": gravity_vec,
|
||
"driving_force": int(header.get("driving_force", 0)),
|
||
}
|
||
|
||
|
||
def compute_energy(x, y, z, vx, vy, vz, masses, mass_arr,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
gravity_field, G, gravity_interaction, gravity_strength):
|
||
"""计算系统各能量分量。
|
||
|
||
Returns:
|
||
ek_sys: 系统动能 (n_frames,)
|
||
us_sys: 系统弹性势能 (n_frames,)
|
||
ug_sys: 系统重力势能 (n_frames,)
|
||
ugr_sys: 系统万有引力势能 (n_frames,)
|
||
"""
|
||
n_frames = x.shape[0]
|
||
masses_2d = masses[np.newaxis, :] # (1, n_atoms)
|
||
|
||
# 动能 Ek = ½ m v²
|
||
ek = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2)
|
||
ek_sys = np.sum(ek, axis=1)
|
||
|
||
# 弹性势能 Us = ½ k (d - d₀)²
|
||
us_sys = np.zeros(n_frames)
|
||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||
for b_idx in range(len(bond_pairs)):
|
||
i, j = bond_pairs[b_idx]
|
||
dx = x[:, j] - x[:, i]
|
||
dy = y[:, j] - y[:, i]
|
||
dz = z[:, j] - z[:, i]
|
||
dist = np.sqrt(dx**2 + dy**2 + dz**2)
|
||
stretch = dist - bond_rest_lengths[b_idx]
|
||
us_sys += 0.5 * bond_stiffness[b_idx] * stretch**2
|
||
|
||
# 均匀重力场势能 Ug = -m G·r
|
||
ug_sys = np.zeros(n_frames)
|
||
if gravity_field:
|
||
G_vec = np.array(G)
|
||
ug_sys = -masses_2d * (G_vec[0] * x + G_vec[1] * y + G_vec[2] * z)
|
||
ug_sys = np.sum(ug_sys, axis=1)
|
||
|
||
# 万有引力势能 Ug_grav = -G_grav Σ m_i m_j / r
|
||
ugr_sys = np.zeros(n_frames)
|
||
if gravity_interaction:
|
||
n_atoms = len(masses)
|
||
# 为避免巨大计算量,仅当原子数较少时计算
|
||
if n_atoms <= 200:
|
||
for i in range(n_atoms):
|
||
for j in range(i + 1, n_atoms):
|
||
dx = x[:, j] - x[:, i]
|
||
dy = y[:, j] - y[:, i]
|
||
dz = z[:, j] - z[:, i]
|
||
dist = np.sqrt(dx**2 + dy**2 + dz**2)
|
||
dist = np.maximum(dist, 1e-12)
|
||
pair_pe = -gravity_strength * masses[i] * masses[j] / dist
|
||
ugr_sys += pair_pe
|
||
|
||
return ek_sys, us_sys, ug_sys, ugr_sys
|
||
|
||
|
||
def _load_driver_info(output_dir):
|
||
"""从 input/driver.txt 读取驱动原子的 atom_id, amp, freq(仅用于能量计算)。
|
||
返回 dict: {atom_id: {'amp': [ax,ay,az], 'freq': [fx,fy,fz]}},失败返回 {}。
|
||
"""
|
||
input_dir = os.path.join(os.path.dirname(output_dir), "input")
|
||
path = os.path.join(input_dir, "driver.txt")
|
||
if not os.path.exists(path):
|
||
return {}
|
||
drivers = {}
|
||
try:
|
||
with open(path, encoding="utf-8") as f:
|
||
f.readline() # skip header
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
parts = line.split()
|
||
if len(parts) < 10:
|
||
continue
|
||
n = int(parts[0])
|
||
amp = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
|
||
freq = np.array([float(parts[4]), float(parts[5]), float(parts[6])])
|
||
drivers[n] = {"amp": amp, "freq": freq}
|
||
except Exception:
|
||
pass
|
||
return drivers
|
||
|
||
|
||
def compute_per_atom_energy(x, y, z, vx, vy, vz, masses,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
atom_ids, driver_info):
|
||
"""计算每帧每个粒子的动能、势能、总能。
|
||
|
||
势能分配规则:
|
||
- 非驱动粒子之间的键:势能各分一半
|
||
- 键的一端是驱动粒子:势能全部归非驱动端
|
||
- 驱动粒子自身:按简谐振子 PE = ½ m (2πf)² A² cos²(2πft+φ) 计算
|
||
(此处用 KE_driven = ½m·v² 的互补式:PE_driven = E_total_sho - KE_driven)
|
||
|
||
Returns:
|
||
ek_atom: (n_frames, n_atoms) 每原子动能
|
||
pe_atom: (n_frames, n_atoms) 每原子势能
|
||
et_atom: (n_frames, n_atoms) 每原子总能
|
||
"""
|
||
n_frames, n_atoms = x.shape
|
||
|
||
# 动能 per atom
|
||
masses_2d = masses[np.newaxis, :]
|
||
ek_atom = 0.5 * masses_2d * (vx**2 + vy**2 + vz**2) # (n_frames, n_atoms)
|
||
|
||
# 构建 atom_id → index 映射,以及驱动原子 index 集合
|
||
id_to_idx = {int(aid): i for i, aid in enumerate(atom_ids)}
|
||
driven_idx = set()
|
||
for aid in driver_info:
|
||
if aid in id_to_idx:
|
||
driven_idx.add(id_to_idx[aid])
|
||
|
||
# 势能 per atom(弹簧键)
|
||
pe_atom = np.zeros((n_frames, n_atoms))
|
||
if bond_pairs is not None and len(bond_pairs) > 0:
|
||
for b in range(len(bond_pairs)):
|
||
i, j = int(bond_pairs[b, 0]), int(bond_pairs[b, 1])
|
||
ddx = x[:, j] - x[:, i]
|
||
ddy = y[:, j] - y[:, i]
|
||
ddz = z[:, j] - z[:, i]
|
||
dist = np.sqrt(ddx**2 + ddy**2 + ddz**2)
|
||
bond_pe = 0.5 * bond_stiffness[b] * (dist - bond_rest_lengths[b])**2
|
||
i_driven = i in driven_idx
|
||
j_driven = j in driven_idx
|
||
if i_driven and not j_driven:
|
||
pe_atom[:, j] += bond_pe # 全归非驱动端
|
||
elif j_driven and not i_driven:
|
||
pe_atom[:, i] += bond_pe # 全归非驱动端
|
||
else:
|
||
pe_atom[:, i] += bond_pe * 0.5
|
||
pe_atom[:, j] += bond_pe * 0.5
|
||
|
||
# 驱动粒子:用简谐振子总能 E_sho = ½m(2πf)²A²,PE_sho = E_sho - KE
|
||
TWO_PI = 2.0 * np.pi
|
||
for aid, info in driver_info.items():
|
||
if aid not in id_to_idx:
|
||
continue
|
||
idx = id_to_idx[aid]
|
||
m = masses[idx]
|
||
amp = info["amp"]
|
||
freq = info["freq"]
|
||
e_sho = 0.5 * m * np.sum((TWO_PI * freq)**2 * amp**2)
|
||
pe_sho = e_sho - ek_atom[:, idx]
|
||
pe_atom[:, idx] = np.maximum(pe_sho, 0.0) # SHO 势能非负
|
||
|
||
et_atom = ek_atom + pe_atom
|
||
return ek_atom, pe_atom, et_atom
|
||
|
||
|
||
def plot_wave(output_dir, save_gif=False, save_mp4=False, show=True):
|
||
"""主绘图函数:读取 display.txt 并生成波形+能量动画。
|
||
|
||
布局(4行×2列):
|
||
左列 行0-2:x/y/z 位移波形(vs 原子序号)
|
||
右列 行0:每粒子动能
|
||
右列 行1:每粒子势能
|
||
右列 行2:每粒子总能
|
||
右列 行3:系统总能量随时间变化
|
||
|
||
Args:
|
||
output_dir: 输出目录(含 display.npz 或 display.txt)
|
||
save_gif: 是否保存 GIF
|
||
save_mp4: 是否保存 MP4
|
||
show: 是否弹出交互窗口
|
||
"""
|
||
data = _load_wave_dataset(output_dir)
|
||
|
||
n_frames = int(data["n_frames"])
|
||
t = np.array(data["t"])
|
||
|
||
x = np.array(data["x"])
|
||
y = np.array(data["y"])
|
||
z = np.array(data["z"])
|
||
vx = np.array(data["vx"])
|
||
vy = np.array(data["vy"])
|
||
vz = np.array(data["vz"])
|
||
|
||
pos_0 = np.array(data["pos_0"])
|
||
masses = np.array(data["masses"])
|
||
atom_ids = np.array(data["atom_ids"])
|
||
n_atoms = len(atom_ids)
|
||
|
||
bond_pairs = np.array(data.get("bond_pairs", []), dtype=np.int64)
|
||
bond_stiffness = np.array(data.get("bond_stiffness", []), dtype=np.float64)
|
||
bond_rest_lengths= np.array(data.get("bond_rest_lengths",[]), dtype=np.float64)
|
||
|
||
gravity_field = int(data.get("gravity_field", 0))
|
||
gravity_interaction = int(data.get("gravity_interaction", 0))
|
||
G = data.get("G", [0, 0, 0])
|
||
gravity_strength = float(data.get("gravity_strength", 1.0))
|
||
driving_force = int(data.get("driving_force", 0))
|
||
|
||
# 驱动原子信息(用于势能计算)
|
||
driver_info = _load_driver_info(output_dir) if driving_force else {}
|
||
|
||
# ── 位移 ──
|
||
dx = x - pos_0[np.newaxis, :, 0]
|
||
dy = y - pos_0[np.newaxis, :, 1]
|
||
dz = z - pos_0[np.newaxis, :, 2]
|
||
|
||
# ── 系统总能量(用于右下时间图)──
|
||
ek_sys, us_sys, ug_sys, ugr_sys = compute_energy(
|
||
x, y, z, vx, vy, vz, masses, masses,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
gravity_field, G, gravity_interaction, gravity_strength)
|
||
e_total = ek_sys + us_sys + ug_sys + ugr_sys
|
||
power = np.gradient(e_total, t)
|
||
|
||
# ── 每粒子能量 ──
|
||
ek_atom, pe_atom, et_atom = compute_per_atom_energy(
|
||
x, y, z, vx, vy, vz, masses,
|
||
bond_pairs, bond_stiffness, bond_rest_lengths,
|
||
atom_ids, driver_info)
|
||
|
||
# ── y 轴范围 ──
|
||
def get_ylim(arr):
|
||
vmax = np.max(np.abs(arr))
|
||
if vmax < 1e-10:
|
||
return -1.0, 1.0
|
||
m = vmax * 0.2
|
||
return -vmax - m, vmax + m
|
||
|
||
def get_ylim_pos(arr):
|
||
vmax = np.max(arr)
|
||
if vmax < 1e-12:
|
||
return 0.0, 1.0
|
||
return 0.0, vmax * 1.2
|
||
|
||
disp_ylims = [get_ylim(dx), get_ylim(dy), get_ylim(dz)]
|
||
ek_ylim = get_ylim_pos(ek_atom)
|
||
pe_ylim = get_ylim_pos(pe_atom)
|
||
et_ylim = get_ylim_pos(et_atom)
|
||
e_max = max(np.max(e_total), 0.01) * 1.3
|
||
p_max = max(np.max(np.abs(power)) * 1.3, 0.01)
|
||
|
||
atom_idx = np.arange(n_atoms)
|
||
|
||
# ── 图形布局 ──
|
||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'DejaVu Sans']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
fig = plt.figure(figsize=(16, 14))
|
||
fig.suptitle("波形与能量分析", fontsize=16)
|
||
|
||
import matplotlib.gridspec as gridspec
|
||
gs = gridspec.GridSpec(4, 2, figure=fig, hspace=0.45, wspace=0.35)
|
||
|
||
# 左列:位移波形(行 0-2)
|
||
ax_dx = fig.add_subplot(gs[0, 0])
|
||
ax_dy = fig.add_subplot(gs[1, 0])
|
||
ax_dz = fig.add_subplot(gs[2, 0])
|
||
# 左列行 3 留空(或与右下对齐)
|
||
ax_blank = fig.add_subplot(gs[3, 0])
|
||
ax_blank.set_visible(False)
|
||
|
||
# 右列:每粒子能量(行 0-2)+ 时间能量图(行 3)
|
||
ax_ek = fig.add_subplot(gs[0, 1])
|
||
ax_pe = fig.add_subplot(gs[1, 1])
|
||
ax_et = fig.add_subplot(gs[2, 1])
|
||
ax_ep = fig.add_subplot(gs[3, 1])
|
||
|
||
# ── 初始化位移波形 ──
|
||
wave_axes = [ax_dx, ax_dy, ax_dz]
|
||
wave_disps = [dx, dy, dz]
|
||
wave_titles = ["纵波 (x 方向位移)", "横波 (y 方向位移)", "横波 (z 方向位移)"]
|
||
wave_colors = ["#2563eb", "#ea580c", "#16a34a"]
|
||
wave_lines = []
|
||
time_texts = []
|
||
|
||
for ax, disp, title, color, yl in zip(wave_axes, wave_disps, wave_titles, wave_colors, disp_ylims):
|
||
ax.set_xlim(0, n_atoms - 1)
|
||
ax.set_ylim(yl)
|
||
ax.set_xlabel("原子序号")
|
||
ax.set_ylabel("位移")
|
||
ax.set_title(title)
|
||
ax.grid(True, alpha=0.3)
|
||
ln, = ax.plot([], [], color=color, linewidth=1.5)
|
||
wave_lines.append(ln)
|
||
tt = ax.text(0.02, 0.95, "", transform=ax.transAxes,
|
||
fontsize=9, verticalalignment="top")
|
||
time_texts.append(tt)
|
||
|
||
# ── 初始化每粒子能量图 ──
|
||
energy_axes = [ax_ek, ax_pe, ax_et]
|
||
energy_arrays = [ek_atom, pe_atom, et_atom]
|
||
energy_titles = ["每粒子动能", "每粒子势能", "每粒子总能"]
|
||
energy_colors = ["#1d4ed8", "#b45309", "#7c3aed"]
|
||
energy_ylims = [ek_ylim, pe_ylim, et_ylim]
|
||
energy_lines = []
|
||
|
||
for ax, arr, title, color, yl in zip(energy_axes, energy_arrays, energy_titles, energy_colors, energy_ylims):
|
||
ax.set_xlim(0, n_atoms - 1)
|
||
ax.set_ylim(yl)
|
||
ax.set_xlabel("原子序号")
|
||
ax.set_ylabel("能量")
|
||
ax.set_title(title)
|
||
ax.grid(True, alpha=0.3)
|
||
ln, = ax.plot([], [], color=color, linewidth=1.5)
|
||
energy_lines.append(ln)
|
||
|
||
# ── 初始化系统总能量时间图 ──
|
||
ax_ep.set_xlim(t[0], t[-1])
|
||
ep_yhigh = max(e_max, p_max)
|
||
ep_ylow = min(-p_max * 0.1, 0.0)
|
||
ax_ep.set_ylim(ep_ylow, ep_yhigh)
|
||
ax_ep.set_xlabel("时间 (s)")
|
||
ax_ep.set_ylabel("能量 / 功率")
|
||
ax_ep.set_title("系统能量与输入功率")
|
||
ax_ep.grid(True, alpha=0.3)
|
||
|
||
ln_ek, = ax_ep.plot([], [], "b-", lw=1.5, label="动能")
|
||
ln_us, = ax_ep.plot([], [], "orange", lw=1.5, label="弹性势能")
|
||
ln_et, = ax_ep.plot([], [], "r--", lw=1.5, label="总能量")
|
||
ln_pw, = ax_ep.plot([], [], "g-", lw=1.5, alpha=0.7, label="输入功率 (dE/dt)")
|
||
ln_ug = None
|
||
ln_ugr = None
|
||
if gravity_field:
|
||
ln_ug, = ax_ep.plot([], [], "purple", lw=1.0, alpha=0.5, label="重力势能")
|
||
if gravity_interaction and n_atoms <= 200:
|
||
ln_ugr, = ax_ep.plot([], [], "brown", lw=1.0, alpha=0.5, label="万有引力势能")
|
||
ax_ep.legend(loc="upper left", fontsize=8)
|
||
|
||
# ── 动画更新 ──
|
||
def update(frame):
|
||
# 位移波形
|
||
for i in range(3):
|
||
wave_lines[i].set_data(atom_idx, wave_disps[i][frame])
|
||
time_texts[i].set_text(f"t = {t[frame]:.2f} s 帧 {frame+1}/{n_frames}")
|
||
|
||
# 每粒子能量
|
||
for i in range(3):
|
||
energy_lines[i].set_data(atom_idx, energy_arrays[i][frame])
|
||
|
||
# 系统总能量(累计到当前帧)
|
||
cur_t = t[:frame + 1]
|
||
ln_ek.set_data(cur_t, ek_sys[:frame + 1])
|
||
ln_us.set_data(cur_t, us_sys[:frame + 1])
|
||
ln_et.set_data(cur_t, e_total[:frame + 1])
|
||
ln_pw.set_data(cur_t, power[:frame + 1])
|
||
if ln_ug: ln_ug.set_data(cur_t, ug_sys[:frame + 1])
|
||
if ln_ugr: ln_ugr.set_data(cur_t, ugr_sys[:frame + 1])
|
||
ax_ep.set_xlim(t[0], max(t[frame] + max(t[-1] * 0.05, 1), t[-1]))
|
||
|
||
artists = wave_lines + time_texts + energy_lines + \
|
||
[ln_ek, ln_us, ln_et, ln_pw]
|
||
if ln_ug: artists.append(ln_ug)
|
||
if ln_ugr: artists.append(ln_ugr)
|
||
return artists
|
||
|
||
ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)
|
||
|
||
# ── 输出文件 ──
|
||
gif_path = None
|
||
if save_gif:
|
||
gif_path = os.path.join(output_dir, "wave_animation.gif")
|
||
ani.save(gif_path, writer="pillow", fps=min(20, max(1, n_frames // 5)))
|
||
print(f"[plot_wave] GIF 已保存: {gif_path}")
|
||
|
||
if save_mp4:
|
||
try:
|
||
import matplotlib.animation as manim
|
||
ffmpeg_path = None
|
||
try:
|
||
import imageio_ffmpeg
|
||
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
|
||
except Exception:
|
||
pass
|
||
if ffmpeg_path and os.path.exists(ffmpeg_path):
|
||
plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path
|
||
ffps = min(20, max(1, n_frames // 5))
|
||
writer = manim.FFMpegWriter(fps=ffps, codec="libx264",
|
||
extra_args=["-pix_fmt", "yuv420p"])
|
||
mp4_path = os.path.join(output_dir, "wave_animation.mp4")
|
||
ani.save(mp4_path, writer=writer)
|
||
print(f"[plot_wave] MP4 已保存: {mp4_path}")
|
||
except FileNotFoundError:
|
||
print("[plot_wave] 警告: 未找到 ffmpeg,跳过 MP4 输出")
|
||
except Exception as e:
|
||
print(f"[plot_wave] 警告: MP4 输出失败 ({e}),跳过")
|
||
|
||
if show:
|
||
plt.show()
|
||
else:
|
||
plt.close(fig)
|
||
return gif_path
|
||
|
||
|
||
if __name__ == "__main__":
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
if len(sys.argv) > 1:
|
||
output_dir = os.path.abspath(sys.argv[1])
|
||
else:
|
||
output_dir = compute.get_output_dir(script_dir)
|
||
plot_wave(output_dir)
|