00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00021
00022
00023 #if defined(VMDCPUDISPATCH) && defined(VMDUSEAVX512)
00024
00025 #include <immintrin.h>
00026
00027 #include <math.h>
00028 #include <stdio.h>
00029 #include "Orbital.h"
00030 #include "DrawMolecule.h"
00031 #include "utilities.h"
00032 #include "Inform.h"
00033 #include "WKFThreads.h"
00034 #include "WKFUtils.h"
00035 #include "ProfileHooks.h"
00036
00037 #define ANGS_TO_BOHR 1.88972612478289694072f
00038
00039 #if defined(__GNUC__) && ! defined(__INTEL_COMPILER)
00040 #define __align(X) __attribute__((aligned(X) ))
00041 #else
00042 #define __align(X) __declspec(align(X) )
00043 #endif
00044
00045 #define MLOG2EF -1.44269504088896f
00046
00047 #if 0
00048 static void print_mm512_ps(__m512 v) {
00049 __attribute__((aligned(64))) float tmp[16];
00050 _mm512_storeu_ps(&tmp[0], v);
00051
00052 printf("mm512: ");
00053 int i;
00054 for (i=0; i<16; i++)
00055 printf("%g ", tmp[i]);
00056 printf("\n");
00057 }
00058 #endif
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071 #define SCEXP0 1.0000000000000000f
00072 #define SCEXP1 0.6987082824680118f
00073 #define SCEXP2 0.2633174272827404f
00074 #define SCEXP3 0.0923611991471395f
00075 #define SCEXP4 0.0277520543324108f
00076
00077
00078 #define EXPOBIAS 127
00079 #define EXPOSHIFT 23
00080
00081
00082 #define ACUTOFF -10
00083
00084 typedef union AVX512reg_t {
00085 __m512 f;
00086 __m512i i;
00087 } AVX512reg;
00088
00089 __m512 aexpfnxavx512f(__m512 x) {
00090 __mmask16 mask;
00091 mask = _mm512_cmpnle_ps_mask(_mm512_set1_ps(ACUTOFF), x);
00092 #if 0
00093
00094 if (_mm512_movemask_ps(scal.f) == 0) {
00095 return _mm512_set1_ps(0.0f);
00096 }
00097
00098 #endif
00099
00100
00101
00102
00103
00104
00105
00106 __align(64) AVX512reg n;
00107 __m512 mb = _mm512_mul_ps(x, _mm512_set1_ps(MLOG2EF));
00108 n.i = _mm512_cvttps_epi32(mb);
00109 __m512 mbflr = _mm512_cvtepi32_ps(n.i);
00110 __m512 d = _mm512_sub_ps(mbflr, mb);
00111
00112
00113
00114 __m512 y;
00115 y = _mm512_fmadd_ps(d, _mm512_set1_ps(SCEXP4), _mm512_set1_ps(SCEXP3));
00116 y = _mm512_fmadd_ps(y, d, _mm512_set1_ps(SCEXP2));
00117 y = _mm512_fmadd_ps(y, d, _mm512_set1_ps(SCEXP1));
00118 y = _mm512_fmadd_ps(y, d, _mm512_set1_ps(SCEXP0));
00119
00120
00121
00122 n.i = _mm512_sub_epi32(_mm512_set1_epi32(EXPOBIAS), n.i);
00123 n.i = _mm512_slli_epi32(n.i, EXPOSHIFT);
00124 n.f = _mm512_mask_mul_ps(n.f, mask, _mm512_set1_ps(0.0f), n.f);
00125 y = _mm512_mul_ps(y, n.f);
00126 return y;
00127 }
00128
00129
00130
00131
00132
00133 int evaluate_grid_avx512f(int numatoms,
00134 const float *wave_f, const float *basis_array,
00135 const float *atompos,
00136 const int *atom_basis,
00137 const int *num_shells_per_atom,
00138 const int *num_prim_per_shell,
00139 const int *shell_types,
00140 const int *numvoxels,
00141 float voxelsize,
00142 const float *origin,
00143 int density,
00144 float * orbitalgrid) {
00145 if (!orbitalgrid)
00146 return -1;
00147
00148 int nx, ny, nz;
00149 __attribute__((aligned(64))) float sxdelta[16];
00150 for (nx=0; nx<16; nx++)
00151 sxdelta[nx] = ((float) nx) * voxelsize * ANGS_TO_BOHR;
00152
00153
00154
00155 int numgridxy = numvoxels[0]*numvoxels[1];
00156 for (nz=0; nz<numvoxels[2]; nz++) {
00157 float grid_x, grid_y, grid_z;
00158 grid_z = origin[2] + nz * voxelsize;
00159 for (ny=0; ny<numvoxels[1]; ny++) {
00160 grid_y = origin[1] + ny * voxelsize;
00161 int gaddrzy = ny*numvoxels[0] + nz*numgridxy;
00162 for (nx=0; nx<numvoxels[0]; nx+=16) {
00163 grid_x = origin[0] + nx * voxelsize;
00164
00165
00166
00167 int at;
00168 int prim, shell;
00169
00170
00171 __m512 value = _mm512_set1_ps(0.0f);
00172
00173
00174 int ifunc = 0;
00175 int shell_counter = 0;
00176
00177
00178 for (at=0; at<numatoms; at++) {
00179 int maxshell = num_shells_per_atom[at];
00180 int prim_counter = atom_basis[at];
00181
00182
00183 float sxdist = (grid_x - atompos[3*at ])*ANGS_TO_BOHR;
00184 float sydist = (grid_y - atompos[3*at+1])*ANGS_TO_BOHR;
00185 float szdist = (grid_z - atompos[3*at+2])*ANGS_TO_BOHR;
00186
00187 float sydist2 = sydist*sydist;
00188 float szdist2 = szdist*szdist;
00189 float yzdist2 = sydist2 + szdist2;
00190
00191 __m512 xdelta = _mm512_load_ps(&sxdelta[0]);
00192 __m512 xdist = _mm512_set1_ps(sxdist);
00193 xdist = _mm512_add_ps(xdist, xdelta);
00194 __m512 ydist = _mm512_set1_ps(sydist);
00195 __m512 zdist = _mm512_set1_ps(szdist);
00196 __m512 xdist2 = _mm512_mul_ps(xdist, xdist);
00197 __m512 ydist2 = _mm512_mul_ps(ydist, ydist);
00198 __m512 zdist2 = _mm512_mul_ps(zdist, zdist);
00199 __m512 dist2 = _mm512_set1_ps(yzdist2);
00200 dist2 = _mm512_add_ps(dist2, xdist2);
00201
00202
00203
00204
00205
00206
00207 for (shell=0; shell < maxshell; shell++) {
00208 __m512 contracted_gto = _mm512_set1_ps(0.0f);
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218 int maxprim = num_prim_per_shell[shell_counter];
00219 int shelltype = shell_types[shell_counter];
00220 for (prim=0; prim<maxprim; prim++) {
00221
00222 float exponent = -basis_array[prim_counter ];
00223 float contract_coeff = basis_array[prim_counter + 1];
00224
00225
00226 __m512 expval = _mm512_mul_ps(_mm512_set1_ps(exponent), dist2);
00227
00228 __m512 retval = aexpfnxavx512f(expval);
00229 contracted_gto = _mm512_fmadd_ps(_mm512_set1_ps(contract_coeff), retval, contracted_gto);
00230
00231 prim_counter += 2;
00232 }
00233
00234
00235 __m512 tmpshell = _mm512_set1_ps(0.0f);
00236 switch (shelltype) {
00237
00238 case S_SHELL:
00239 value = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), contracted_gto, value);
00240 break;
00241
00242 case P_SHELL:
00243 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist, tmpshell);
00244 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist, tmpshell);
00245 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist, tmpshell);
00246 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00247 break;
00248
00249 case D_SHELL:
00250 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist2, tmpshell);
00251 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, ydist), tmpshell);
00252 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist2, tmpshell);
00253 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, zdist), tmpshell);
00254 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist, zdist), tmpshell);
00255 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist2, tmpshell);
00256 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00257 break;
00258
00259 case F_SHELL:
00260 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, xdist), tmpshell);
00261 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, ydist), tmpshell);
00262 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, xdist), tmpshell);
00263 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, ydist), tmpshell);
00264 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, zdist), tmpshell);
00265 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(_mm512_mul_ps(xdist, ydist), zdist), tmpshell);
00266 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, zdist), tmpshell);
00267 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, xdist), tmpshell);
00268 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, ydist), tmpshell);
00269 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, zdist), tmpshell);
00270 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00271 break;
00272
00273 #if 0
00274 default:
00275
00276 int i, j;
00277 float xdp, ydp, zdp;
00278 float xdiv = 1.0f / xdist;
00279 for (j=0, zdp=1.0f; j<=shelltype; j++, zdp*=zdist) {
00280 int imax = shelltype - j;
00281 for (i=0, ydp=1.0f, xdp=pow(xdist, imax); i<=imax; i++, ydp*=ydist, xdp*=xdiv) {
00282 tmpshell += wave_f[ifunc++] * xdp * ydp * zdp;
00283 }
00284 }
00285 value += tmpshell * contracted_gto;
00286 #endif
00287 }
00288
00289 shell_counter++;
00290 }
00291 }
00292
00293
00294 if (density) {
00295 __mmask16 mask = _mm512_cmplt_ps_mask(value, _mm512_set1_ps(0.0f));
00296 __m512 sqdensity = _mm512_mul_ps(value, value);
00297 __m512 orbdensity = _mm512_mask_mul_ps(sqdensity, mask, sqdensity,
00298 _mm512_set1_ps(-1.0f));
00299 _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], orbdensity);
00300 } else {
00301 _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], value);
00302 }
00303 }
00304 }
00305 }
00306
00307
00308
00309
00310
00311 _mm256_zeroupper();
00312
00313 return 0;
00314 }
00315
00316 #endif
00317
00318