@@ -39,14 +39,18 @@ eval_psp_energy_correction(T, ::Element) = zero(T)
3939eval_psp_energy_correction (psp:: Element ) = eval_psp_energy_correction (Float64, psp)
4040
4141# Fall back to the Gaussian table for Elements without pseudopotentials
42- function valence_charge_density_fourier (el:: Element , p:: T ) :: T where {T <: Real }
42+ function valence_charge_density_fourier (el:: Element , p)
4343 gaussian_valence_charge_density_fourier (el, p)
4444end
4545
4646""" Gaussian valence charge density using Abinit's coefficient table, in Fourier space."""
4747function gaussian_valence_charge_density_fourier (el:: Element , p:: T ):: T where {T <: Real }
4848 charge_ionic (el) * exp (- (p * atom_decay_length (el))^ 2 )
4949end
50+ function gaussian_valence_charge_density_fourier (el:: Element , ps:: AbstractVector{T} ) where {T <: Real }
51+ arch = architecture (ps)
52+ to_device (arch, map (p -> gaussian_valence_charge_density_fourier (el, p), to_cpu (ps)))
53+ end
5054
5155function core_charge_density_fourier (:: Element , :: T ):: T where {T <: Real }
5256 error (" Abstract elements do not necesesarily provide core charge density." )
@@ -160,40 +164,17 @@ charge_ionic(el::ElementPsp) = charge_ionic(el.psp)
160164has_core_density (el:: ElementPsp ) = has_core_density (el. psp)
161165eval_psp_energy_correction (T, el:: ElementPsp ) = eval_psp_energy_correction (T, el. psp)
162166
163- function local_potential_fourier (el:: ElementPsp , p:: T ) where {T <: Real }
164- eval_psp_local_fourier (el. psp, p)
165- end
166- local_potential_real (el:: ElementPsp , r:: Real ) = eval_psp_local_real (el. psp, r)
167+ local_potential_fourier (el:: ElementPsp , p) = eval_psp_local_fourier (el. psp, p)
168+ local_potential_real (el:: ElementPsp , r) = eval_psp_local_real (el. psp, r)
167169
168- function valence_charge_density_fourier (el:: ElementPsp , p:: T ) where {T <: Real }
170+ function valence_charge_density_fourier (el:: ElementPsp , p)
169171 if has_valence_density (el. psp)
170172 eval_psp_density_valence_fourier (el. psp, p)
171173 else
172174 gaussian_valence_charge_density_fourier (el, p)
173175 end
174176end
175- function core_charge_density_fourier (el:: ElementPsp , p:: T ) where {T <: Real }
176- eval_psp_density_core_fourier (el. psp, p)
177- end
178-
179- # Vectorized versions of the above, specific implementation depending on the Psp type
180- function local_potential_fourier (el:: ElementPsp , ps:: AbstractVector{T} ) where {T <: Real }
181- eval_psp_local_fourier (el. psp, ps)
182- end
183- function local_potential_real (el:: ElementPsp , rs:: AbstractVector{T} ) where {T <: Real }
184- eval_psp_local_real (el. psp, rs)
185- end
186- function valence_charge_density_fourier (el:: ElementPsp , ps:: AbstractVector{T} ) where {T <: Real }
187- if has_valence_density (el. psp)
188- eval_psp_density_valence_fourier (el. psp, ps)
189- else
190- gaussian_valence_charge_density_fourier (el, ps)
191- end
192- end
193- function core_charge_density_fourier (el:: ElementPsp , ps:: AbstractVector{T} ) where {T <: Real }
194- eval_psp_density_core_fourier (el. psp, ps)
195- end
196-
177+ core_charge_density_fourier (el:: ElementPsp , p) = eval_psp_density_core_fourier (el. psp, p)
197178
198179#
199180# ElementCohenBergstresser
@@ -263,7 +244,6 @@ function local_potential_fourier(el::ElementCohenBergstresser, p::T) where {T <:
263244end
264245# TODO Strictly speaking needs a eval_psp_energy_correction
265246
266-
267247#
268248# ElementGaussian
269249#
@@ -298,25 +278,22 @@ end
298278
299279# Generic API expectations: all element functions taking a real space scalar |r| or a
300280# reciprocal space scalar |p| should have a vectorized version accepting vectors of |r| or |p|.
301- # This macro vectorizes element functions by calling existing single-value version elementwise.
302- # This is GPU safe and generic. Performance critical functions should have their own
303- # GPU-optimized implementation.
304- macro vectorize_element_function (fn)
305- quote
306- function $fn (el:: Element , arg:: AbstractVector{T} ) where {T <: Real }
307- arch = architecture (arg)
308- to_device (arch, map (p -> $ fn (el, p), to_cpu (arg)))
281+ # The following loop vectorizes element functions by calling the single-value version
282+ # elementwise. This is GPU safe and generic. Performance critical functions should have their
283+ # own GPU-optimized implementation. Note: explicit loop over Element types in order to avoid
284+ # ambiguities with the specific ElementPsp functions.
285+ for fn in [:gaussian_valence_charge_density_fourier , :core_charge_density_fourier ,
286+ :local_potential_fourier , :local_potential_real ]
287+ for el_type in [ElementCoulomb, ElementCohenBergstresser, ElementGaussian]
288+ @eval begin
289+ function DFTK. $fn (el:: $el_type , arg:: AbstractVector{T} ) where {T <: Real }
290+ arch = architecture (arg)
291+ to_device (arch, map (p -> $ fn (el, p), to_cpu (arg)))
292+ end
309293 end
310294 end
311295end
312296
313- # Generic vectorized element functions
314- @vectorize_element_function DFTK. valence_charge_density_fourier
315- @vectorize_element_function DFTK. gaussian_valence_charge_density_fourier
316- @vectorize_element_function DFTK. core_charge_density_fourier
317- @vectorize_element_function DFTK. local_potential_fourier
318- @vectorize_element_function DFTK. local_potential_real
319-
320297#
321298# Helper functions
322299#
0 commit comments