Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions src/Type/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,11 @@ inferRecDef2 topLevel coreDef divergent (def,mbAssumed)
coreX <- subst simexpr
-- traceDoc $ \penv -> prettyExpr penv coreX
(mvars,msub) <- Op.freshSub Bound tvars
let -- coreX = simplify expr -- coref0 (Core.defExpr coreDef)
-- mvars = [TypeVar id kind Bound | TypeVar id kind _ <- tvars]
-- msub = subNew (zip tvars (map TVar mvars))


resCoreX = (CoreVar.|~>) [(Core.TName ({- unqualify -} name) assumedTpX,
Core.TypeApp (Core.Var (Core.TName ({- unqualify -} name) (resTp1)) info)
(map TVar mvars))] -- TODO: check: was `tvars` TODO: wrong for unannotated polymorphic recursion: see codegen/wrong/rec2
(msub |-> coreX)
let -- replace recursive calls: match TypeApp (Var name) [args] as a unit
-- and produce TypeApp (Var resTp1) [mvars] with correct arg count/order
newVar = Core.Var (Core.TName ({- unqualify -} name) (resTp1)) info
oldName = Core.TName ({- unqualify -} name) assumedTpX
resCoreX = replaceRecCallEx oldName newVar (map TVar mvars) (msub |-> coreX)

resCoreY = Core.addTypeLambdas mvars resCoreX
-- TODO: check: this was:
Expand All @@ -441,14 +437,20 @@ inferRecDef2 topLevel coreDef divergent (def,mbAssumed)
do assumedTpX <- normalize True assumedTp >>= subst -- resTp0
simResCore1 <- return resCore1 -- liftUnique $ uniqueSimplify penv False False 1 0 resCore1
coreX <- subst simResCore1
let resCoreX = (CoreVar.|~>) [(Core.TName ({- unqualify -} name) assumedTpX, Core.Var (Core.TName ({- unqualify -} name) resTp1) info)] coreX
let newVar = Core.Var (Core.TName ({- unqualify -} name) resTp1) info
oldName = Core.TName ({- unqualify -} name) assumedTpX
newTpArgs = map TVar (fst (splitTypeScheme resTp1))
resCoreX = replaceRecCallEx oldName newVar newTpArgs coreX
return (resTp1, resCoreX)
_ -- ensure we insert the right info (test: static/div2-ack)
-> -- trace " rec normal" $
do assumedTpX <- normalize True assumedTp >>= subst
simResCore1 <- return resCore1 -- liftUnique $ uniqueSimplify penv False False 1 0 resCore1
coreX <- subst simResCore1
let resCoreX = (CoreVar.|~>) [(Core.TName ({- unqualify -} name) assumedTpX, Core.Var (Core.TName ({- unqualify -} name) resTp1) info)] coreX
let newVar = Core.Var (Core.TName ({- unqualify -} name) resTp1) info
oldName = Core.TName ({- unqualify -} name) assumedTpX
newTpArgs = map TVar (fst (splitTypeScheme resTp1))
resCoreX = replaceRecCallEx oldName newVar newTpArgs coreX
return (resTp1, resCoreX)
--(Nothing,_)
-- -> return (resTp1,resCore1) -- (CoreVar.|~>) [(unqualify name, Core.Var (Core.TName (unqualify name) resTp1) Core.InfoNone)] resCore1
Expand All @@ -461,6 +463,38 @@ inferRecDef2 topLevel coreDef divergent (def,mbAssumed)



-- | Replace recursive calls in a core expression.
-- Matches both bare `Var oldName` and `TypeApp (Var oldName) [args]` and replaces
-- the entire construct with `TypeApp newVar newTpArgs`. This correctly handles cases
-- where the generalized type has fewer forall variables than the assumed type
-- (e.g., when effect variables are closed during normalization), or where the
-- forall variable order differs between the assumed and generalized types.
replaceRecCallEx :: Core.TName -> Core.Expr -> [Type] -> Core.Expr -> Core.Expr
replaceRecCallEx oldName newVar newTpArgs expr
= go expr
where
isShadowed = (== oldName)

go (Core.TypeApp (Core.Var tn _) _) | tn == oldName = Core.makeTypeApp newVar newTpArgs
go (Core.Var tn _) | tn == oldName = Core.makeTypeApp newVar newTpArgs
go (Core.Lam tnames eff body) = Core.Lam tnames eff (if any isShadowed tnames then body else go body)
go (Core.App e args) = Core.App (go e) (map go args)
go (Core.TypeLam tvs body) = Core.TypeLam tvs (go body)
go (Core.TypeApp e tps) = Core.TypeApp (go e) tps
go (Core.Let dgs body) = let defnames = map Core.defName (Core.flattenDefGroups dgs)
in Core.Let (map goDg dgs) (if any (== Core.getName oldName) defnames then body else go body)
go (Core.Case es bs) = Core.Case (map go es) (map goBranch bs)
go e = e -- Var (non-matching), Con, Lit

goDg (Core.DefRec defs) = Core.DefRec (map goDef defs)
goDg (Core.DefNonRec def) = Core.DefNonRec (goDef def)
goDef def = def{ Core.defExpr = go (Core.defExpr def) }

goBranch (Core.Branch pats guards) = if any isShadowed (S.toList (CoreVar.bv pats))
then Core.Branch pats guards
else Core.Branch pats (map goGuard guards)
goGuard (Core.Guard test body) = Core.Guard (go test) (go body)

inferDef :: Bool -> Expect -> Def Type -> Inf Core.Def
inferDef topLevel expect (Def (ValueBinder name mbTp expr nameRng vrng) rng vis sort inl doc)
=do penv <- getPrettyEnv
Expand Down
25 changes: 25 additions & 0 deletions test/cgen/type-of.kk
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
type at<a,l>
Local(v: a)
Remote

fun unwrap(x: at<a,_>): a
match x
Local(v) -> v
Remote -> impossible("invariant violation")

alias khor<k,e> = <khor-div<k,e>,div|e>
div effect khor-div<k,e>
fun locally(loc1: k<l1>, khor: (forall<w> at<w,l1> -> w) -> khor<k,e> a): at<a,l1>

extern loc/unsafe-cast(x: k<l1>): k<l2> { inline "#1" }

fun epp1/epp(loc1: k<l1>, khor: () -> khor<k,e> a, ?(==): forall<x,y> (k<x>,k<y>) -> bool): <div|e> at<a,l1>
with handler
return(v) Local(v)
fun locally(loc1', khor') match loc1 == loc1'
False -> Remote
True -> loc1.unsafe-cast.epp(fn() khor'(unwrap))
khor()

fun main()
println("runs")
1 change: 1 addition & 0 deletions test/cgen/type-of.kk.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
runs
Loading