diff --git a/src/CassetteOverlay.jl b/src/CassetteOverlay.jl index 19f5023..bd77a04 100644 --- a/src/CassetteOverlay.jl +++ b/src/CassetteOverlay.jl @@ -59,13 +59,16 @@ function generate_overlay_src( tt = Base.to_tuple_type(fargtypes) mt_worlds = methodtable(world, passtype) if mt_worlds isa Pair - method_table, worlds = mt_worlds + method_table, mtworlds = mt_worlds else method_table = mt_worlds - worlds = nothing + mtworlds = nothing end match = Base._which(tt; method_table, raise = false, world) - match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError + results = Base.Compiler.findall(tt, method_table; limit=1) + length(results) == 1 || return nothing # method match failed – the fallback implementation will raise a proper MethodError + match = results[1] + match_worlds = results.valid_worlds mi = Core.Compiler.specialize_method(match) src = Core.Compiler.retrieve_code_info(mi, world) src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it @@ -78,12 +81,24 @@ function generate_overlay_src( push!(invalid_code, (world, source, passtype, fargtypes, src, selfname, fargsname)) # TODO `return nothing` when updating the minimum compat to 1.12 end - if worlds !== nothing - src.min_world, src.max_world = max(src.min_world, first(worlds)), min(src.max_world, last(worlds)) + if mtworlds !== nothing + src.min_world, src.max_world = max(src.min_world, first(mtworlds)), min(src.max_world, last(mtworlds)) end + src.min_world, src.max_world = max(src.min_world, first(match_worlds)), min(src.max_world, last(match_worlds)) return src end +function get_mt_worlds(m::Module, var::Symbol, world::UInt) + @static if VERSION ≥ v"1.12-" + @assert isconst_at_world(m, var, world) + mt, worlds = getglobal_at_world(m, var, world) + return Base.Compiler.OverlayMethodTable(world, mt::MethodTable) => worlds + else + @assert @invokelatest isconst(M, S) + return getglobal(M, S)::MethodTable + end +end + macro overlaypass(args...) if length(args) == 1 PassName = nothing @@ -92,6 +107,10 @@ macro overlaypass(args...) PassName, method_table = args end + if !(method_table === nothing || method_table isa Symbol || Meta.isexpr(method_table, :.)) + error("Unexpected @overlaypass call") + end + if PassName === nothing PassName = esc(gensym(string(method_table))) decl_pass = :(struct $PassName <: $OverlayPass end) @@ -104,9 +123,22 @@ macro overlaypass(args...) nonoverlaytype = typeof(CassetteOverlay.nonoverlay) - if method_table !== :nothing - mthd_tbl = :($CassetteOverlay.methodtable(world::UInt, ::Type{$PassName}) = - Base.Compiler.OverlayMethodTable(world, $(esc(method_table)))) + if method_table isa Symbol + mthd_tbl = :( + function $CassetteOverlay.methodtable(world::UInt, ::Type{$PassName}) + return $CassetteOverlay.get_mt_worlds($__module__, $(QuoteNode(method_table)), world) + end + ) + elseif Meta.isexpr(method_table, :.) + M, S = method_table.args + if !(M isa Symbol && S isa QuoteNode && S.value isa Symbol) + error("Unexpected @overlaypass call") + end + mthd_tbl = :( + function $CassetteOverlay.methodtable(world::UInt, ::Type{$PassName}) + return $CassetteOverlay.get_mt_worlds($(esc(M)), $S, world) + end + ) else mthd_tbl = nothing end @@ -189,22 +221,12 @@ end abstract type AbstractBindingOverlay{M, S} <: OverlayPass; end function methodtable(world::UInt, ::Type{<:AbstractBindingOverlay{M, S}}) where {M, S} - if M === nothing - return nothing - end - @static if VERSION ≥ v"1.12-" - @assert isconst_at_world(M, S, world) - mt, worlds = getglobal_at_world(M, S, world) - return Base.Compiler.OverlayMethodTable(world, mt::MethodTable) => worlds - else - @assert @invokelatest isconst(M, S) - return getglobal(M, S)::MethodTable - end + (M isa Module && S isa Symbol) || error("Unexpected AbstractBindingOverlay type") + return get_mt_worlds(M, S, world) end @overlaypass AbstractBindingOverlay nothing -struct Overlay{M, S} <: AbstractBindingOverlay{M, S} -end +struct Overlay{M, S} <: AbstractBindingOverlay{M, S} end function Overlay(mt::MethodTable) @assert @invokelatest isconst(mt.module, mt.name) @assert mt === @invokelatest getglobal(mt.module, mt.name) diff --git a/test/simple.jl b/test/simple.jl index 98bfc58..4a09a61 100644 --- a/test/simple.jl +++ b/test/simple.jl @@ -10,11 +10,9 @@ myidentity(@nospecialize x) = x kwifelse(x, y; cond=true) = ifelse(cond, x, y) # run overlayed methods -@overlay SimpleTable myidentity(@nospecialize x) = 42 -@test pass(myidentity, nothing) == 42 -@test pass() do - myidentity(nothing) -end == 42 +@overlay SimpleTable myidentity(@nospecialize x) = (@noinline; (println(devnull, "prevent inlining")); 42) +call_myidentity() = @noinline myidentity(nothing) +@test pass(call_myidentity) == 42 # kwargs @overlay SimpleTable kwifelse(x, y; cond=true) = ifelse(cond, y, x) @@ -28,8 +26,8 @@ let (x, y) = (0, 1) end # method invalidation -@overlay SimpleTable myidentity(@nospecialize x) = 0 -@test pass(myidentity, nothing) == 0 +@overlay SimpleTable myidentity(@nospecialize x) = (@noinline; (println(devnull, "prevent inlining")); 0) +@test pass(call_myidentity) == 0 # nonoverlay @overlay SimpleTable myidentity(@nospecialize x) = nonoverlay(myidentity, x)