Skip to content

Commit e7ec155

Browse files
committed
Reorganize macros
1 parent 65c5dbe commit e7ec155

5 files changed

Lines changed: 43 additions & 66 deletions

File tree

src/elements.jl

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,18 @@ eval_psp_energy_correction(T, ::Element) = zero(T)
3939
eval_psp_energy_correction(psp::Element) = eval_psp_energy_correction(Float64, psp)
4040

4141
# Fall back to the Gaussian table for Elements without pseudopotentials
42-
function valence_charge_density_fourier(el::Element, p::T)::T where {T <: Real}
42+
function valence_charge_density_fourier(el::Element, p)
4343
gaussian_valence_charge_density_fourier(el, p)
4444
end
4545

4646
"""Gaussian valence charge density using Abinit's coefficient table, in Fourier space."""
4747
function gaussian_valence_charge_density_fourier(el::Element, p::T)::T where {T <: Real}
4848
charge_ionic(el) * exp(-(p * atom_decay_length(el))^2)
4949
end
50+
function gaussian_valence_charge_density_fourier(el::Element, ps::AbstractVector{T}) where {T <: Real}
51+
arch = architecture(ps)
52+
to_device(arch, map(p -> gaussian_valence_charge_density_fourier(el, p), to_cpu(ps)))
53+
end
5054

5155
function core_charge_density_fourier(::Element, ::T)::T where {T <: Real}
5256
error("Abstract elements do not necesesarily provide core charge density.")
@@ -160,40 +164,17 @@ charge_ionic(el::ElementPsp) = charge_ionic(el.psp)
160164
has_core_density(el::ElementPsp) = has_core_density(el.psp)
161165
eval_psp_energy_correction(T, el::ElementPsp) = eval_psp_energy_correction(T, el.psp)
162166

163-
function local_potential_fourier(el::ElementPsp, p::T) where {T <: Real}
164-
eval_psp_local_fourier(el.psp, p)
165-
end
166-
local_potential_real(el::ElementPsp, r::Real) = eval_psp_local_real(el.psp, r)
167+
local_potential_fourier(el::ElementPsp, p) = eval_psp_local_fourier(el.psp, p)
168+
local_potential_real(el::ElementPsp, r) = eval_psp_local_real(el.psp, r)
167169

168-
function valence_charge_density_fourier(el::ElementPsp, p::T) where {T <: Real}
170+
function valence_charge_density_fourier(el::ElementPsp, p)
169171
if has_valence_density(el.psp)
170172
eval_psp_density_valence_fourier(el.psp, p)
171173
else
172174
gaussian_valence_charge_density_fourier(el, p)
173175
end
174176
end
175-
function core_charge_density_fourier(el::ElementPsp, p::T) where {T <: Real}
176-
eval_psp_density_core_fourier(el.psp, p)
177-
end
178-
179-
# Vectorized versions of the above, specific implementation depending on the Psp type
180-
function local_potential_fourier(el::ElementPsp, ps::AbstractVector{T}) where {T <: Real}
181-
eval_psp_local_fourier(el.psp, ps)
182-
end
183-
function local_potential_real(el::ElementPsp, rs::AbstractVector{T}) where {T <: Real}
184-
eval_psp_local_real(el.psp, rs)
185-
end
186-
function valence_charge_density_fourier(el::ElementPsp, ps::AbstractVector{T}) where {T <: Real}
187-
if has_valence_density(el.psp)
188-
eval_psp_density_valence_fourier(el.psp, ps)
189-
else
190-
gaussian_valence_charge_density_fourier(el, ps)
191-
end
192-
end
193-
function core_charge_density_fourier(el::ElementPsp, ps::AbstractVector{T}) where {T <: Real}
194-
eval_psp_density_core_fourier(el.psp, ps)
195-
end
196-
177+
core_charge_density_fourier(el::ElementPsp, p) = eval_psp_density_core_fourier(el.psp, p)
197178

