This commit is contained in:
2026-05-17 08:47:25 +08:00
parent 1159d86b8b
commit 45513fe334
27 changed files with 4734 additions and 2 deletions
+492
View File
@@ -0,0 +1,492 @@
"""VisPy 演示:加载预计算轨迹数据,驱动小球运动动画。
计算与显示完全分离:
1. 先运行 compute.py → 生成 output/trajectory.txt(全量 NT 步轨迹)
2. 再运行 sample.py → 从 output/trajectory.txt 抽帧生成 output/display.txt
3. 本文件加载 output/display.txt,按帧播放动画
用法:
python draw.py # 使用 dynamics 根目录下的 output/
python draw.py examples/case01/output # 指定案例输出目录
"""
import numpy as np
import os
import sys
from vispy import app, scene
from vispy.visuals.transforms import STTransform
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import compute
# ===========================================================================
# 加载预计算轨迹
# ===========================================================================
script_dir = os.path.dirname(os.path.abspath(__file__))
if len(sys.argv) > 1:
# 用户指定了输出目录
output_dir = os.path.abspath(sys.argv[1])
else:
output_dir = compute.get_output_dir(script_dir)
os.environ["DYNAMICS_OUTPUT_DIR"] = output_dir
disp_path = os.path.join(output_dir, "display.txt")
if not os.path.exists(disp_path):
raise FileNotFoundError(
f"找不到 display.txt\n"
f"期望路径: {disp_path}\n"
f"请先运行 compute.py 计算轨迹,再运行 sample.py 生成显示数组。\n"
f"用法: python draw.py [案例输出目录]"
)
disp_data = compute.load_text_data(disp_path)
# 单原子数据(plot_atom:用于信息显示)
DISP_X = disp_data["disp_x"]
DISP_Y = disp_data["disp_y"]
DISP_Z = disp_data["disp_z"]
DISP_VX = disp_data["disp_vx"]
DISP_VY = disp_data["disp_vy"]
DISP_VZ = disp_data["disp_vz"]
# 全原子数据(用于多球绘制)
DISP_ALL_X = disp_data["disp_all_x"] # (n_frames, n_atoms)
DISP_ALL_Y = disp_data["disp_all_y"]
DISP_ALL_Z = disp_data["disp_all_z"]
DISP_ALL_VX = disp_data["disp_all_vx"]
DISP_ALL_VY = disp_data["disp_all_vy"]
DISP_ALL_VZ = disp_data["disp_all_vz"]
DISP_T = disp_data["disp_t"]
DISP_STEP = disp_data["disp_step"]
N_FRAMES = int(disp_data["n_frames"])
NT = int(disp_data["NT"])
DT = float(disp_data["DT"])
NSTEP = int(disp_data["NSTEP"])
# 原子信息
ATOM_IDS = disp_data.get("atom_ids", np.array([1]))
ATOM_RADII = disp_data.get("atom_radii", np.array([float(disp_data["ball_radius"])]))
N_ATOMS = len(ATOM_IDS)
PLOT_ATOM_ROW = int(disp_data.get("plot_atom_row", 0))
PLOT_ATOM_ID = int(disp_data.get("plot_atom_id", ATOM_IDS[0]))
BOND_PAIRS = disp_data.get("bond_pairs", [])
if N_FRAMES <= 0:
raise ValueError(
"output/display.txt 中没有可播放的帧,请检查 sample_start/sample_end/NSTEP 配置。")
# 保留模拟边界常量(用于场景缩放、相机等),从 output/display.txt 中读取
X_MIN = float(disp_data["X_MIN"]); X_MAX = float(disp_data["X_MAX"])
Y_MIN = float(disp_data["Y_MIN"]); Y_MAX = float(disp_data["Y_MAX"])
Z_MIN = float(disp_data["Z_MIN"]); Z_MAX = float(disp_data["Z_MAX"])
X0 = float(disp_data["X0"]); Y0 = float(disp_data["Y0"]); Z0 = float(disp_data["Z0"])
raw_alpha = disp_data["alpha"]
if isinstance(raw_alpha, (list, tuple, np.ndarray)):
alpha_list = [float(a) for a in raw_alpha]
if len(alpha_list) != 6:
raise ValueError(f"alpha 数组长度须为 6,实际为 {len(alpha_list)}")
else:
alpha_list = [float(raw_alpha)] * 6
# 绘图参数
ball_radius = float(disp_data["ball_radius"])
ball_color_r = float(disp_data["ball_color_r"])
ball_color_g = float(disp_data["ball_color_g"])
ball_color_b = float(disp_data["ball_color_b"])
box_color_r = float(disp_data["box_color_r"])
box_color_g = float(disp_data["box_color_g"])
box_color_b = float(disp_data["box_color_b"])
# ===========================================================================
# 图形界面无关的几何参数(不参与物理计算,仅用于场景外观)
# ===========================================================================
info_margin = 36
axis_length = 10.0
initial_camera = {
"distance": 40.0,
"elevation": 0,
"azimuth": 0,
"center": (0, 0, 0),
}
# ===========================================================================
# 创建画布与相机
# ===========================================================================
canvas = scene.SceneCanvas(
keys="interactive",
size=(1000, 700),
bgcolor=(0.08, 0.08, 0.10, 1.0),
show=True,
)
view = canvas.central_widget.add_view()
view.camera = "turntable"
view.camera.distance = initial_camera["distance"]
view.camera.elevation = initial_camera["elevation"]
view.camera.azimuth = initial_camera["azimuth"]
view.camera.center = initial_camera["center"]
# ===========================================================================
# 场景对象
# ===========================================================================
axis_width = 3
arrow_size = 14
axes_visible = True
axes_group = []
axes_group.append(scene.visuals.Arrow(
pos=np.array([[0, 0, 0], [axis_length, 0, 0]], dtype=np.float32),
color=(1.0, 0.2, 0.2, 1.0),
width=axis_width,
arrows=np.array([[0, 0, 0, axis_length, 0, 0]], dtype=np.float32),
arrow_size=arrow_size,
parent=view.scene,
))
axes_group.append(scene.visuals.Arrow(
pos=np.array([[0, 0, 0], [0, axis_length, 0]], dtype=np.float32),
color=(0.2, 1.0, 0.2, 1.0),
width=axis_width,
arrows=np.array([[0, 0, 0, 0, axis_length, 0]], dtype=np.float32),
arrow_size=arrow_size,
parent=view.scene,
))
axes_group.append(scene.visuals.Arrow(
pos=np.array([[0, 0, 0], [0, 0, axis_length]], dtype=np.float32),
color=(0.3, 0.6, 1.0, 1.0),
width=axis_width,
arrows=np.array([[0, 0, 0, 0, 0, axis_length]], dtype=np.float32),
arrow_size=arrow_size,
parent=view.scene,
))
axes_group.append(scene.visuals.Text(text="x", color=(1.0, 0.2, 0.2, 1.0), font_size=18,
pos=(axis_length + 0.2, 0, 0), anchor_x="left", anchor_y="center", parent=view.scene))
axes_group.append(scene.visuals.Text(text="y", color=(0.2, 1.0, 0.2, 1.0), font_size=18,
pos=(0, axis_length + 0.2, 0), anchor_x="left", anchor_y="bottom", parent=view.scene))
axes_group.append(scene.visuals.Text(text="z", color=(0.3, 0.6, 1.0, 1.0), font_size=18,
pos=(0, 0, axis_length + 0.2), anchor_x="left", anchor_y="bottom", parent=view.scene))
# 所有小球(每个原子一个球,不同颜色)
TAB10_RGB = np.array([
[0.1216, 0.4667, 0.7059], # 蓝
[0.8902, 0.4667, 0.1137], # 橙
[0.1725, 0.6275, 0.1725], # 绿
[0.8392, 0.1529, 0.1569], # 红
[0.5804, 0.4039, 0.7412], # 紫
[0.5490, 0.3373, 0.2941], # 棕
[0.8902, 0.4667, 0.7608], # 粉
[0.4980, 0.4980, 0.4980], # 灰
[0.7373, 0.7412, 0.1333], # 黄绿
[0.0902, 0.7451, 0.8118], # 青
])
balls = []
for i in range(N_ATOMS):
r, g, b = TAB10_RGB[i % len(TAB10_RGB)]
s = scene.visuals.Sphere(
radius=float(ATOM_RADII[i]), method="latitude",
color=(r, g, b, 1.0), edge_color=None, parent=view.scene)
s.mesh.shading = "smooth"
balls.append(s)
# 成键线(原子之间的连接)
if len(BOND_PAIRS) > 0:
n_bonds = len(BOND_PAIRS)
bond_pos = np.zeros((n_bonds * 2, 3), dtype=np.float32)
bond_lines = scene.visuals.Line(
pos=bond_pos, color=(0.7, 0.7, 0.7, 0.8), width=3,
connect="segments", parent=view.scene)
else:
bond_lines = None
# 六个面形成立方体边界(每个面独立透明度,alpha<=0 时隐藏该面)
box_size = X_MAX - X_MIN
faces = [
((X_MAX, 0, 0), "-x"),
((X_MIN, 0, 0), "+x"),
((0, Y_MAX, 0), "-y"),
((0, Y_MIN, 0), "+y"),
((0, 0, Z_MAX), "-z"),
((0, 0, Z_MIN), "+z"),
]
for f_idx, (pos, direction) in enumerate(faces):
a = alpha_list[f_idx]
if a <= 0:
continue
face_color = (box_color_r, box_color_g, box_color_b, a)
plane = scene.visuals.Plane(
width=box_size, height=box_size, width_segments=1, height_segments=1,
direction=direction, color=face_color, parent=view.scene)
plane.set_gl_state(depth_test=False, blend=True)
plane.transform = STTransform(translate=pos)
# 右上角:相机信息
camera_info = scene.visuals.Text(
text="", color="white", font_size=14,
pos=(0, 0), anchor_x="right", anchor_y="top", parent=canvas.scene)
# 左上角:小球信息
ball_info = scene.visuals.Text(
text="", color=(0.2, 1.0, 0.2, 1.0), font_size=28,
pos=(0, 0), anchor_x="left", anchor_y="top",
face="黑体", bold=True, parent=canvas.scene)
# 左上角:reset 按钮(在 info 上方)
reset_btn_size = (60, 30)
reset_button = scene.visuals.Rectangle(
center=(reset_btn_size[0] / 2 + 8, reset_btn_size[1] / 2 + 8),
width=reset_btn_size[0], height=reset_btn_size[1],
radius=6, color=(0.18, 0.35, 0.65, 0.85),
border_color="white", parent=canvas.scene)
reset_button_label = scene.visuals.Text(
text="reset", color="white", font_size=16,
pos=(reset_btn_size[0] / 2 + 8, reset_btn_size[1] / 2 + 8),
anchor_x="center", anchor_y="center",
bold=True, parent=canvas.scene)
# 信息显示/隐藏 切换按钮(左上角小方块,在 reset 下方)
info_btn_size = (60, 30)
info_toggle_visible = True
info_button = scene.visuals.Rectangle(
center=(info_btn_size[0] / 2 + 8, info_btn_size[1] / 2 + 8),
width=info_btn_size[0], height=info_btn_size[1],
radius=6, color=(0.9, 0.3, 0.3, 0.9),
border_color="white", parent=canvas.scene)
info_button_label = scene.visuals.Text(
text="info", color="white", font_size=16,
pos=(info_btn_size[0] / 2 + 8, info_btn_size[1] / 2 + 8),
anchor_x="center", anchor_y="center",
bold=True, parent=canvas.scene)
# 强制在初始化时定位到左上角(canvas.scene坐标系:左下角为原点)
_cw, _ch = canvas.size
reset_button.center = (reset_btn_size[0] / 2 + 8, _ch - reset_btn_size[1] / 2 - 8 - info_btn_size[1] - 4)
reset_button_label.pos = reset_button.center
info_button.center = (info_btn_size[0] / 2 + 8, _ch - info_btn_size[1] / 2 - 8)
info_button_label.pos = info_button.center
# axes 显示/隐藏 按钮(在 info 下方)
axes_btn_size = (60, 30)
axes_button = scene.visuals.Rectangle(
center=(axes_btn_size[0] / 2 + 8, axes_btn_size[1] / 2 + 8),
width=axes_btn_size[0], height=axes_btn_size[1],
radius=6, color=(0.3, 0.7, 0.3, 0.9),
border_color="white", parent=canvas.scene)
axes_button_label = scene.visuals.Text(
text="axes", color="white", font_size=16,
pos=(axes_btn_size[0] / 2 + 8, axes_btn_size[1] / 2 + 8),
anchor_x="center", anchor_y="center",
bold=True, parent=canvas.scene)
axes_button.center = (axes_btn_size[0] / 2 + 8, _ch - axes_btn_size[1] / 2 - 8 - info_btn_size[1] - 4 - axes_btn_size[1] - 4)
axes_button_label.pos = axes_button.center
# ===========================================================================
# 回调函数
# ===========================================================================
def update_camera_info(event=None):
c = view.camera
camera_info.text = (
"Camera\n"
f"center = ({c.center[0]:.2f}, {c.center[1]:.2f}, {c.center[2]:.2f})\n"
f"distance = {c.distance:.2f}\n"
f"elevation = {c.elevation:.2f}\n"
f"azimuth = {c.azimuth:.2f}"
)
def update_ball_info(frame_idx, x, y, z, vx, vy, vz):
step = int(DISP_STEP[frame_idx])
t = float(DISP_T[frame_idx])
ball_info.text = (
f"原子 {PLOT_ATOM_ID} (共 {N_ATOMS} 个原子)\n"
f"frame {frame_idx+1}/{N_FRAMES} | saved step {step}/{NT-1}\n"
f"t = {t:.2f} s | dt = {DT:.3f} s | nstep = {NSTEP}\n"
f"Position: ({x:.2f}, {y:.2f}, {z:.2f})\n"
f"Velocity: ({vx:.2f}, {vy:.2f}, {vz:.2f})\n"
)
def reposition_camera_info(event=None):
width, height = canvas.size
camera_info.pos = (width - 20, height - 20)
# reset → info → axes 三按钮纵向排列
gap = 4
info_y = info_btn_size[1] / 2 + 8
reset_y = info_y + info_btn_size[1] + gap
axes_y = info_y + info_btn_size[1] + gap + axes_btn_size[1] + gap
reset_button.center = (reset_btn_size[0] / 2 + 8, height - reset_y)
reset_button_label.pos = reset_button.center
info_button.center = (info_btn_size[0] / 2 + 8, height - info_y)
info_button_label.pos = info_button.center
axes_button.center = (axes_btn_size[0] / 2 + 8, height - axes_y)
axes_button_label.pos = axes_button.center
# info 文字放在所有按钮下方
buttons_bottom = height - (axes_y + axes_btn_size[1] / 2)
ball_info.pos = (info_margin, buttons_bottom - 10)
update_ball_info(frame_idx, DISP_X[frame_idx], DISP_Y[frame_idx], DISP_Z[frame_idx],
DISP_VX[frame_idx], DISP_VY[frame_idx], DISP_VZ[frame_idx])
update_camera_info()
def handle_view_interaction(event):
update_camera_info()
def rotate_about_screen_normal(angle):
if hasattr(view.camera, "roll"):
view.camera.roll = (view.camera.roll + angle) % 360
else:
view.camera.azimuth = (view.camera.azimuth + angle) % 360
update_camera_info()
def handle_key_press(event):
key_name = ""
if getattr(event, "text", None):
key_name = event.text.lower()
elif getattr(event, "key", None) is not None:
key_name = str(event.key).lower()
if key_name == "q":
rotate_about_screen_normal(-90)
elif key_name == "e":
rotate_about_screen_normal(90)
def reset_camera_view():
global frame_idx
frame_idx = 0
# 立即复位所有小球到第 0 帧
for i in range(N_ATOMS):
balls[i].transform = STTransform(translate=(
float(DISP_ALL_X[frame_idx, i]),
float(DISP_ALL_Y[frame_idx, i]),
float(DISP_ALL_Z[frame_idx, i]),
))
if bond_lines is not None and len(BOND_PAIRS) > 0:
pos = np.zeros((len(BOND_PAIRS) * 2, 3), dtype=np.float32)
for b_idx, (i, j) in enumerate(BOND_PAIRS):
pos[b_idx * 2] = [DISP_ALL_X[frame_idx, i], DISP_ALL_Y[frame_idx, i], DISP_ALL_Z[frame_idx, i]]
pos[b_idx * 2 + 1] = [DISP_ALL_X[frame_idx, j], DISP_ALL_Y[frame_idx, j], DISP_ALL_Z[frame_idx, j]]
bond_lines.set_data(pos=pos)
# 复位相机
view.camera.distance = initial_camera["distance"]
view.camera.elevation = initial_camera["elevation"]
view.camera.azimuth = initial_camera["azimuth"]
view.camera.center = initial_camera["center"]
if hasattr(view.camera, "roll"):
view.camera.roll = 0
update_camera_info()
def handle_mouse_press(event):
global info_toggle_visible, axes_visible
if event.pos is None:
return
ex, ey = event.pos
# reset 按钮
bx, by = reset_button.center
lw = reset_btn_size[0] / 2; lh = reset_btn_size[1] / 2
if (bx - lw) <= ex <= (bx + lw) and (by - lh) <= ey <= (by + lh):
reset_camera_view()
return
# info 按钮
bx, by = info_button.center
if (bx - info_btn_size[0]/2) <= ex <= (bx + info_btn_size[0]/2) and (by - info_btn_size[1]/2) <= ey <= (by + info_btn_size[1]/2):
info_toggle_visible = not info_toggle_visible
ball_info.visible = info_toggle_visible
camera_info.visible = info_toggle_visible
return
# axes 按钮
bx, by = axes_button.center
if (bx - axes_btn_size[0]/2) <= ex <= (bx + axes_btn_size[0]/2) and (by - axes_btn_size[1]/2) <= ey <= (by + axes_btn_size[1]/2):
axes_visible = not axes_visible
for axis in axes_group:
axis.visible = axes_visible
# ===========================================================================
# 动画初始化
# ===========================================================================
frame_idx = 0
# 初始帧:摆放所有小球并刷新 UI
for i in range(N_ATOMS):
balls[i].transform = STTransform(translate=(
float(DISP_ALL_X[frame_idx, i]),
float(DISP_ALL_Y[frame_idx, i]),
float(DISP_ALL_Z[frame_idx, i]),
))
# 初始帧:更新成键线
if bond_lines is not None and len(BOND_PAIRS) > 0:
pos = np.zeros((len(BOND_PAIRS) * 2, 3), dtype=np.float32)
for b_idx, (i, j) in enumerate(BOND_PAIRS):
pos[b_idx * 2] = [DISP_ALL_X[frame_idx, i], DISP_ALL_Y[frame_idx, i], DISP_ALL_Z[frame_idx, i]]
pos[b_idx * 2 + 1] = [DISP_ALL_X[frame_idx, j], DISP_ALL_Y[frame_idx, j], DISP_ALL_Z[frame_idx, j]]
bond_lines.set_data(pos=pos)
reposition_camera_info()
update_ball_info(frame_idx,
float(DISP_X[frame_idx]), float(DISP_Y[frame_idx]), float(DISP_Z[frame_idx]),
float(DISP_VX[frame_idx]), float(DISP_VY[frame_idx]), float(DISP_VZ[frame_idx]))
print(f"[draw] 加载 output/display.txt: {N_FRAMES} 帧, {N_ATOMS} 个原子, NT={NT}, DT={DT}, NSTEP={NSTEP}")
print(f"[draw] 绘图参数: ball_radius={ball_radius}, box_color=({box_color_r:.2f},{box_color_g:.2f},{box_color_b:.2f}), alpha={alpha_list}")
# ===========================================================================
# 每帧回调:仅推进帧索引,从预存数组读取位置,零物理计算
# ===========================================================================
def update(event):
global frame_idx
frame_idx = (frame_idx + 1) % N_FRAMES # 循环播放
# 更新所有小球位置
for i in range(N_ATOMS):
x = float(DISP_ALL_X[frame_idx, i])
y = float(DISP_ALL_Y[frame_idx, i])
z = float(DISP_ALL_Z[frame_idx, i])
balls[i].transform = STTransform(translate=(x, y, z))
# 更新成键线
if bond_lines is not None and len(BOND_PAIRS) > 0:
pos = np.zeros((len(BOND_PAIRS) * 2, 3), dtype=np.float32)
for b_idx, (i, j) in enumerate(BOND_PAIRS):
pos[b_idx * 2] = [DISP_ALL_X[frame_idx, i], DISP_ALL_Y[frame_idx, i], DISP_ALL_Z[frame_idx, i]]
pos[b_idx * 2 + 1] = [DISP_ALL_X[frame_idx, j], DISP_ALL_Y[frame_idx, j], DISP_ALL_Z[frame_idx, j]]
bond_lines.set_data(pos=pos)
# 信息面板显示 plot_atom 的数据
x = float(DISP_X[frame_idx])
y = float(DISP_Y[frame_idx])
z = float(DISP_Z[frame_idx])
vx = float(DISP_VX[frame_idx])
vy = float(DISP_VY[frame_idx])
vz = float(DISP_VZ[frame_idx])
update_ball_info(frame_idx, x, y, z, vx, vy, vz)
update_camera_info()
# ===========================================================================
# 事件绑定与启动
# ===========================================================================
timer = app.Timer(interval=0.02, connect=update, start=True)
canvas.events.mouse_move.connect(handle_view_interaction)
canvas.events.mouse_press.connect(handle_mouse_press)
canvas.events.mouse_wheel.connect(handle_view_interaction)
canvas.events.resize.connect(reposition_camera_info)
canvas.events.key_press.connect(handle_key_press)
if hasattr(canvas, "native") and hasattr(canvas.native, "setFocus"):
canvas.native.setFocus()
if __name__ == "__main__":
app.run()