00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include "msmpot_cuda.h"
00019
00020
00021 #define MAXLEVELS 28
00022 #define SUBCUBESZ 64
00023 #define LG_SUBCUBESZ 6
00024
00025 #undef PRECOMP_1
00026 #define PRECOMP_1
00027
00028 #undef UNROLL_1
00029 #define UNROLL_1
00030
00031 #undef SHMEMCOPY
00032 #define SHMEMCOPY
00033
00034 #undef UNROLL_2
00035 #define UNROLL_2
00036
00037 __constant__ int4 sinfo[MAXLEVELS];
00038
00039
00040
00041
00042
00043
00044 __constant__ float lfac[MAXLEVELS];
00045
00046
00047 __constant__ float wt[24*24*24];
00048
00049
00050 __global__ static void cuda_latcut(
00051 unsigned int nsubcubes,
00052 int nlevels,
00053 int srad,
00054 int padding,
00055 float *qgrids,
00056 float *egrids
00057 )
00058 {
00059
00060
00061
00062
00063 int soff = 0;
00064 int level = 0;
00065
00066 __shared__ float cq[8*8*8];
00067
00068 unsigned int block_index = gridDim.x * blockIdx.y + blockIdx.x;
00069
00070 if (block_index >= nsubcubes) return;
00071
00072
00073 while (block_index >= sinfo[level].w) {
00074 soff += (sinfo[level].x * sinfo[level].y * sinfo[level].z);
00075 level++;
00076 }
00077
00078
00079 float *qlevel = qgrids + soff * SUBCUBESZ;
00080
00081
00082 int nx = sinfo[level].x;
00083 int ny = sinfo[level].y;
00084 int nz = sinfo[level].z;
00085
00086 int boff = block_index - (level > 0 ? sinfo[level-1].w : 0);
00087
00088 int sbx = nx - 2*padding;
00089 int sby = ny - 2*padding;
00090
00091
00092
00093 int nrow = boff / sbx;
00094 int sx = (boff % sbx) + padding;
00095 int sz = (nrow / sby) + padding;
00096 int sy = (nrow % sby) + padding;
00097
00098 int tx = threadIdx.x;
00099 int ty = threadIdx.y;
00100 int tz = threadIdx.z;
00101
00102
00103 int ax = sx - srad;
00104 if (ax < 0) ax = 0;
00105 int bx = sx + srad;
00106 if (bx >= nx) bx = nx-1;
00107 int ay = sy - srad;
00108 if (ay < 0) ay = 0;
00109 int by = sy + srad;
00110 if (by >= ny) by = ny-1;
00111 int az = sz - srad;
00112 if (az < 0) az = 0;
00113 int bz = sz + srad;
00114 if (bz >= nz) bz = nz-1;
00115
00116 int mx, my, mz;
00117
00118
00119 int tid = (tz*4 + ty)*4 + tx;
00120
00121
00122 float e = 0;
00123 for (mz = az; mz < bz; mz++) {
00124 for (my = ay; my < by; my++) {
00125 #ifdef SHMEMCOPY
00126 int addr;
00127 float *q;
00128
00129 #ifdef UNROLL_2
00130 q = qlevel + (((mz+0)*ny + (my+0))*nx + ax) * SUBCUBESZ;
00131 addr = ((tz + 4*0)*8 + (ty + 4*0))*8 + tx;
00132 cq[addr+4] = q[tid];
00133
00134 q = qlevel + (((mz+0)*ny + (my+1))*nx + ax) * SUBCUBESZ;
00135 addr = ((tz + 4*0)*8 + (ty + 4*1))*8 + tx;
00136 cq[addr+4] = q[tid];
00137
00138 q = qlevel + (((mz+1)*ny + (my+0))*nx + ax) * SUBCUBESZ;
00139 addr = ((tz + 4*1)*8 + (ty + 4*0))*8 + tx;
00140 cq[addr+4] = q[tid];
00141
00142 q = qlevel + (((mz+1)*ny + (my+1))*nx + ax) * SUBCUBESZ;
00143 addr = ((tz + 4*1)*8 + (ty + 4*1))*8 + tx;
00144 cq[addr+4] = q[tid];
00145 #else
00146 int j, k;
00147
00148 for (k = 0; k < 2; k++) {
00149 for (j = 0; j < 2; j++) {
00150 int qsid = ((mz+k)*ny + (my+j))*nx + ax;
00151 float *q = qlevel + qsid * SUBCUBESZ;
00152 int addr = ((tz + 4*k)*8 + (ty + 4*j))*8 + tx;
00153 cq[addr+4] = q[tid];
00154 }
00155 }
00156 #endif
00157
00158 #endif
00159 for (mx = ax; mx < bx; mx++) {
00160
00161 #if !defined(SHMEMCOPY)
00162 int i, j, k;
00163
00164
00165 for (k = 0; k < 2; k++) {
00166 for (j = 0; j < 2; j++) {
00167 for (i = 0; i < 2; i++) {
00168
00169 int qsid = ((mz+k)*ny + (my+j))*nx + (mx+i);
00170 float *q = qlevel + qsid * SUBCUBESZ;
00171
00172 int idest = ((tz + 4*k)*8 + (ty + 4*j))*8 + (tx + 4*i);
00173
00174 cq[idest] = q[tid];
00175 }
00176 }
00177 }
00178 #else
00179
00180 #ifdef UNROLL_2
00181
00182 addr = ((tz + 4*0)*8 + (ty + 4*0))*8 + tx;
00183 cq[addr] = cq[addr+4];
00184
00185 addr = ((tz + 4*0)*8 + (ty + 4*1))*8 + tx;
00186 cq[addr] = cq[addr+4];
00187
00188 addr = ((tz + 4*1)*8 + (ty + 4*0))*8 + tx;
00189 cq[addr] = cq[addr+4];
00190
00191 addr = ((tz + 4*1)*8 + (ty + 4*1))*8 + tx;
00192 cq[addr] = cq[addr+4];
00193
00194
00195 q = qlevel + (((mz+0)*ny + (my+0))*nx + (mx+1)) * SUBCUBESZ;
00196 addr = ((tz + 4*0)*8 + (ty + 4*0))*8 + tx;
00197 cq[addr+4] = q[tid];
00198
00199 q = qlevel + (((mz+0)*ny + (my+1))*nx + (mx+1)) * SUBCUBESZ;
00200 addr = ((tz + 4*0)*8 + (ty + 4*1))*8 + tx;
00201 cq[addr+4] = q[tid];
00202
00203 q = qlevel + (((mz+1)*ny + (my+0))*nx + (mx+1)) * SUBCUBESZ;
00204 addr = ((tz + 4*1)*8 + (ty + 4*0))*8 + tx;
00205 cq[addr+4] = q[tid];
00206
00207 q = qlevel + (((mz+1)*ny + (my+1))*nx + (mx+1)) * SUBCUBESZ;
00208 addr = ((tz + 4*1)*8 + (ty + 4*1))*8 + tx;
00209 cq[addr+4] = q[tid];
00210 #else
00211
00212 for (k = 0; k < 2; k++) {
00213 for (j = 0; j < 2; j++) {
00214 int addr = ((tz + 4*k)*8 + (ty + 4*j))*8 + tx;
00215 cq[addr] = cq[addr+4];
00216 }
00217 }
00218
00219 for (k = 0; k < 2; k++) {
00220 for (j = 0; j < 2; j++) {
00221 int qsid = ((mz+k)*ny + (my+j))*nx + (mx+1);
00222 q = qlevel + qsid * SUBCUBESZ;
00223 addr = ((tz + 4*k)*8 + (ty + 4*j))*8 + tx;
00224 cq[addr+4] = q[tid];
00225 }
00226 }
00227 #endif
00228
00229 #endif
00230 __syncthreads();
00231
00232
00233 int wx = (mx-sx+srad)*4;
00234 int wy = (my-sy+srad)*4;
00235 int wz = (mz-sz+srad)*4;
00236
00237 #if defined(SHMEMCOPY) && defined(UNROLL_2) && defined(UNROLL_1)
00238 int j, k;
00239 #elif defined(SHMEMCOPY) && defined(UNROLL_2)
00240 int i, j, k;
00241 #elif defined(SHMEMCOPY)
00242 int i;
00243 #endif
00244
00245 for (k = 0; k < 4; k++) {
00246 for (j = 0; j < 4; j++) {
00247 #ifdef PRECOMP_1
00248 int cq_index_jk = ((tz+k)*8 + (ty+j))*8;
00249 int wt_index_jk = ((wz+k)*(8*srad) + (wy+j))*(8*srad);
00250 #endif
00251
00252 #if !defined(UNROLL_1)
00253 for (i = 0; i < 4; i++) {
00254 #if !defined(PRECOMP_1)
00255 int cq_index = ((tz+k)*8 + (ty+j))*8 + (tx+i);
00256 int wt_index = ((wz+k)*(8*srad) + (wy+j))*(8*srad) + (wx+i);
00257 #else
00258 int cq_index = cq_index_jk + (tx+i);
00259 int wt_index = wt_index_jk + (wx+i);
00260 #endif
00261 e += cq[cq_index] * wt[wt_index];
00262
00263 }
00264 #else
00265 int cq_index;
00266 int wt_index;
00267
00268 #if !defined(PRECOMP_1)
00269 cq_index = ((tz+k)*8 + (ty+j))*8 + (tx+0);
00270 wt_index = ((wz+k)*(8*srad) + (wy+j))*(8*srad) + (wx+0);
00271 #else
00272 cq_index = cq_index_jk + (tx+0);
00273 wt_index = wt_index_jk + (wx+0);
00274 #endif
00275 e += cq[cq_index] * wt[wt_index];
00276
00277 #if !defined(PRECOMP_1)
00278 cq_index = ((tz+k)*8 + (ty+j))*8 + (tx+1);
00279 wt_index = ((wz+k)*(8*srad) + (wy+j))*(8*srad) + (wx+1);
00280 #else
00281 cq_index = cq_index_jk + (tx+1);
00282 wt_index = wt_index_jk + (wx+1);
00283 #endif
00284 e += cq[cq_index] * wt[wt_index];
00285
00286 #if !defined(PRECOMP_1)
00287 cq_index = ((tz+k)*8 + (ty+j))*8 + (tx+2);
00288 wt_index = ((wz+k)*(8*srad) + (wy+j))*(8*srad) + (wx+2);
00289 #else
00290 cq_index = cq_index_jk + (tx+2);
00291 wt_index = wt_index_jk + (wx+2);
00292 #endif
00293 e += cq[cq_index] * wt[wt_index];
00294
00295 #if !defined(PRECOMP_1)
00296 cq_index = ((tz+k)*8 + (ty+j))*8 + (tx+3);
00297 wt_index = ((wz+k)*(8*srad) + (wy+j))*(8*srad) + (wx+3);
00298 #else
00299 cq_index = cq_index_jk + (tx+3);
00300 wt_index = wt_index_jk + (wx+3);
00301 #endif
00302 e += cq[cq_index] * wt[wt_index];
00303 #endif
00304
00305 }
00306 }
00307 __syncthreads();
00308
00309 }
00310 }
00311 }
00312 e *= lfac[level];
00313
00314 float *eout = egrids + (soff + (sz*ny + sy)*nx + sx) * SUBCUBESZ;
00315 eout[ (tz*4 + ty)*4 + tx ] = e;
00316 }
00317
00318
00319
00320
00321
00322 void Msmpot_cuda_cleanup_latcut(MsmpotCuda *mc) {
00323 cudaFree(mc->device_qgrids);
00324 cudaFree(mc->device_egrids);
00325 free(mc->host_qgrids);
00326 free(mc->host_egrids);
00327 free(mc->host_wt);
00328 free(mc->host_sinfo);
00329 free(mc->host_lfac);
00330 }
00331
00332
00333 int Msmpot_cuda_setup_latcut(MsmpotCuda *mc) {
00334 Msmpot *msm = mc->msmpot;
00335 const float hx = msm->hx;
00336 const float hy = msm->hy;
00337 const float hz = msm->hz;
00338 float hmin;
00339 const float a = msm->a;
00340 const int split = msm->split;
00341 const int maxlevels = msm->maxlevels;
00342 int nlevels = msm->nlevels - 1;
00343 int nrad;
00344 int srad;
00345 int pad;
00346 int i, j, k, ii, jj, kk;
00347 int index;
00348 long btotal, stotal, maxwts, maxgridpts, memsz;
00349 float s, t, gs, gt;
00350 float lfac;
00351 float *wt;
00352
00353 if (nlevels > MAXLEVELS) return ERROR(MSMPOT_ERROR_CUDA_SUPPORT);
00354 mc->lk_nlevels = nlevels;
00355 hmin = hx;
00356 if (hmin < hy) hmin = hy;
00357 if (hmin < hz) hmin = hz;
00358 nrad = (int) ceilf(2*a/hmin) - 1;
00359 srad = (int) ceilf((nrad + 1) / 4.f);
00360 if (srad > 3) return ERROR(MSMPOT_ERROR_CUDA_SUPPORT);
00361 mc->lk_srad = srad;
00362 if (msm->isperiodic) {
00363 pad = srad;
00364 }
00365 else {
00366 pad = 1;
00367 }
00368 mc->lk_padding = pad;
00369 #ifdef MSMPOT_DEBUG
00370 printf("a=%g h=%g\n", a, h);
00371 printf("nrad=%d\n", nrad);
00372 printf("srad=%d\n", srad);
00373 printf("pad=%d\n", pad);
00374 #endif
00375
00376 if (mc->maxlevels < maxlevels) {
00377 void *v;
00378 v = realloc(mc->host_lfac, maxlevels * sizeof(float));
00379 if (NULL == v) return ERROR(MSMPOT_ERROR_MALLOC);
00380 mc->host_lfac = (float *) v;
00381 v = realloc(mc->host_sinfo, maxlevels * 4 * sizeof(int));
00382 if (NULL == v) return ERROR(MSMPOT_ERROR_MALLOC);
00383 mc->host_sinfo = (int *) v;
00384 mc->maxlevels = maxlevels;
00385 }
00386
00387 lfac = 1.f;
00388 for (i = 0; i < nlevels; i++) {
00389 mc->host_lfac[i] = lfac;
00390 lfac *= 0.5f;
00391 }
00392
00393 stotal = 0;
00394 btotal = 0;
00395 for (i = 0; i < nlevels; i++) {
00396
00397 const floatGrid *f = &(mc->msmpot->qh[i]);
00398 int nx = mc->host_sinfo[4*i ] = (int) ceilf(f->ni * 0.25f) + 2*pad;
00399 int ny = mc->host_sinfo[4*i + 1] = (int) ceilf(f->nj * 0.25f) + 2*pad;
00400 int nz = mc->host_sinfo[4*i + 2] = (int) ceilf(f->nk * 0.25f) + 2*pad;
00401 stotal += nx * ny * nz;
00402 btotal += (nx - 2*pad) * (ny - 2*pad) * (nz - 2*pad);
00403 mc->host_sinfo[4* i + 3] = btotal;
00404 #ifdef MSMPOT_DEBUG
00405 printf("\nlevel %d: ni=%2d nj=%2d nk=%2d\n", i, f->ni, f->nj, f->nk);
00406 printf(" nx=%2d ny=%2d nz=%2d stotal=%d\n",
00407 nx, ny, nz, stotal);
00408 printf(" bx=%2d by=%2d bz=%2d btotal=%d\n",
00409 nx-2*pad, ny-2*pad, nz-2*pad, btotal);
00410 #endif
00411 }
00412 #ifdef MSMPOT_DEBUG
00413 printf("\n");
00414 #endif
00415
00416 mc->subcube_total = stotal;
00417 mc->block_total = btotal;
00418
00419 #ifdef MSMPOT_DEBUG
00420 printf("nlevels=%d\n", nlevels);
00421 for (i = 0; i < nlevels; i++) {
00422 printf("ni=%d nj=%d nk=%d\n",
00423 msm->qh[i].ni, msm->qh[i].nj, msm->qh[i].nk);
00424 printf("nx=%d ny=%d nz=%d nw=%d\n",
00425 mc->host_sinfo[4*i ],
00426 mc->host_sinfo[4*i + 1],
00427 mc->host_sinfo[4*i + 2],
00428 mc->host_sinfo[4*i + 3]);
00429 }
00430 #endif
00431
00432
00433 maxwts = (8*srad) * (8*srad) * (8*srad);
00434 if (mc->maxwts < maxwts) {
00435 void *v = realloc(mc->host_wt, maxwts * sizeof(float));
00436 if (NULL == v) return ERROR(MSMPOT_ERROR_MALLOC);
00437 mc->host_wt = (float *) v;
00438 mc->maxwts = maxwts;
00439 }
00440 wt = mc->host_wt;
00441 for (kk = 0; kk < 8*srad; kk++) {
00442 for (jj = 0; jj < 8*srad; jj++) {
00443 for (ii = 0; ii < 8*srad; ii++) {
00444 index = (kk*(8*srad) + jj)*(8*srad) + ii;
00445 i = ii - 4*srad;
00446 j = jj - 4*srad;
00447 k = kk - 4*srad;
00448 s = (i*i*hx*hx + j*j*hy*hy + k*k*hz*hz) / (a*a);
00449 t = 0.25f * s;
00450 if (t >= 1) {
00451 wt[index] = 0;
00452 }
00453 else if (s >= 1) {
00454 gs = 1/sqrtf(s);
00455 SPOLY(>, t, split);
00456 wt[index] = (gs - 0.5f * gt) / a;
00457 }
00458 else {
00459 SPOLY(&gs, s, split);
00460 SPOLY(>, t, split);
00461 wt[index] = (gs - 0.5f * gt) / a;
00462 }
00463 }
00464 }
00465 }
00466
00467
00468 maxgridpts = stotal * SUBCUBESZ;
00469 memsz = maxgridpts * sizeof(float);
00470 if (mc->maxgridpts < memsz) {
00471 void *v;
00472 v = realloc(mc->host_qgrids, memsz);
00473 if (NULL == v) return ERROR(MSMPOT_ERROR_MALLOC);
00474 mc->host_qgrids = (float *) v;
00475 v = realloc(mc->host_egrids, memsz);
00476 if (NULL == v) return ERROR(MSMPOT_ERROR_MALLOC);
00477 mc->host_egrids = (float *) v;
00478 cudaFree(mc->device_qgrids);
00479 CUERR(MSMPOT_ERROR_CUDA_MALLOC);
00480 cudaMalloc(&v, memsz);
00481 CUERR(MSMPOT_ERROR_CUDA_MALLOC);
00482 mc->device_qgrids = (float *) v;
00483 cudaFree(mc->device_egrids);
00484 CUERR(MSMPOT_ERROR_CUDA_MALLOC);
00485 cudaMalloc(&v, memsz);
00486 CUERR(MSMPOT_ERROR_CUDA_MALLOC);
00487 mc->device_egrids = (float *) v;
00488 mc->maxgridpts = maxgridpts;
00489 }
00490
00491 return MSMPOT_SUCCESS;
00492 }
00493
00494
00495
00496 int Msmpot_cuda_condense_qgrids(MsmpotCuda *mc) {
00497 const int *host_sinfo = mc->host_sinfo;
00498 float *host_qgrids = mc->host_qgrids;
00499
00500 const long memsz = mc->subcube_total * SUBCUBESZ * sizeof(float);
00501 const int nlevels = mc->lk_nlevels;
00502 const int pad = mc->lk_padding;
00503 int level, in, jn, kn, i, j, k;
00504 int isrc, jsrc, ksrc, subcube_index, grid_index, off;
00505
00506 const int ispx = (IS_SET_X(mc->msmpot->isperiodic) != 0);
00507 const int ispy = (IS_SET_Y(mc->msmpot->isperiodic) != 0);
00508 const int ispz = (IS_SET_Z(mc->msmpot->isperiodic) != 0);
00509
00510 memset(host_qgrids, 0, memsz);
00511
00512 off = 0;
00513 for (level = 0; level < nlevels; level++) {
00514 const floatGrid *qgrid = &(mc->msmpot->qh[level]);
00515 const float *qbuffer = qgrid->buffer;
00516
00517 const int ni = (int) (qgrid->ni);
00518 const int nj = (int) (qgrid->nj);
00519 const int nk = (int) (qgrid->nk);
00520
00521 const int nx = host_sinfo[4*level ];
00522 const int ny = host_sinfo[4*level + 1];
00523 const int nz = host_sinfo[4*level + 2];
00524
00525 #ifdef MSMPOT_DEBUG
00526 printf("level=%d\n", level);
00527 printf(" nx=%d ny=%d nz=%d\n", nx, ny, nz);
00528 printf(" ni=%d nj=%d nk=%d\n", ni, nj, nk);
00529 #endif
00530
00531 for (kn = 0; kn < nz; kn++) {
00532 for (jn = 0; jn < ny; jn++) {
00533 for (in = 0; in < nx; in++) {
00534
00535 for (k = 0; k < 4; k++) {
00536 ksrc = (kn-pad)*4 + k;
00537 if (ispz) {
00538 while (ksrc < 0) ksrc += nk;
00539 while (ksrc >= nk) ksrc -= nk;
00540 }
00541 else if (ksrc < 0 || ksrc >= nk) break;
00542
00543 for (j = 0; j < 4; j++) {
00544 jsrc = (jn-pad)*4 + j;
00545 if (ispy) {
00546 while (jsrc < 0) jsrc += nj;
00547 while (jsrc >= nj) jsrc -= nj;
00548 }
00549 else if (jsrc < 0 || jsrc >= nj) break;
00550
00551 for (i = 0; i < 4; i++) {
00552 isrc = (in-pad)*4 + i;
00553 if (ispx) {
00554 while (isrc < 0) isrc += ni;
00555 while (isrc >= ni) isrc -= ni;
00556 }
00557 else if (isrc < 0 || isrc >= ni) break;
00558
00559 grid_index = (ksrc * nj + jsrc) * ni + isrc;
00560 subcube_index = (((kn*ny + jn)*nx + in) + off) * SUBCUBESZ
00561 + (k*4 + j)*4 + i;
00562
00563 host_qgrids[subcube_index] = qbuffer[grid_index];
00564 }
00565 }
00566 }
00567
00568 }
00569 }
00570 }
00571
00572 off += nx * ny * nz;
00573
00574 }
00575
00576 return 0;
00577 }
00578
00579
00580
00581 int Msmpot_cuda_expand_egrids(MsmpotCuda *mc) {
00582 const int *host_sinfo = mc->host_sinfo;
00583 const float *host_egrids = mc->host_egrids;
00584
00585 const int nlevels = mc->lk_nlevels;
00586 const int pad = mc->lk_padding;
00587 int level, in, jn, kn, i, j, k;
00588 int isrc, jsrc, ksrc, subcube_index, grid_index, off;
00589
00590 off = 0;
00591 for (level = 0; level < nlevels; level++) {
00592 floatGrid *egrid = &(mc->msmpot->eh[level]);
00593 float *ebuffer = egrid->buffer;
00594
00595 const int ni = (int) (egrid->ni);
00596 const int nj = (int) (egrid->nj);
00597 const int nk = (int) (egrid->nk);
00598
00599 const int nx = host_sinfo[4*level ];
00600 const int ny = host_sinfo[4*level + 1];
00601 const int nz = host_sinfo[4*level + 2];
00602
00603 for (kn = pad; kn < nz-pad; kn++) {
00604 for (jn = pad; jn < ny-pad; jn++) {
00605 for (in = pad; in < nx-pad; in++) {
00606
00607 for (k = 0; k < 4; k++) {
00608 ksrc = (kn-pad)*4 + k;
00609 if (ksrc >= nk) break;
00610
00611 for (j = 0; j < 4; j++) {
00612 jsrc = (jn-pad)*4 + j;
00613 if (jsrc >= nj) break;
00614
00615 for (i = 0; i < 4; i++) {
00616 isrc = (in-pad)*4 + i;
00617 if (isrc >= ni) break;
00618
00619 grid_index = (ksrc * nj + jsrc) * ni + isrc;
00620 subcube_index = (((kn*ny + jn)*nx + in) + off) * SUBCUBESZ
00621 + (k*4 + j)*4 + i;
00622
00623 ebuffer[grid_index] = host_egrids[subcube_index];
00624 }
00625 }
00626 }
00627
00628 }
00629 }
00630 }
00631
00632 off += nx * ny * nz;
00633
00634 }
00635
00636 return 0;
00637 }
00638
00639
00640 int Msmpot_cuda_compute_latcut(MsmpotCuda *mc) {
00641 const int nlevels = mc->lk_nlevels;
00642 const int srad = mc->lk_srad;
00643 const int padding = mc->lk_padding;
00644 const int wt_total = (8*srad) * (8*srad) * (8*srad);
00645 const long memsz = mc->subcube_total * SUBCUBESZ * sizeof(float);
00646
00647 dim3 gridDim, blockDim;
00648
00649 unsigned int bx = mc->block_total;
00650 unsigned int by = 1;
00651 #define MAX_GRID_DIM 65536u
00652
00653 while (bx > MAX_GRID_DIM) {
00654 bx >>= 1;
00655 by <<= 1;
00656 }
00657 if (bx * by < (unsigned int)(mc->block_total)) bx++;
00658 if (bx > MAX_GRID_DIM || by > MAX_GRID_DIM) {
00659
00660 return ERROR(MSMPOT_ERROR_CUDA_SUPPORT);
00661 }
00662
00663 gridDim.x = (int) bx;
00664 gridDim.y = (int) by;
00665 gridDim.z = 1;
00666
00667 blockDim.x = 4;
00668 blockDim.y = 4;
00669 blockDim.z = 4;
00670
00671
00672 cudaMemcpyToSymbol(sinfo, mc->host_sinfo, nlevels * sizeof(int4), 0);
00673 CUERR(MSMPOT_ERROR_CUDA_MEMCPY);
00674 cudaMemcpyToSymbol(lfac, mc->host_lfac, nlevels * sizeof(float), 0);
00675 CUERR(MSMPOT_ERROR_CUDA_MEMCPY);
00676 cudaMemcpyToSymbol(wt, mc->host_wt, wt_total * sizeof(float), 0);
00677 CUERR(MSMPOT_ERROR_CUDA_MEMCPY);
00678
00679
00680 cudaMemcpy(mc->device_qgrids, mc->host_qgrids, memsz, cudaMemcpyHostToDevice);
00681 CUERR(MSMPOT_ERROR_CUDA_MEMCPY);
00682
00683
00684 #ifdef MSMPOT_DEBUG
00685 printf("gridDim.x=%d\n", gridDim.x);
00686 printf("gridDim.y=%d\n", gridDim.y);
00687 printf("nsubcubes=%u (using %u extra thread blocks)\n",
00688 (uint)(mc->block_total),
00689 (gridDim.x*gridDim.y - (uint)(mc->block_total)));
00690 printf("nlevels=%d\n", nlevels);
00691 printf("srad=%d\n", srad);
00692 printf("padding=%d\n", padding);
00693 printf("address of qgrids=%lx\n", (long) (mc->device_qgrids));
00694 printf("address of egrids=%lx\n", (long) (mc->device_egrids));
00695 #endif
00696 cuda_latcut<<<gridDim, blockDim, 0>>>((unsigned int)(mc->block_total),
00697 nlevels, srad, padding, mc->device_qgrids, mc->device_egrids);
00698 CUERR(MSMPOT_ERROR_CUDA_KERNEL);
00699
00700
00701 cudaMemcpy(mc->host_egrids, mc->device_egrids, memsz, cudaMemcpyDeviceToHost);
00702 CUERR(MSMPOT_ERROR_CUDA_MEMCPY);
00703
00704 return 0;
00705 }