198179
#
199180
# ElementCohenBergstresser
@@ -263,7 +244,6 @@ function local_potential_fourier(el::ElementCohenBergstresser, p::T) where {T <:
263244
end
264245
# TODO Strictly speaking needs a eval_psp_energy_correction
265246

266-
267247
#
268248
# ElementGaussian
269249
#
@@ -298,25 +278,22 @@ end
298278

299279
# Generic API expectations: all element functions taking a real space scalar |r| or a
300280
# reciprocal space scalar |p| should have a vectorized version accepting vectors of |r| or |p|.
301-
# This macro vectorizes element functions by calling existing single-value version elementwise.
302-
# This is GPU safe and generic. Performance critical functions should have their own
303-
# GPU-optimized implementation.
304-
macro vectorize_element_function(fn)
305-
quote
306-
function $fn(el::Element, arg::AbstractVector{T}) where {T <: Real}
307-
arch = architecture(arg)
308-
to_device(arch, map(p -> $fn(el, p), to_cpu(arg)))
281+
# The following loop vectorizes element functions by calling the single-value version
282+
# elementwise. This is GPU safe and generic. Performance critical functions should have their
283+
# own GPU-optimized implementation. Note: explicit loop over Element types in order to avoid
284+
# ambiguities with the specific ElementPsp functions.
285+
for fn in [:gaussian_valence_charge_density_fourier, :core_charge_density_fourier,
286+
:local_potential_fourier, :local_potential_real]
287+
for el_type in [ElementCoulomb, ElementCohenBergstresser, ElementGaussian]
288+
@eval begin
289+
function DFTK.$fn(el::$el_type, arg::AbstractVector{T}) where {T <: Real}
290+
arch = architecture(arg)
291+
to_device(arch, map(p -> $fn(el, p), to_cpu(arg)))
292+
end
309293
end
310294
end
311295
end
312296

313-
# Generic vectorized element functions
314-
@vectorize_element_function DFTK.valence_charge_density_fourier
315-
@vectorize_element_function DFTK.gaussian_valence_charge_density_fourier
316-
@vectorize_element_function DFTK.core_charge_density_fourier
317-
@vectorize_element_function DFTK.local_potential_fourier
318-
@vectorize_element_function DFTK.local_potential_real
319-
320297
#
321298
# Helper functions
322299
#

src/pseudo/NormConservingPsp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ abstract type NormConservingPsp end
4848
# have their own GPU-optimized implementation instead of relying on this macro. The
4949
# different norm-conserving pseudopotential types are responsible for the implementation
5050
# of vectorized functions, whether by using these macros or not.
51-
macro vectorize_psp_function(fn, PspType)
51+
macro vectorize_psp_function(PspType, fn)
5252
quote
5353
function $fn(psp::$PspType, vec::AbstractVector{T}) where {T <: Real}
5454
arch = architecture(vec)
5555
to_device(arch, map(p -> $fn(psp, p), to_cpu(vec)))
5656
end
5757
end
5858
end
59-
macro vectorize_psp_projector_function(fn, PspType)
59+
macro vectorize_psp_projector_function(PspType, fn)
6060
quote
6161
function $fn(psp::$PspType, i, l, vec::AbstractVector{T}) where {T <: Real}
6262
arch = architecture(vec)

src/pseudo/PspHgh.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ function eval_psp_local_fourier(psp::PspHgh, p::T) where {T <: Real}
121121

122122
4T(π) * rloc^2 * (-Zion + sqrt(T(π) / 2) * rloc * t^2 * P) * exp(-t^2 / 2) / t^2
123123
end
124-
@vectorize_psp_function DFTK.eval_psp_local_fourier PspHgh
124+
@vectorize_psp_function PspHgh DFTK.eval_psp_local_fourier
125125

126126
# [GTH98] (1)
127127
function eval_psp_local_real(psp::PspHgh, r::T) where {T <: Real}
@@ -133,7 +133,7 @@ function eval_psp_local_real(psp::PspHgh, r::T) where {T <: Real}
133133
+ exp(-rr^2 / 2) * (cloc[1] + cloc[2] * rr^2 + cloc[3] * rr^4 + cloc[4] * rr^6)
134134
)
135135
end
136-
@vectorize_psp_function DFTK.eval_psp_local_real PspHgh
136+
@vectorize_psp_function PspHgh DFTK.eval_psp_local_real
137137

