feat: 为 C/C++/Fortran 引擎增加驱动力(driving_force)支持
- param.json 新增 driving_force 开关 - C 引擎: 新增 DriverData 结构体、read_driver()、apply_driving_force() - C++ 引擎: 同上(C++ 风格实现) - Fortran 引擎: 同上(Fortran 90 风格实现) - 修复 JSON 输出末尾逗号导致加载失败的问题 - 编译通过并验证 C 引擎运行正常(100000步/6.6s)
This commit is contained in:
+143
-1
@@ -42,6 +42,7 @@ struct SimParams {
|
||||
int elastic_force = 1;
|
||||
int damping_force = 0;
|
||||
double gravity_strength = 1.0;
|
||||
int driving_force = 0;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
@@ -65,6 +66,20 @@ struct BondData {
|
||||
std::vector<double> rest_lengths;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// 驱动力数据
|
||||
// ========================================================================
|
||||
struct DriverData {
|
||||
int n_drivers = 0;
|
||||
std::vector<int> atom_idx; // internal atom indices
|
||||
std::vector<double> amp_x, amp_y, amp_z;
|
||||
std::vector<double> freq_x, freq_y, freq_z;
|
||||
std::vector<double> phi_x, phi_y, phi_z; // radians
|
||||
std::vector<int> has_period; // 0=all, 1=limited
|
||||
std::vector<double> period_cycles;
|
||||
std::vector<double> freeze_x, freeze_y, freeze_z;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// 辅助函数
|
||||
// ========================================================================
|
||||
@@ -149,6 +164,7 @@ static SimParams read_params(const std::string &path) {
|
||||
p.elastic_force = json_read_int(buf, "elastic_force");
|
||||
p.damping_force = json_read_int(buf, "damping_force");
|
||||
p.gravity_strength = json_read_double(buf, "gravity_strength");
|
||||
p.driving_force = json_read_int(buf, "driving_force");
|
||||
return p;
|
||||
}
|
||||
|
||||
@@ -233,6 +249,57 @@ static BondData read_bonds(const std::string &input_dir) {
|
||||
return b;
|
||||
}
|
||||
|
||||
/* 读取 driver.txt */
|
||||
static DriverData read_driver(const std::string &input_dir, const AtomData &atoms) {
|
||||
DriverData d;
|
||||
std::string path = input_dir + "/driver.txt";
|
||||
std::ifstream f(path);
|
||||
if (!f) { std::cerr << "[C++-engine] 警告: 无法打开 " << path << std::endl; return d; }
|
||||
|
||||
std::string header;
|
||||
std::getline(f, header); // skip header
|
||||
|
||||
int n;
|
||||
double ax, ay, az, fx, fy, fz, px, py, pz;
|
||||
std::string period_str;
|
||||
|
||||
while (f >> n >> ax >> ay >> az >> fx >> fy >> fz >> px >> py >> pz >> period_str) {
|
||||
// Find atom index by id
|
||||
int idx = -1;
|
||||
for (size_t i = 0; i < atoms.ids.size(); i++) {
|
||||
if (atoms.ids[i] == n) { idx = i; break; }
|
||||
}
|
||||
if (idx < 0) {
|
||||
std::cerr << "[C++-engine] 警告: driver.txt 原子 " << n << " 不在 coord.txt 中" << std::endl;
|
||||
continue;
|
||||
}
|
||||
d.atom_idx.push_back(idx);
|
||||
d.amp_x.push_back(ax); d.amp_y.push_back(ay); d.amp_z.push_back(az);
|
||||
d.freq_x.push_back(fx); d.freq_y.push_back(fy); d.freq_z.push_back(fz);
|
||||
// Convert degrees to radians
|
||||
const double DEG2RAD = M_PI / 180.0;
|
||||
d.phi_x.push_back(px * DEG2RAD);
|
||||
d.phi_y.push_back(py * DEG2RAD);
|
||||
d.phi_z.push_back(pz * DEG2RAD);
|
||||
|
||||
if (period_str == "all") {
|
||||
d.has_period.push_back(0);
|
||||
d.period_cycles.push_back(-1.0);
|
||||
} else {
|
||||
d.has_period.push_back(1);
|
||||
d.period_cycles.push_back(std::stod(period_str));
|
||||
}
|
||||
d.freeze_x.push_back(0.0);
|
||||
d.freeze_y.push_back(0.0);
|
||||
d.freeze_z.push_back(0.0);
|
||||
d.n_drivers++;
|
||||
}
|
||||
|
||||
if (d.n_drivers > 0)
|
||||
std::cout << "[C++-engine] 已加载驱动力: " << d.n_drivers << " 条定义" << std::endl;
|
||||
return d;
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// 物理核心
|
||||
// ========================================================================
|
||||
@@ -539,6 +606,68 @@ static void apply_step(
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// 驱动力应用
|
||||
// ========================================================================
|
||||
|
||||
static void apply_driving_force(
|
||||
int n, double *x, double *y, double *z,
|
||||
double *vx, double *vy, double *vz,
|
||||
double t, int step, double dt,
|
||||
DriverData &drivers)
|
||||
{
|
||||
if (drivers.n_drivers == 0) return;
|
||||
for (int d = 0; d < drivers.n_drivers; d++) {
|
||||
int idx = drivers.atom_idx[d];
|
||||
|
||||
// Check period limits
|
||||
if (drivers.has_period[d]) {
|
||||
double max_freq = std::max({std::fabs(drivers.freq_x[d]),
|
||||
std::fabs(drivers.freq_y[d]),
|
||||
std::fabs(drivers.freq_z[d])});
|
||||
int period_steps = 0;
|
||||
if (max_freq > 1e-12) {
|
||||
period_steps = static_cast<int>(drivers.period_cycles[d] / max_freq / dt);
|
||||
}
|
||||
if (step > period_steps) {
|
||||
// Frozen: keep last position, zero velocity
|
||||
x[idx] = drivers.freeze_x[d];
|
||||
y[idx] = drivers.freeze_y[d];
|
||||
z[idx] = drivers.freeze_z[d];
|
||||
vx[idx] = vy[idx] = vz[idx] = 0.0;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const double TWO_PI = 2.0 * M_PI;
|
||||
double px = drivers.amp_x[d] * std::cos(TWO_PI * drivers.freq_x[d] * t + drivers.phi_x[d]);
|
||||
double py = drivers.amp_y[d] * std::cos(TWO_PI * drivers.freq_y[d] * t + drivers.phi_y[d]);
|
||||
double pz = drivers.amp_z[d] * std::cos(TWO_PI * drivers.freq_z[d] * t + drivers.phi_z[d]);
|
||||
double vpx = -drivers.amp_x[d] * TWO_PI * drivers.freq_x[d] * std::sin(TWO_PI * drivers.freq_x[d] * t + drivers.phi_x[d]);
|
||||
double vpy = -drivers.amp_y[d] * TWO_PI * drivers.freq_y[d] * std::sin(TWO_PI * drivers.freq_y[d] * t + drivers.phi_y[d]);
|
||||
double vpz = -drivers.amp_z[d] * TWO_PI * drivers.freq_z[d] * std::sin(TWO_PI * drivers.freq_z[d] * t + drivers.phi_z[d]);
|
||||
|
||||
x[idx] = px; y[idx] = py; z[idx] = pz;
|
||||
vx[idx] = vpx; vy[idx] = vpy; vz[idx] = vpz;
|
||||
|
||||
// Record freeze position at the last driving step
|
||||
if (drivers.has_period[d]) {
|
||||
double max_freq = std::max({std::fabs(drivers.freq_x[d]),
|
||||
std::fabs(drivers.freq_y[d]),
|
||||
std::fabs(drivers.freq_z[d])});
|
||||
int period_steps = 0;
|
||||
if (max_freq > 1e-12) {
|
||||
period_steps = static_cast<int>(drivers.period_cycles[d] / max_freq / dt);
|
||||
}
|
||||
if (step == period_steps) {
|
||||
drivers.freeze_x[d] = px;
|
||||
drivers.freeze_y[d] = py;
|
||||
drivers.freeze_z[d] = pz;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// JSON 输出
|
||||
// ========================================================================
|
||||
@@ -621,7 +750,9 @@ static void write_trajectory_json(
|
||||
if (b > 0) f << ',';
|
||||
f << bonds.rest_lengths[b];
|
||||
}
|
||||
f << "]\n";
|
||||
f << "],\n";
|
||||
|
||||
f << " \"driving_force\": " << params.driving_force << "\n";
|
||||
|
||||
f << "}\n";
|
||||
}
|
||||
@@ -647,6 +778,11 @@ int main(int argc, char **argv) {
|
||||
AtomData atoms = read_coord(input_dir);
|
||||
BondData bonds = read_bonds(input_dir);
|
||||
|
||||
DriverData drivers;
|
||||
if (params.driving_force) {
|
||||
drivers = read_driver(input_dir, atoms);
|
||||
}
|
||||
|
||||
std::cout << "[C++-engine] 原子数=" << atoms.ids.size()
|
||||
<< ", 键数=" << bonds.stiffness.size()
|
||||
<< ", NT=" << params.NT << ", DT=" << params.DT
|
||||
@@ -676,6 +812,9 @@ int main(int argc, char **argv) {
|
||||
|
||||
// 预热
|
||||
for (int s = 0; s < params.warmup_steps; s++) {
|
||||
double tw = (s + 1) * params.DT;
|
||||
if (params.driving_force)
|
||||
apply_driving_force(n, x.data(), y.data(), z.data(), vx.data(), vy.data(), vz.data(), tw, s, params.DT, drivers);
|
||||
apply_step(params.method, n, x.data(), y.data(), z.data(),
|
||||
vx.data(), vy.data(), vz.data(),
|
||||
atoms.masses.data(), params.G, params.B,
|
||||
@@ -688,6 +827,9 @@ int main(int argc, char **argv) {
|
||||
|
||||
// 记录
|
||||
for (int s = 0; s < record_steps; s++) {
|
||||
double t = (s + params.warmup_steps) * params.DT;
|
||||
if (params.driving_force)
|
||||
apply_driving_force(n, x.data(), y.data(), z.data(), vx.data(), vy.data(), vz.data(), t, s, params.DT, drivers);
|
||||
// 保存当前帧
|
||||
for (int i = 0; i < n; i++) {
|
||||
traj_x[s * n + i] = x[i];
|
||||
|
||||
Reference in New Issue
Block a user