161 lines
4.5 KiB
Python
161 lines
4.5 KiB
Python
"""
|
|
migrate_npz_outputs.py
|
|
----------------------
|
|
将旧版 `.npz` 运行产物迁移为当前使用的 `.txt` 文本格式。
|
|
|
|
不依赖 numpy,便于在教学机器上直接运行:
|
|
|
|
py -3 migrate_npz_outputs.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import json
|
|
import os
|
|
import struct
|
|
import zipfile
|
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
OUTPUT_DIR = os.path.join(BASE_DIR, "output")
|
|
|
|
|
|
def parse_npy(raw):
|
|
"""Parse a small subset of the NPY format used by this project."""
|
|
if raw[:6] != b"\x93NUMPY":
|
|
raise ValueError("无效的 NPY 文件头")
|
|
|
|
major = raw[6]
|
|
minor = raw[7]
|
|
if major == 1:
|
|
header_len = struct.unpack("<H", raw[8:10])[0]
|
|
header_start = 10
|
|
elif major in (2, 3):
|
|
header_len = struct.unpack("<I", raw[8:12])[0]
|
|
header_start = 12
|
|
else:
|
|
raise ValueError(f"暂不支持的 NPY 版本: {major}.{minor}")
|
|
|
|
header_end = header_start + header_len
|
|
header = raw[header_start:header_end].decode("latin1").strip()
|
|
meta = ast.literal_eval(header)
|
|
descr = meta["descr"]
|
|
shape = meta["shape"]
|
|
fortran_order = meta["fortran_order"]
|
|
if fortran_order:
|
|
raise ValueError("暂不支持 Fortran 顺序数组")
|
|
|
|
payload = raw[header_end:]
|
|
values = unpack_payload(payload, descr)
|
|
return reshape_values(values, shape)
|
|
|
|
|
|
def unpack_payload(payload, descr):
|
|
endian = descr[0]
|
|
dtype = descr[1]
|
|
item_chars = descr[2:]
|
|
|
|
if dtype == "f":
|
|
fmt = {"4": "f", "8": "d"}[item_chars]
|
|
size = int(item_chars)
|
|
count = len(payload) // size
|
|
prefix = "<" if endian in ("<", "|") else ">"
|
|
return list(struct.unpack(prefix + fmt * count, payload))
|
|
|
|
if dtype == "i":
|
|
fmt = {"1": "b", "2": "h", "4": "i", "8": "q"}[item_chars]
|
|
size = int(item_chars)
|
|
count = len(payload) // size
|
|
prefix = "<" if endian in ("<", "|") else ">"
|
|
return list(struct.unpack(prefix + fmt * count, payload))
|
|
|
|
if dtype == "u":
|
|
fmt = {"1": "B", "2": "H", "4": "I", "8": "Q"}[item_chars]
|
|
size = int(item_chars)
|
|
count = len(payload) // size
|
|
prefix = "<" if endian in ("<", "|") else ">"
|
|
return list(struct.unpack(prefix + fmt * count, payload))
|
|
|
|
if dtype == "b":
|
|
return [bool(x) for x in payload]
|
|
|
|
if dtype == "U":
|
|
char_count = int(item_chars)
|
|
item_size = char_count * 4
|
|
out = []
|
|
for offset in range(0, len(payload), item_size):
|
|
chunk = payload[offset:offset + item_size]
|
|
out.append(chunk.decode("utf-32le").rstrip("\x00"))
|
|
return out
|
|
|
|
raise ValueError(f"暂不支持的数据类型: {descr}")
|
|
|
|
|
|
def reshape_values(values, shape):
|
|
if shape == ():
|
|
return values[0]
|
|
if len(shape) == 1:
|
|
return values[:shape[0]]
|
|
|
|
step = 1
|
|
for dim in shape[1:]:
|
|
step *= dim
|
|
|
|
return [
|
|
reshape_values(values[index:index + step], shape[1:])
|
|
for index in range(0, len(values), step)
|
|
]
|
|
|
|
|
|
def load_npz(path):
|
|
data = {}
|
|
with zipfile.ZipFile(path, "r") as zf:
|
|
for name in zf.namelist():
|
|
if not name.endswith(".npy"):
|
|
continue
|
|
key = os.path.splitext(os.path.basename(name))[0]
|
|
data[key] = parse_npy(zf.read(name))
|
|
return data
|
|
|
|
|
|
def dump_json(path, payload):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def maybe_move_legacy_table():
|
|
table_like = os.path.join(BASE_DIR, "trajectory.txt")
|
|
table_target = os.path.join(OUTPUT_DIR, "trajectory_table.txt")
|
|
if not os.path.exists(table_like) or os.path.exists(table_target):
|
|
return
|
|
|
|
with open(table_like, "r", encoding="utf-8") as f:
|
|
first_line = f.readline()
|
|
if first_line.startswith("#"):
|
|
os.replace(table_like, table_target)
|
|
|
|
|
|
def main():
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
maybe_move_legacy_table()
|
|
|
|
mappings = [
|
|
("trajectory.npz", os.path.join("output", "trajectory.txt")),
|
|
("display.npz", os.path.join("output", "display.txt")),
|
|
]
|
|
|
|
for src_name, dst_name in mappings:
|
|
src = os.path.join(BASE_DIR, src_name)
|
|
dst = os.path.join(BASE_DIR, dst_name)
|
|
if not os.path.exists(src):
|
|
continue
|
|
payload = load_npz(src)
|
|
dump_json(dst, payload)
|
|
print(f"[migrate] 已生成: {dst_name}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|