138138

139139
# [HGH98] (7-15) except they do it with plane waves normalized by 1/sqrt(Ω).
@@ -162,15 +162,15 @@ function eval_psp_projector_fourier(psp::PspHgh, i, l, p::T) where {T <: Real}
162162

163163
error("Not implemented for l=$l and i=$i")
164164
end
165-
@vectorize_psp_projector_function DFTK.eval_psp_projector_fourier PspHgh
165+
@vectorize_psp_projector_function PspHgh DFTK.eval_psp_projector_fourier
166166

167167
# [HGH98] (3)
168168
function eval_psp_projector_real(psp::PspHgh, i, l, r::T) where {T <: Real}
169169
rp = T(psp.rp[l + 1])
170170
ired = (4i - 1) / T(2)
171171
sqrt(T(2)) * r^(l + 2(i - 1)) * exp(-r^2 / 2rp^2) / rp^(l + ired) / sqrt(gamma(l + ired))
172172
end
173-
@vectorize_psp_projector_function DFTK.eval_psp_projector_real PspHgh
173+
@vectorize_psp_projector_function PspHgh DFTK.eval_psp_projector_real
174174

175175
function eval_psp_energy_correction(T, psp::PspHgh)
176176
# By construction we need to compute the DC component of the difference

src/pseudo/PspLinComb.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ end
8282
@make_psplincomb_call DFTK.eval_psp_density_core_real
8383
@make_psplincomb_call DFTK.eval_psp_density_core_fourier
8484

85-
@vectorize_psp_function DFTK.eval_psp_local_real PspLinComb
86-
@vectorize_psp_function DFTK.eval_psp_local_fourier PspLinComb
87-
@vectorize_psp_function DFTK.eval_psp_density_valence_real PspLinComb
88-
@vectorize_psp_function DFTK.eval_psp_density_valence_fourier PspLinComb
89-
@vectorize_psp_function DFTK.eval_psp_density_core_real PspLinComb
90-
@vectorize_psp_function DFTK.eval_psp_density_core_fourier PspLinComb
91-
@vectorize_psp_projector_function DFTK.eval_psp_projector_real PspLinComb
92-
@vectorize_psp_projector_function DFTK.eval_psp_projector_fourier PspLinComb
85+
@vectorize_psp_function PspLinComb DFTK.eval_psp_local_real
86+
@vectorize_psp_function PspLinComb DFTK.eval_psp_local_fourier
87+
@vectorize_psp_function PspLinComb DFTK.eval_psp_density_valence_real
88+
@vectorize_psp_function PspLinComb DFTK.eval_psp_density_valence_fourier
89+
@vectorize_psp_function PspLinComb DFTK.eval_psp_density_core_real
90+
@vectorize_psp_function PspLinComb DFTK.eval_psp_density_core_fourier
91+
@vectorize_psp_projector_function PspLinComb DFTK.eval_psp_projector_real
92+
@vectorize_psp_projector_function PspLinComb DFTK.eval_psp_projector_fourier

src/pseudo/PspUpf.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ has_core_density(psp::PspUpf) = !all(iszero, psp.r2_ρcore)
174174
function eval_psp_projector_real(psp::PspUpf, i, l, r::T)::T where {T<:Real}
175175
psp.r2_projs_interp[l+1][i](r) / r^2 # TODO if r is below a threshold, return zero
176176
end
177-
@vectorize_psp_projector_function DFTK.eval_psp_projector_real PspUpf
177+
@vectorize_psp_projector_function PspUpf DFTK.eval_psp_projector_real
178178

