feat: 为 C/C++/Fortran 引擎增加驱动力(driving_force)支持
- param.json 新增 driving_force 开关 - C 引擎: 新增 DriverData 结构体、read_driver()、apply_driving_force() - C++ 引擎: 同上(C++ 风格实现) - Fortran 引擎: 同上(Fortran 90 风格实现) - 修复 JSON 输出末尾逗号导致加载失败的问题 - 编译通过并验证 C 引擎运行正常(100000步/6.6s)
This commit is contained in:
+206
-3
@@ -37,6 +37,7 @@ typedef struct {
|
||||
int elastic_force; /* 弹簧键力开关 */
|
||||
int damping_force; /* 阻尼开关 */
|
||||
double gravity_strength; /* 万有引力强度 */
|
||||
int driving_force; /* 驱动力开关 */
|
||||
} SimParams;
|
||||
|
||||
/* ========================================================================
|
||||
@@ -62,6 +63,139 @@ typedef struct {
|
||||
double *rest_lengths;
|
||||
} BondData;
|
||||
|
||||
/* 前向声明 */
|
||||
static void *xmalloc(size_t sz);
|
||||
|
||||
/* ========================================================================
|
||||
* 驱动力数据
|
||||
* ======================================================================== */
|
||||
typedef struct {
|
||||
int n_drivers;
|
||||
int *atom_idx;
|
||||
double *amp_x, *amp_y, *amp_z;
|
||||
double *freq_x, *freq_y, *freq_z;
|
||||
double *phi_x, *phi_y, *phi_z; /* radians */
|
||||
int *has_period; /* 0=all, 1=limited cycles */
|
||||
double *period_cycles; /* number of cycles */
|
||||
double *freeze_x, *freeze_y, *freeze_z;
|
||||
} DriverData;
|
||||
|
||||
/* 读取 driver.txt */
|
||||
static DriverData read_driver(const char *input_dir, const AtomData *atoms) {
|
||||
DriverData d;
|
||||
memset(&d, 0, sizeof(d));
|
||||
|
||||
char path[512];
|
||||
snprintf(path, sizeof(path), "%s/driver.txt", input_dir);
|
||||
FILE *f = fopen(path, "r");
|
||||
if (!f) return d;
|
||||
|
||||
char line[1024];
|
||||
if (!fgets(line, sizeof(line), f)) { fclose(f); return d; }
|
||||
|
||||
/* 第一遍:统计行数 */
|
||||
int n_lines = 0;
|
||||
while (fgets(line, sizeof(line), f)) {
|
||||
char trimmed[1024];
|
||||
int j = 0;
|
||||
for (int i = 0; line[i]; i++) {
|
||||
if (line[i] != ' ' && line[i] != '\t' && line[i] != '\n' && line[i] != '\r')
|
||||
trimmed[j++] = line[i];
|
||||
}
|
||||
trimmed[j] = '\0';
|
||||
if (strlen(trimmed) > 0 && trimmed[0] != '#') n_lines++;
|
||||
}
|
||||
|
||||
if (n_lines == 0) { fclose(f); return d; }
|
||||
|
||||
/* 分配内存 */
|
||||
d.n_drivers = n_lines;
|
||||
d.atom_idx = (int*)xmalloc(n_lines * sizeof(int));
|
||||
d.amp_x = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.amp_y = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.amp_z = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.freq_x = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.freq_y = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.freq_z = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.phi_x = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.phi_y = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.phi_z = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.has_period = (int*)xmalloc(n_lines * sizeof(int));
|
||||
d.period_cycles = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.freeze_x = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.freeze_y = (double*)xmalloc(n_lines * sizeof(double));
|
||||
d.freeze_z = (double*)xmalloc(n_lines * sizeof(double));
|
||||
|
||||
/* 初始化 freeze 数组 */
|
||||
for (int i = 0; i < n_lines; i++) {
|
||||
d.freeze_x[i] = d.freeze_y[i] = d.freeze_z[i] = 0.0;
|
||||
}
|
||||
|
||||
/* 第二遍:解析 */
|
||||
rewind(f);
|
||||
fgets(line, sizeof(line), f); /* 跳过表头 */
|
||||
|
||||
int idx = 0;
|
||||
while (idx < n_lines && fgets(line, sizeof(line), f)) {
|
||||
char trimmed[1024];
|
||||
int j = 0;
|
||||
for (int i = 0; line[i]; i++) {
|
||||
if (line[i] != ' ' && line[i] != '\t' && line[i] != '\n' && line[i] != '\r')
|
||||
trimmed[j++] = line[i];
|
||||
}
|
||||
trimmed[j] = '\0';
|
||||
if (strlen(trimmed) == 0 || trimmed[0] == '#') continue;
|
||||
|
||||
int atom_id;
|
||||
double amp_x, amp_y, amp_z;
|
||||
double freq_x, freq_y, freq_z;
|
||||
double phi_x, phi_y, phi_z;
|
||||
char period_str[256] = {0};
|
||||
|
||||
int n_parsed = sscanf(line,
|
||||
"%d %lf %lf %lf %lf %lf %lf %lf %lf %lf %255s",
|
||||
&atom_id,
|
||||
&_x, &_y, &_z,
|
||||
&freq_x, &freq_y, &freq_z,
|
||||
&phi_x, &phi_y, &phi_z,
|
||||
period_str);
|
||||
|
||||
if (n_parsed < 11) continue;
|
||||
|
||||
/* 通过原子 ID 匹配内部索引(线性搜索)*/
|
||||
int ii = -1;
|
||||
for (int k = 0; k < atoms->n_atoms; k++) {
|
||||
if (atoms->atom_ids[k] == atom_id) { ii = k; break; }
|
||||
}
|
||||
if (ii < 0) continue;
|
||||
|
||||
d.atom_idx[idx] = ii;
|
||||
d.amp_x[idx] = amp_x;
|
||||
d.amp_y[idx] = amp_y;
|
||||
d.amp_z[idx] = amp_z;
|
||||
d.freq_x[idx] = freq_x;
|
||||
d.freq_y[idx] = freq_y;
|
||||
d.freq_z[idx] = freq_z;
|
||||
/* 角度 → 弧度 */
|
||||
d.phi_x[idx] = phi_x * M_PI / 180.0;
|
||||
d.phi_y[idx] = phi_y * M_PI / 180.0;
|
||||
d.phi_z[idx] = phi_z * M_PI / 180.0;
|
||||
|
||||
if (strcmp(period_str, "all") == 0 || strcmp(period_str, "-1") == 0) {
|
||||
d.has_period[idx] = 0;
|
||||
d.period_cycles[idx] = -1.0;
|
||||
} else {
|
||||
d.has_period[idx] = 1;
|
||||
d.period_cycles[idx] = strtod(period_str, NULL);
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
d.n_drivers = idx;
|
||||
|
||||
fclose(f);
|
||||
return d;
|
||||
}
|
||||
|
||||
/* ========================================================================
|
||||
* 轨迹缓冲区
|
||||
* ======================================================================== */
|
||||
@@ -169,6 +303,7 @@ static SimParams read_params(const char *path) {
|
||||
p.elastic_force = json_read_int(buf, "elastic_force");
|
||||
p.damping_force = json_read_int(buf, "damping_force");
|
||||
p.gravity_strength = json_read_double(buf, "gravity_strength");
|
||||
p.driving_force = json_read_int(buf, "driving_force");
|
||||
g_gravity_field = p.gravity_field;
|
||||
g_gravity_interaction = p.gravity_interaction;
|
||||
g_elastic_force = p.elastic_force;
|
||||
@@ -529,6 +664,63 @@ static void leapfrog_step(
|
||||
}
|
||||
}
|
||||
|
||||
/* ── 驱动力(与 Python apply_driving_force 一致)──────────────── */
|
||||
static void apply_driving_force(
|
||||
int n, double *x, double *y, double *z,
|
||||
double *vx, double *vy, double *vz,
|
||||
double t, int step, double dt,
|
||||
const DriverData *drivers)
|
||||
{
|
||||
if (!drivers || drivers->n_drivers == 0) return;
|
||||
for (int d = 0; d < drivers->n_drivers; d++) {
|
||||
int idx = drivers->atom_idx[d];
|
||||
/* 检查周期限制 */
|
||||
if (drivers->has_period[d]) {
|
||||
double max_freq = fmax(fabs(drivers->freq_x[d]),
|
||||
fmax(fabs(drivers->freq_y[d]), fabs(drivers->freq_z[d])));
|
||||
int period_steps = 0;
|
||||
if (max_freq > 1e-12) {
|
||||
period_steps = (int)(drivers->period_cycles[d] / max_freq / dt);
|
||||
}
|
||||
if (step > period_steps) {
|
||||
/* 冻结 */
|
||||
if (drivers->freeze_x) {
|
||||
x[idx] = drivers->freeze_x[d];
|
||||
y[idx] = drivers->freeze_y[d];
|
||||
z[idx] = drivers->freeze_z[d];
|
||||
}
|
||||
vx[idx] = vy[idx] = vz[idx] = 0.0;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
double px = drivers->amp_x[d] * cos(2*M_PI*drivers->freq_x[d]*t + drivers->phi_x[d]);
|
||||
double py = drivers->amp_y[d] * cos(2*M_PI*drivers->freq_y[d]*t + drivers->phi_y[d]);
|
||||
double pz = drivers->amp_z[d] * cos(2*M_PI*drivers->freq_z[d]*t + drivers->phi_z[d]);
|
||||
double vpx = -drivers->amp_x[d]*2*M_PI*drivers->freq_x[d]*sin(2*M_PI*drivers->freq_x[d]*t + drivers->phi_x[d]);
|
||||
double vpy = -drivers->amp_y[d]*2*M_PI*drivers->freq_y[d]*sin(2*M_PI*drivers->freq_y[d]*t + drivers->phi_y[d]);
|
||||
double vpz = -drivers->amp_z[d]*2*M_PI*drivers->freq_z[d]*sin(2*M_PI*drivers->freq_z[d]*t + drivers->phi_z[d]);
|
||||
|
||||
x[idx] = px; y[idx] = py; z[idx] = pz;
|
||||
vx[idx] = vpx; vy[idx] = vpy; vz[idx] = vpz;
|
||||
|
||||
/* 记录冻结位置(周期结束时) */
|
||||
if (drivers->has_period[d]) {
|
||||
double max_freq = fmax(fabs(drivers->freq_x[d]),
|
||||
fmax(fabs(drivers->freq_y[d]), fabs(drivers->freq_z[d])));
|
||||
int period_steps = 0;
|
||||
if (max_freq > 1e-12) {
|
||||
period_steps = (int)(drivers->period_cycles[d] / max_freq / dt);
|
||||
}
|
||||
if (step == period_steps) {
|
||||
drivers->freeze_x[d] = px;
|
||||
drivers->freeze_y[d] = py;
|
||||
drivers->freeze_z[d] = pz;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ── 分发器:调用对应积分方法 + 边界条件 + 自由度约束(与 Python 一致)── */
|
||||
static void apply_step(
|
||||
const char *method,
|
||||
@@ -643,7 +835,8 @@ static void write_trajectory_json(const char *path, const Trajectory *traj,
|
||||
if (b > 0) fputc(',', f);
|
||||
fprintf(f, "%.15g", bonds->rest_lengths[b]);
|
||||
}
|
||||
fprintf(f, "]\n");
|
||||
fprintf(f, "],\n");
|
||||
fprintf(f, " \"driving_force\": %d\n", params->driving_force);
|
||||
|
||||
fprintf(f, "}\n");
|
||||
fclose(f);
|
||||
@@ -668,8 +861,14 @@ int main(int argc, char **argv) {
|
||||
AtomData atoms = read_coord(input_dir);
|
||||
BondData bonds = read_bonds(input_dir, &atoms);
|
||||
|
||||
printf("[C-engine] 原子数=%d, 键数=%d, NT=%d, DT=%.6g, method=%s\n",
|
||||
atoms.n_atoms, bonds.n_bonds, params.NT, params.DT, params.method);
|
||||
DriverData drivers;
|
||||
drivers.n_drivers = 0;
|
||||
if (params.driving_force) {
|
||||
drivers = read_driver(input_dir, &atoms);
|
||||
}
|
||||
|
||||
printf("[C-engine] 原子数=%d, 键数=%d, 驱动=%d, NT=%d, DT=%.6g, method=%s\n",
|
||||
atoms.n_atoms, bonds.n_bonds, drivers.n_drivers, params.NT, params.DT, params.method);
|
||||
|
||||
int n = atoms.n_atoms;
|
||||
double *x = (double*)xmalloc(n * sizeof(double));
|
||||
@@ -701,6 +900,8 @@ int main(int argc, char **argv) {
|
||||
|
||||
/* 预热 */
|
||||
for (int s = 0; s < params.warmup_steps; s++) {
|
||||
double tw = (s + 1) * params.DT;
|
||||
if (params.driving_force) apply_driving_force(n, x, y, z, vx, vy, vz, tw, s, params.DT, &drivers);
|
||||
apply_step(params.method, n, x, y, z, vx, vy, vz,
|
||||
atoms.masses, params.G, params.B, &bonds, atoms.fixed,
|
||||
atoms.pos_0,
|
||||
@@ -709,6 +910,8 @@ int main(int argc, char **argv) {
|
||||
|
||||
/* 记录 */
|
||||
for (int s = 0; s < record_steps; s++) {
|
||||
double t = (s + params.warmup_steps) * params.DT;
|
||||
if (params.driving_force) apply_driving_force(n, x, y, z, vx, vy, vz, t, s, params.DT, &drivers);
|
||||
for (int i = 0; i < n; i++) {
|
||||
traj.x[ s * n + i] = x[i];
|
||||
traj.y[ s * n + i] = y[i];
|
||||
|
||||
Reference in New Issue
Block a user