Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions src/CassetteOverlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,16 @@ function generate_overlay_src(
tt = Base.to_tuple_type(fargtypes)
mt_worlds = methodtable(world, passtype)
if mt_worlds isa Pair
method_table, worlds = mt_worlds
method_table, mtworlds = mt_worlds
else
method_table = mt_worlds
worlds = nothing
mtworlds = nothing
end
match = Base._which(tt; method_table, raise = false, world)
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
results = Base.Compiler.findall(tt, method_table; limit=1)
length(results) == 1 || return nothing # method match failed – the fallback implementation will raise a proper MethodError
match = results[1]
match_worlds = results.valid_worlds
mi = Core.Compiler.specialize_method(match)
src = Core.Compiler.retrieve_code_info(mi, world)
src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it
Expand All @@ -78,12 +81,24 @@ function generate_overlay_src(
push!(invalid_code, (world, source, passtype, fargtypes, src, selfname, fargsname))
# TODO `return nothing` when updating the minimum compat to 1.12
end
if worlds !== nothing
src.min_world, src.max_world = max(src.min_world, first(worlds)), min(src.max_world, last(worlds))
if mtworlds !== nothing
src.min_world, src.max_world = max(src.min_world, first(mtworlds)), min(src.max_world, last(mtworlds))
end
src.min_world, src.max_world = max(src.min_world, first(match_worlds)), min(src.max_world, last(match_worlds))
return src
end

function get_mt_worlds(m::Module, var::Symbol, world::UInt)
@static if VERSION ≥ v"1.12-"
@assert isconst_at_world(m, var, world)
mt, worlds = getglobal_at_world(m, var, world)
return Base.Compiler.OverlayMethodTable(world, mt::MethodTable) => worlds
else
@assert @invokelatest isconst(M, S)
return getglobal(M, S)::MethodTable
end
end

macro overlaypass(args...)
if length(args) == 1
PassName = nothing
Expand All @@ -92,6 +107,10 @@ macro overlaypass(args...)
PassName, method_table = args
end

if !(method_table === nothing || method_table isa Symbol || Meta.isexpr(method_table, :.))
error("Unexpected @overlaypass call")
end

if PassName === nothing
PassName = esc(gensym(string(method_table)))
decl_pass = :(struct $PassName <: $OverlayPass end)
Expand All @@ -104,9 +123,22 @@ macro overlaypass(args...)

nonoverlaytype = typeof(CassetteOverlay.nonoverlay)

if method_table !== :nothing
mthd_tbl = :($CassetteOverlay.methodtable(world::UInt, ::Type{$PassName}) =
Base.Compiler.OverlayMethodTable(world, $(esc(method_table))))
if method_table isa Symbol
mthd_tbl = :(
function $CassetteOverlay.methodtable(world::UInt, ::Type{$PassName})
return $CassetteOverlay.get_mt_worlds($__module__, $(QuoteNode(method_table)), world)
end
)
elseif Meta.isexpr(method_table, :.)
M, S = method_table.args
if !(M isa Symbol && S isa QuoteNode && S.value isa Symbol)
error("Unexpected @overlaypass call")
end
mthd_tbl = :(
function $CassetteOverlay.methodtable(world::UInt, ::Type{$PassName})
return $CassetteOverlay.get_mt_worlds($(esc(M)), $S, world)
end
)
else
mthd_tbl = nothing
end
Expand Down Expand Up @@ -189,22 +221,12 @@ end

abstract type AbstractBindingOverlay{M, S} <: OverlayPass; end
function methodtable(world::UInt, ::Type{<:AbstractBindingOverlay{M, S}}) where {M, S}
if M === nothing
return nothing
end
@static if VERSION ≥ v"1.12-"
@assert isconst_at_world(M, S, world)
mt, worlds = getglobal_at_world(M, S, world)
return Base.Compiler.OverlayMethodTable(world, mt::MethodTable) => worlds
else
@assert @invokelatest isconst(M, S)
return getglobal(M, S)::MethodTable
end
(M isa Module && S isa Symbol) || error("Unexpected AbstractBindingOverlay type")
return get_mt_worlds(M, S, world)
end
@overlaypass AbstractBindingOverlay nothing

struct Overlay{M, S} <: AbstractBindingOverlay{M, S}
end
struct Overlay{M, S} <: AbstractBindingOverlay{M, S} end
function Overlay(mt::MethodTable)
@assert @invokelatest isconst(mt.module, mt.name)
@assert mt === @invokelatest getglobal(mt.module, mt.name)
Expand Down
12 changes: 5 additions & 7 deletions test/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ myidentity(@nospecialize x) = x
kwifelse(x, y; cond=true) = ifelse(cond, x, y)

# run overlayed methods
@overlay SimpleTable myidentity(@nospecialize x) = 42
@test pass(myidentity, nothing) == 42
@test pass() do
myidentity(nothing)
end == 42
@overlay SimpleTable myidentity(@nospecialize x) = (@noinline; (println(devnull, "prevent inlining")); 42)
call_myidentity() = @noinline myidentity(nothing)
@test pass(call_myidentity) == 42

# kwargs
@overlay SimpleTable kwifelse(x, y; cond=true) = ifelse(cond, y, x)
Expand All @@ -28,8 +26,8 @@ let (x, y) = (0, 1)
end

# method invalidation
@overlay SimpleTable myidentity(@nospecialize x) = 0
@test pass(myidentity, nothing) == 0
@overlay SimpleTable myidentity(@nospecialize x) = (@noinline; (println(devnull, "prevent inlining")); 0)
@test pass(call_myidentity) == 0

# nonoverlay
@overlay SimpleTable myidentity(@nospecialize x) = nonoverlay(myidentity, x)
Expand Down
Loading