179179
function eval_psp_projector_fourier(psp::PspUpf, i, l, p::T)::T where {T<:Real}
180180
# The projectors may have been cut off before the end of the radial mesh
@@ -185,7 +185,7 @@ function eval_psp_projector_fourier(psp::PspUpf, i, l, p::T)::T where {T<:Real}
185185
r2_proj = @view psp.r2_projs[l+1][i][1:ircut_proj]
186186
hankel(rgrid, r2_proj, l, p)
187187
end
188-
@vectorize_psp_projector_function DFTK.eval_psp_projector_fourier PspUpf
188+
@vectorize_psp_projector_function PspUpf DFTK.eval_psp_projector_fourier
189189

190190
count_n_pswfc_radial(psp::PspUpf, l) = length(psp.r2_pswfcs[l+1])
191191

@@ -194,7 +194,7 @@ pswfc_label(psp::PspUpf, i, l) = psp.pswfc_labels[l+1][i]
194194
function eval_psp_pswfc_real(psp::PspUpf, i, l, r::T)::T where {T<:Real}
195195
psp.r2_pswfcs_interp[l+1][i](r) / r^2 # TODO if r is below a threshold, return zero
196196
end
197-
@vectorize_psp_projector_function DFTK.eval_psp_pswfc_real PspUpf
197+
@vectorize_psp_projector_function PspUpf DFTK.eval_psp_pswfc_real
198198

199199
function eval_psp_pswfc_fourier(psp::PspUpf, i, l, p::T)::T where {T<:Real}
200200
# Pseudo-atomic wavefunctions are _not_ currently cut off like the other
@@ -203,10 +203,10 @@ function eval_psp_pswfc_fourier(psp::PspUpf, i, l, p::T)::T where {T<:Real}
203203
# If issues arise, try cutting them off too.
204204
return hankel(psp.rgrid, psp.r2_pswfcs[l+1][i], l, p)
205205
end
206-
@vectorize_psp_projector_function DFTK.eval_psp_pswfc_fourier PspUpf
206+
@vectorize_psp_projector_function PspUpf DFTK.eval_psp_pswfc_fourier
207207

208208
eval_psp_local_real(psp::PspUpf, r::T) where {T<:Real} = psp.vloc_interp(r)
209-
@vectorize_psp_function DFTK.eval_psp_local_real PspUpf
209+
@vectorize_psp_function PspUpf DFTK.eval_psp_local_real
210210

211211
# Low-level function for the local part of the pseudopotential in reciprocal space
212212
function _eval_psp_local_fourier(quadrature, rgrid, vloc, Zion, p::T)::T where {T<:Real}
@@ -247,19 +247,19 @@ end
247247
function eval_psp_density_valence_real(psp::PspUpf, r::T) where {T<:Real}
248248
psp.r2_ρion_interp(r) / r^2 # TODO if r is below a threshold, return zero
249249
end
250-
@vectorize_psp_function DFTK.eval_psp_density_valence_real PspUpf
250+
@vectorize_psp_function PspUpf DFTK.eval_psp_density_valence_real
251251

252252
function eval_psp_density_valence_fourier(psp::PspUpf, p::T) where {T<:Real}
253253
rgrid = @view psp.rgrid[1:psp.ircut]
254254
r2_ρion = @view psp.r2_ρion[1:psp.ircut]
255255
return hankel(rgrid, r2_ρion, 0, p)
256256
end
257-
@vectorize_psp_function DFTK.eval_psp_density_valence_fourier PspUpf
257+
@vectorize_psp_function PspUpf DFTK.eval_psp_density_valence_fourier
258258

259259
function eval_psp_density_core_real(psp::PspUpf, r::T) where {T<:Real}
260260
psp.r2_ρcore_interp(r) / r^2 # TODO if r is below a threshold, return zero
261261
end
262-
@vectorize_psp_function DFTK.eval_psp_density_core_real PspUpf
262+
@vectorize_psp_function PspUpf DFTK.eval_psp_density_core_real
263263

264264
function eval_psp_density_core_fourier(psp::PspUpf, p::T) where {T<:Real}
265265
rgrid = @view psp.rgrid[1:psp.ircut]

0 commit comments

Comments
 (0)