00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00022 #include <stdio.h>
00023 #include <stdlib.h>
00024 #include <string.h>
00025 #include <math.h>
00026 #include "WKFThreads.h"
00027 #include "OrbitalJIT.h"
00028
00029
00030
00031
00032
00033 #define ANGS_TO_BOHR 1.8897259877218677f
00034
00035
00036 #define UNROLLX 1
00037 #define UNROLLY 1
00038 #define BLOCKSIZEX 8
00039 #define BLOCKSIZEY 8
00040 #define BLOCKSIZE BLOCKSIZEX * BLOCKSIZEY
00041
00042
00043 #define TILESIZEX BLOCKSIZEX*UNROLLX
00044 #define TILESIZEY BLOCKSIZEY*UNROLLY
00045 #define GPU_X_ALIGNMASK (TILESIZEX - 1)
00046 #define GPU_Y_ALIGNMASK (TILESIZEY - 1)
00047
00048 #define MEMCOALESCE 384
00049
00050
00051 #define S_SHELL 0
00052 #define P_SHELL 1
00053 #define D_SHELL 2
00054 #define F_SHELL 3
00055 #define G_SHELL 4
00056 #define H_SHELL 5
00057
00058
00059
00060
00061 #define MAX_ATOM_SZ 256
00062
00063 #define MAX_ATOMPOS_SZ (MAX_ATOM_SZ)
00064
00065
00066 #define MAX_ATOM_BASIS_SZ (MAX_ATOM_SZ)
00067
00068
00069 #define MAX_ATOMSHELL_SZ (MAX_ATOM_SZ)
00070
00071
00072 #define MAX_BASIS_SZ 6144
00073
00074
00075 #define MAX_SHELL_SZ 1024
00076
00077
00078
00079 #define MAX_WAVEF_SZ 6144
00080
00081
00082
00083
00084
00085
00086 int orbital_jit_generate(int jitlanguage,
00087 const char * srcfilename, int numatoms,
00088 const float *wave_f, const float *basis_array,
00089 const int *atom_basis,
00090 const int *num_shells_per_atom,
00091 const int *num_prim_per_shell,
00092 const int *shell_types) {
00093 FILE *ofp=NULL;
00094 if (srcfilename)
00095 ofp=fopen(srcfilename, "w");
00096
00097 if (ofp == NULL)
00098 ofp=stdout;
00099
00100
00101
00102 int at;
00103 int prim, shell;
00104
00105
00106 int shell_counter = 0;
00107
00108 if (jitlanguage == ORBITAL_JIT_CUDA) {
00109 fprintf(ofp,
00110 "__global__ static void cuorbitalconstmem_jit(int numatoms,\n"
00111 " float voxelsize,\n"
00112 " float originx,\n"
00113 " float originy,\n"
00114 " float grid_z, \n"
00115 " int density, \n"
00116 " float * orbitalgrid) {\n"
00117 " unsigned int xindex = blockIdx.x * blockDim.x + threadIdx.x;\n"
00118 " unsigned int yindex = blockIdx.y * blockDim.y + threadIdx.y;\n"
00119 " unsigned int outaddr = gridDim.x * blockDim.x * yindex + xindex;\n"
00120 );
00121 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00122 fprintf(ofp,
00123 "// unit conversion \n"
00124 "#define ANGS_TO_BOHR 1.8897259877218677f \n"
00125 );
00126
00127 fprintf(ofp, "__kernel __attribute__((reqd_work_group_size(%d, %d, 1)))\n",
00128 BLOCKSIZEX, BLOCKSIZEY);
00129
00130 fprintf(ofp,
00131 "void clorbitalconstmem_jit(int numatoms, \n"
00132 " __constant float *const_atompos, \n"
00133 " __constant float *const_wave_f, \n"
00134 " float voxelsize, \n"
00135 " float originx, \n"
00136 " float originy, \n"
00137 " float grid_z, \n"
00138 " int density, \n"
00139 " __global float * orbitalgrid) { \n"
00140 " unsigned int xindex = get_global_id(0); \n"
00141 " unsigned int yindex = get_global_id(1); \n"
00142 " unsigned int outaddr = get_global_size(0) * yindex + xindex; \n"
00143 );
00144 }
00145
00146 fprintf(ofp,
00147 " float grid_x = originx + voxelsize * xindex;\n"
00148 " float grid_y = originy + voxelsize * yindex;\n"
00149
00150 " // similar to C version\n"
00151 " int at;\n"
00152 " // initialize value of orbital at gridpoint\n"
00153 " float value = 0.0f;\n"
00154 " // initialize the wavefunction and shell counters\n"
00155 " int ifunc = 0;\n"
00156 " // loop over all the QM atoms\n"
00157 " for (at = 0; at < numatoms; at++) {\n"
00158 " // calculate distance between grid point and center of atom\n"
00159
00160
00161 " float xdist = (grid_x - const_atompos[3*at ])*ANGS_TO_BOHR;\n"
00162 " float ydist = (grid_y - const_atompos[3*at+1])*ANGS_TO_BOHR;\n"
00163 " float zdist = (grid_z - const_atompos[3*at+2])*ANGS_TO_BOHR;\n"
00164 " float xdist2 = xdist*xdist;\n"
00165 " float ydist2 = ydist*ydist;\n"
00166 " float zdist2 = zdist*zdist;\n"
00167 " float dist2 = xdist2 + ydist2 + zdist2;\n"
00168 " float contracted_gto=0.0f;\n"
00169 " float tmpshell=0.0f;\n"
00170 "\n"
00171 );
00172
00173 #if 0
00174
00175 for (at=0; at<numatoms; at++) {
00176 #else
00177
00178 for (at=0; at<1; at++) {
00179 #endif
00180 int maxshell = num_shells_per_atom[at];
00181 int prim_counter = atom_basis[at];
00182
00183
00184 for (shell=0; shell < maxshell; shell++) {
00185
00186
00187 int maxprim = num_prim_per_shell[shell_counter];
00188 int shelltype = shell_types[shell_counter];
00189 for (prim=0; prim<maxprim; prim++) {
00190 float exponent = basis_array[prim_counter ];
00191 float contract_coeff = basis_array[prim_counter + 1];
00192 #if 1
00193 if (jitlanguage == ORBITAL_JIT_CUDA) {
00194 if (prim == 0) {
00195 fprintf(ofp," contracted_gto = %ff * exp2f(-%ff*dist2);\n",
00196 contract_coeff, exponent);
00197 } else {
00198 fprintf(ofp," contracted_gto += %ff * exp2f(-%ff*dist2);\n",
00199 contract_coeff, exponent);
00200 }
00201 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00202 if (prim == 0) {
00203 fprintf(ofp," contracted_gto = %ff * native_exp2(-%ff*dist2);\n",
00204 contract_coeff, exponent);
00205 } else {
00206 fprintf(ofp," contracted_gto += %ff * native_exp2(-%ff*dist2);\n",
00207 contract_coeff, exponent);
00208 }
00209 }
00210 #else
00211 if (jitlanguage == ORBITAL_JIT_CUDA) {
00212 if (prim == 0) {
00213 fprintf(ofp," contracted_gto = %ff * expf(-%ff*dist2);\n",
00214 contract_coeff, exponent);
00215 } else {
00216 fprintf(ofp," contracted_gto += %ff * expf(-%ff*dist2);\n",
00217 contract_coeff, exponent);
00218 }
00219 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00220 if (prim == 0) {
00221 fprintf(ofp," contracted_gto = %ff * native_exp(-%ff*dist2);\n",
00222 contract_coeff, exponent);
00223 } else {
00224 fprintf(ofp," contracted_gto += %ff * native_exp(-%ff*dist2);\n",
00225 contract_coeff, exponent);
00226 }
00227 }
00228 #endif
00229 prim_counter += 2;
00230 }
00231
00232
00233 switch (shelltype) {
00234 case S_SHELL:
00235 fprintf(ofp,
00236 " // S_SHELL\n"
00237 " value += const_wave_f[ifunc++] * contracted_gto;\n");
00238 break;
00239
00240 case P_SHELL:
00241 fprintf(ofp,
00242 " // P_SHELL\n"
00243 " tmpshell = const_wave_f[ifunc++] * xdist;\n"
00244 " tmpshell += const_wave_f[ifunc++] * ydist;\n"
00245 " tmpshell += const_wave_f[ifunc++] * zdist;\n"
00246 " value += tmpshell * contracted_gto;\n"
00247 );
00248 break;
00249
00250 case D_SHELL:
00251 fprintf(ofp,
00252 " // D_SHELL\n"
00253 " tmpshell = const_wave_f[ifunc++] * xdist2;\n"
00254 " tmpshell += const_wave_f[ifunc++] * xdist * ydist;\n"
00255 " tmpshell += const_wave_f[ifunc++] * ydist2;\n"
00256 " tmpshell += const_wave_f[ifunc++] * xdist * zdist;\n"
00257 " tmpshell += const_wave_f[ifunc++] * ydist * zdist;\n"
00258 " tmpshell += const_wave_f[ifunc++] * zdist2;\n"
00259 " value += tmpshell * contracted_gto;\n"
00260 );
00261 break;
00262
00263 case F_SHELL:
00264 fprintf(ofp,
00265 " // F_SHELL\n"
00266 " tmpshell = const_wave_f[ifunc++] * xdist2 * xdist;\n"
00267 " tmpshell += const_wave_f[ifunc++] * xdist2 * ydist;\n"
00268 " tmpshell += const_wave_f[ifunc++] * ydist2 * xdist;\n"
00269 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist;\n"
00270 " tmpshell += const_wave_f[ifunc++] * xdist2 * zdist;\n"
00271 " tmpshell += const_wave_f[ifunc++] * xdist * ydist * zdist;\n"
00272 " tmpshell += const_wave_f[ifunc++] * ydist2 * zdist;\n"
00273 " tmpshell += const_wave_f[ifunc++] * zdist2 * xdist;\n"
00274 " tmpshell += const_wave_f[ifunc++] * zdist2 * ydist;\n"
00275 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist;\n"
00276 " value += tmpshell * contracted_gto;\n"
00277 );
00278 break;
00279
00280 case G_SHELL:
00281 fprintf(ofp,
00282 " // G_SHELL\n"
00283 " tmpshell = const_wave_f[ifunc++] * xdist2 * xdist2;\n"
00284 " tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * ydist;\n"
00285 " tmpshell += const_wave_f[ifunc++] * xdist2 * ydist2;\n"
00286 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * xdist;\n"
00287 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist2;\n"
00288 " tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * zdist;\n"
00289 " tmpshell += const_wave_f[ifunc++] * xdist2 * ydist * zdist;\n"
00290 " tmpshell += const_wave_f[ifunc++] * ydist2 * xdist * zdist;\n"
00291 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * zdist;\n"
00292 " tmpshell += const_wave_f[ifunc++] * xdist2 * zdist2;\n"
00293 " tmpshell += const_wave_f[ifunc++] * zdist2 * xdist * ydist;\n"
00294 " tmpshell += const_wave_f[ifunc++] * ydist2 * zdist2;\n"
00295 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * xdist;\n"
00296 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * ydist;\n"
00297 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist2;\n"
00298 " value += tmpshell * contracted_gto;\n"
00299 );
00300 break;
00301
00302 }
00303 fprintf(ofp, "\n");
00304
00305 shell_counter++;
00306 }
00307 }
00308
00309 fprintf(ofp,
00310 " }\n"
00311 "\n"
00312 " // return either orbital density or orbital wavefunction amplitude \n"
00313 " if (density) { \n"
00314 );
00315
00316 if (jitlanguage == ORBITAL_JIT_CUDA) {
00317 fprintf(ofp, " orbitalgrid[outaddr] = copysignf(value*value, value);\n");
00318 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00319 fprintf(ofp, " orbitalgrid[outaddr] = copysign(value*value, value);\n");
00320 }
00321
00322 fprintf(ofp,
00323 " } else { \n"
00324 " orbitalgrid[outaddr] = value; \n"
00325 " }\n"
00326 "}\n"
00327 );
00328
00329 if (ofp != stdout)
00330 fclose(ofp);
00331
00332 return 0;
00333 }
00334
00335
00336