diff --git a/futhark.cabal b/futhark.cabal index fe0da38546..4f117628a3 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -174,6 +174,17 @@ library Futhark.Analysis.PrimExp.Convert Futhark.Analysis.PrimExp.Parse Futhark.Analysis.PrimExp.Simplify + Futhark.Analysis.Refinement + Futhark.Analysis.Refinement.CNF + Futhark.Analysis.Refinement.Forward + Futhark.Analysis.Refinement.Prop + Futhark.Analysis.Refinement.Rules + Futhark.Analysis.Refinement.Match + Futhark.Analysis.Refinement.Convert + Futhark.Analysis.Refinement.Latex + Futhark.Analysis.Refinement.Relations + Futhark.Analysis.Refinement.Representation + Futhark.Analysis.Refinement.Monad Futhark.Analysis.SymbolTable Futhark.Analysis.UsageTable Futhark.Bench @@ -203,6 +214,7 @@ library Futhark.CLI.PyOpenCL Futhark.CLI.Python Futhark.CLI.Query + Futhark.CLI.Refinement Futhark.CLI.REPL Futhark.CLI.Run Futhark.CLI.Test @@ -274,6 +286,15 @@ library Futhark.Construct Futhark.Doc.Generator Futhark.Error + Futhark.SoP.Convert + Futhark.SoP.Expression + Futhark.SoP.Monad + Futhark.SoP.Refine + Futhark.SoP.RefineEquivs + Futhark.SoP.FourierMotzkin + Futhark.SoP.RefineRanges + Futhark.SoP.SoP + Futhark.SoP.Util Futhark.FreshNames Futhark.IR Futhark.IR.Aliases @@ -517,6 +538,9 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 + , multiset + , HaTeX + , process executable futhark import: common @@ -544,6 +568,10 @@ test-suite unit Futhark.IR.Mem.IxFun.Alg Futhark.IR.Mem.IxFunTests Futhark.IR.Mem.IxFunWrapper + Futhark.SoP.RefineTests + Futhark.SoP.FourierMotzkinTests + Futhark.SoP.Parse + Futhark.SoP.SoPTests Language.Futhark.CoreTests Language.Futhark.PrimitiveTests Language.Futhark.SyntaxTests diff --git a/prelude/prelude.fut b/prelude/prelude.fut index 0bacd9c732..822fd4187e 100644 --- a/prelude/prelude.fut +++ b/prelude/prelude.fut @@ -6,6 +6,7 @@ open import "array" open import "math" open import "functional" open import "ad" +open import "refinement" -- | Create single-precision float from integer. def r32 (x: i32): f32 = f32.i32 x diff --git a/prelude/refinement.fut b/prelude/refinement.fut new file mode 100644 index 0000000000..b5991a7187 --- /dev/null +++ b/prelude/refinement.fut @@ -0,0 +1,25 @@ +def elems 'a 'b (_: a) : b = ??? + +def toSet 'a 'b (_: a) : b = ??? + +def (<+>) 'a 'b (_ : a) (_ : a) : b = ??? + +def permutationOf 'a 'b 'c (_ : a) (_ : b) : c = ??? + +def subseteq 'a 'b 'c (_ : a) (_ : b) : c = ??? + +def subeq 'a 'b 'c (_ : a) (_ : b) : c = ??? + +def without 'a 'b 'c (_ : a) (_ : b) : c = ??? + +def union 'a 'b 'c (_ : a) (_ : b) : c = ??? + +def sum 't [n] (_: [n]t) : t = ??? + +def foreach 'a '^b 'c (_ : a) (_ : b) : c = ??? + +def forall 'a '^b 'c (_ : a) (_ : b) : c = ??? + +def elem 'a 'b 'c (_ : a) (_ : b) : c = ??? + +def axiom 'a 'b (_ : a) : b = ??? diff --git a/src/Futhark/Analysis/PrimExp.hs b/src/Futhark/Analysis/PrimExp.hs index 9bb4be1860..7f71449105 100644 --- a/src/Futhark/Analysis/PrimExp.hs +++ b/src/Futhark/Analysis/PrimExp.hs @@ -66,6 +66,7 @@ module Futhark.Analysis.PrimExp (~+~), (~-~), (~==~), + oneIshExp, ) where diff --git a/src/Futhark/Analysis/Refinement.hs b/src/Futhark/Analysis/Refinement.hs new file mode 100644 index 0000000000..8f57547287 --- /dev/null +++ b/src/Futhark/Analysis/Refinement.hs @@ -0,0 +1,148 @@ +module Futhark.Analysis.Refinement where + +import Control.Monad.RWS +import Data.List qualified as L +import Data.Map qualified as M +import Data.String +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Convert +import Futhark.Analysis.Refinement.Forward +import Futhark.Analysis.Refinement.Latex +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Prop +import Futhark.Analysis.Refinement.Representation +import Futhark.Analysis.Refinement.Rules +import Futhark.MonadFreshNames +import Futhark.SoP.SoP hiding (Range, SoP, Term) +import Futhark.Util.Pretty +import Language.Futhark qualified as E +import Language.Futhark.Semantic +import Text.LaTeX.Packages.AMSMath qualified as Math + +refineProg :: VNameSource -> Imports -> [Log] +refineProg vns prog = execRefineM (refineImports prog) vns + +refineImports :: [(ImportName, FileModule)] -> RefineM () +refineImports = mapM_ (refineDecs . E.progDecs . fileProg . snd) + +refineDecs :: [E.Dec] -> RefineM () +refineDecs [] = pure () +refineDecs (E.ValDec vb : rest) = do + refineValBind vb + refineDecs rest +refineDecs (_ : ds) = refineDecs ds + +refineValBind :: E.ValBind -> RefineM () +refineValBind (E.ValBind _ _ ret _ _ params body _ _ _) = do + mapM paramRefs params + forwards body + case ret of + Just (E.TERefine t p _) -> do + goal <- mkProp p $ getRes body + s <- get + res <- backwards (fmap (\g -> (g, s, mempty)) goal) body + pure () + _ -> pure () + where + getRes :: E.Exp -> [E.Exp] + getRes (E.AppExp (E.LetPat _ p e body _) _) = + getRes body + getRes (E.TupLit es _) = + concatMap getRes es + getRes e = [e] + + paramRefs (E.PatParens p _) = paramRefs p + paramRefs (E.PatAttr _ p _) = paramRefs p + paramRefs (E.PatAscription e@(E.Id v _ _) (E.TERefine t (E.Lambda [pat] body _ _ _) _) _) + | (E.Named x, _, _) <- E.patternParam pat = do + -- fix, can't handle or + info <- + (concat . cnfToLists . substituteOne (x, Var v)) + <$> unsafeConvert toCNF body -- Fix + insertType v (E.patternType e) + modify $ \senv -> + senv + { known = known senv ++ info, + known_map = + M.insertWith (<>) v info $ known_map senv + } + paramRefs e + | (E.Named x, _, _) <- E.patternParam e = + insertType x (E.patternType e) + paramRefs _ = pure () + +mkProp :: E.Exp -> [E.Exp] -> RefineM (CNF Prop) +mkProp (E.Lambda ps body _ _ _) args = do + m <- mconcat <$> zipWithM mkSubstParam (concatMap unwrapTuple ps) args + g <- substitute m <$> unsafeConvert toCNF body + tell [Math.text "Proving: " <> toLaTeX g] + pure g + where + unwrapTuple (E.TuplePat ps _) = ps + unwrapTuple (E.RecordPat ps _) = map snd ps + unwrapTuple p = [p] + mkSubstParam p arg + | (E.Named x, _, _) <- E.patternParam p = do + arg' <- unsafeConvert toExp arg + pure $ M.singleton x arg' + | otherwise = do + error $ unlines [prettyString p, prettyString arg] + pure mempty +mkProp e _ = + error $ + "Unsupported predicate: " <> prettyString e + +instance Show VNameSource where + show (VNameSource i) = show i + +rewriteProps :: CNF (Prop, SEnv, [Log]) -> RefineM (CNF (Prop, SEnv, [Log])) +rewriteProps gs = do + gs' <- bindCNFM rewriteProp gs + let cnf = fmap (\(g, _, _) -> flatten g) gs' + ws = foldMap (\(_, _, w) -> w) gs' + gs'' = fmap (\(g, s, _) -> (flatten g, s, [])) gs' + if not (null ws) + then do + tell $ L.nub $ ws + tell [toLaTeX cnf] + rewriteProps gs'' + else pure gs'' + where + rewriteProp :: Prop -> CNFM Prop + rewriteProp g = do + g' <- simplify g + mg' <- applyRules g' + case mg' of + Nothing -> pure g' + Just (s, g'') -> do + known_props <- gets known + tell $ + [Math.text "Rewrote " <> toLaTeX g <> Math.text " using " <> Math.mathtt (fromString s)] + pure g'' + +unsafeConvert :: (E.Exp -> RefineM (Maybe a)) -> E.Exp -> RefineM a +unsafeConvert f e = do + me' <- f e + case me' of + Nothing -> error $ "Couldn't convert Exp!: " <> prettyString e + Just e' -> pure e' + +backwards :: (CNF (Prop, SEnv, [Log])) -> E.Exp -> RefineM (CNF (Prop, SEnv, [Log])) +backwards gs (E.AppExp (E.LetPat _ p e body _) _) + | (E.Named x, _, _) <- E.patternParam $ fixTuple p = do + gs' <- backwards gs body + insertExp x e + rewriteProps $ addExps x e $ gs' + where + -- fix, only supports wildcards for now + isWildcard (E.Wildcard {}) = True + isWildcard _ = False + fixTuple (E.TuplePat ps _) = head $ filter (not . isWildcard) ps + fixTuple p = p + addExps x e = + fmap + ( \(g, senv, ws) -> + (g, senv {exps = M.insert x e $ exps senv}, ws) + ) +backwards gs _ = do + rewriteProps gs diff --git a/src/Futhark/Analysis/Refinement/CNF.hs b/src/Futhark/Analysis/Refinement/CNF.hs new file mode 100644 index 0000000000..d0f2638c43 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/CNF.hs @@ -0,0 +1,156 @@ +module Futhark.Analysis.Refinement.CNF + ( CNF (..), + DNF (..), + Or (..), + And (..), + (&&&), + (|||), + cnfTrue, + cnfFalse, + cnfNub, + cnfIsValid, + cnfIsValidM, + toDNF, + listsToCNF, + cnfToLists, + dnfToLists, + negateCNF, + atomCNF, + justAnds, + ) +where + +import Control.Applicative +import Control.Monad +import Futhark.SoP.Util +import Futhark.Util +import Futhark.Util.Pretty + +-- Wrapper around a list (instead of a set) because we want to add +-- 'Monad' instances. +newtype And a = And {ands :: [a]} + deriving (Show, Eq, Ord) + +newtype Or a = Or {ors :: [a]} + deriving (Show, Eq, Ord) + +newtype CNF a = CNF {getCNF :: And (Or a)} + deriving (Show, Eq, Ord) + +newtype DNF a = DNF {getDNF :: Or (And a)} + deriving (Show, Eq, Ord) + +instance (Pretty a) => Pretty (Or a) where + pretty = + concatWith (surround " ∨ ") . map pretty . ors + +instance (Pretty a) => Pretty (And a) where + pretty = + concatWith (surround " ∧ ") . map (parens . pretty) . ands + +instance (Pretty a) => Pretty (CNF a) where + pretty = pretty . getCNF + +instance (Pretty a) => Pretty (DNF a) where + pretty = pretty . getDNF + +atomCNF :: a -> CNF a +atomCNF a = listsToCNF [[a]] + +cnfFalse :: CNF a +cnfFalse = CNF $ And [Or mempty] + +cnfTrue :: CNF a +cnfTrue = CNF $ And mempty + +cnfToLists :: CNF a -> [[a]] +cnfToLists = map ors . ands . getCNF + +dnfToLists :: DNF a -> [[a]] +dnfToLists = map ands . ors . getDNF + +listsToCNF :: [[a]] -> CNF a +listsToCNF = CNF . And . map Or + +listsToDNF :: [[a]] -> DNF a +listsToDNF = DNF . Or . map And + +cnfNub :: (Ord a) => CNF a -> CNF a +cnfNub = listsToCNF . nubOrd . map nubOrd . cnfToLists + +isFalse :: CNF a -> Bool +isFalse (CNF (And [Or as])) = null as +isFalse _ = False + +negateCNF :: (a -> a) -> CNF a -> CNF a +negateCNF = fmap + +(&&&) :: CNF a -> CNF a -> CNF a +xss &&& yss + | isFalse xss || isFalse yss = cnfFalse + | otherwise = listsToCNF $ cnfToLists xss <> cnfToLists yss + +infixr 3 &&& + +(|||) :: CNF a -> CNF a -> CNF a +xss ||| yss + | isFalse xss = yss + | isFalse yss = xss + | otherwise = listsToCNF $ do + xs <- cnfToLists xss + x <- xs + ys <- cnfToLists yss + pure $ x : ys + +infixr 2 ||| + +cnfIsValid :: (a -> Bool) -> CNF a -> Bool +cnfIsValid isValid = all (any isValid) . cnfToLists + +cnfIsValidM :: (Monad m) => (a -> m Bool) -> CNF a -> m Bool +cnfIsValidM isValid = allM (anyM isValid) . cnfToLists + +toDNF :: CNF a -> DNF a +toDNF = listsToDNF . toOrs . cnfToLists + where + toOrs [] = [[]] + toOrs (or' : ors') = + [a : and' | and' <- toOrs ors', a <- or'] + +instance Functor CNF where + fmap f = listsToCNF . (map . map) f . cnfToLists + +instance Functor DNF where + fmap f = listsToDNF . (map . map) f . dnfToLists + +instance Applicative CNF where + pure = listsToCNF . pure . pure + liftA2 = liftM2 + +instance Monad CNF where + xss >>= f = + foldr (((&&&)) . foldr (|||) cnfFalse) cnfTrue ((map . map) f (cnfToLists xss)) + +instance Alternative CNF where + empty = cnfFalse + (<|>) = (|||) + +instance MonadPlus CNF + +instance Semigroup (CNF a) where + (<>) = (&&&) + +instance Monoid (CNF a) where + mempty = cnfFalse + +instance Foldable CNF where + foldMap f = (foldMap . foldMap) f . cnfToLists + +instance Foldable DNF where + foldMap f = (foldMap . foldMap) f . dnfToLists + +instance Traversable CNF where + traverse f = fmap listsToCNF . (traverse . traverse) f . cnfToLists + +justAnds :: CNF a -> Bool +justAnds = all (\c -> length c <= 1) . cnfToLists diff --git a/src/Futhark/Analysis/Refinement/Convert.hs b/src/Futhark/Analysis/Refinement/Convert.hs new file mode 100644 index 0000000000..346e076740 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Convert.hs @@ -0,0 +1,233 @@ +module Futhark.Analysis.Refinement.Convert where + +import Control.Monad.RWS +import Data.Bifunctor +import Data.List (find) +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Debug.Trace +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Representation +import Futhark.MonadFreshNames +import Futhark.SoP.Refine qualified as SoP +import Futhark.SoP.SoP qualified as SoP +import Futhark.Util.Pretty +import Language.Futhark qualified as E +import Language.Futhark.Primitive qualified as EP + +toExp :: (Monad m) => E.Exp -> RefineT m (Maybe Exp) +toExp (E.Attr _ e _) = toExp e +toExp e@(E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) + | E.baseTag (E.qualLeaf op) <= E.maxIntrinsicTag, + name <- E.baseString $ E.qualLeaf op, + Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = do + x <- toExp e_x + y <- toExp e_y + case bop of + E.Plus -> pure $ (~+~) <$> x <*> y + E.Times -> pure $ (~*~) <$> x <*> y + E.Minus -> pure $ (~-~) <$> x <*> y + _ -> error $ show bop +-- | otherwise = do +-- x <- toExp e_x +-- y <- toExp e_y +-- case E.baseString (E.qualLeaf op) of +-- "<+>" -> pure $ (Union . Set . S.fromList) <$> sequence [x, y] +-- "union" -> pure $ (Union . Set . S.fromList) <$> sequence [x, y] +-- "++" -> pure $ Concat <$> x <*> y +-- "without" -> pure $ Without <$> x <*> y +-- s -> error $ s +toExp (E.ArrayLit es _ _) = do + es' <- mapM toExp es + pure $ Array <$> sequence es' +toExp (E.Var (E.QualName qs x) _ _) = + pure $ Just $ Var x +toExp e@(E.AppExp (E.Apply f args _) _) + | f `isFun` "length", + [(_, xs)] <- NE.toList args = do + xs' <- toExp xs + pure $ Len <$> xs' + | f `isFun` "elems", + [(_, n)] <- NE.toList args = do + n' <- toExp n + pure $ Elems <$> n' + | f `isFun` "toSet", + [(_, n)] <- NE.toList args = do + n' <- toExp n + case n' of + Just (Array es) -> pure $ Just $ Set $ S.fromList es + _ -> error $ show n' + | f `isFun` "bool", + [(_, n)] <- NE.toList args = do + n' <- toExp n + pure $ BoolToInt <$> n' + -- | f `isFun` "union", + -- [(_, x), (_, y)] <- NE.toList args = do + -- x' <- (fmap . fmap) termToSet $ toExp x + -- y' <- (fmap . fmap) termToSet $ toExp y + -- pure $ (Union . Set . S.fromList) <$> sequence [x', y'] + -- | f `isFun` "sum", + -- [(_, x)] <- NE.toList args = do + -- x' <- toExp x + -- i <- newVName "i" + -- let set = + -- fmap (\x'' -> intToTerm 0 ... (Len x'' ~-~ intToTerm 1)) x' + -- pure $ Sigma (Var i) <$> set <*> ((flip Idx (Var i)) <$> x') + -- | f `isFun` "without", + -- [(_, x), (_, y)] <- NE.toList args = do + -- x' <- (fmap . fmap) termToSet $ toExp x + -- y' <- (fmap . fmap) termToSet $ toExp y + -- pure $ Without <$> x' <*> y' + -- | f `isFun` "iota", + -- [(_, x)] <- NE.toList args = do + -- x' <- (fmap . fmap) termToSet $ toExp x + -- pure $ (intToTerm 0 ...) <$> x' + | f `isFun` "scan", -- TODO: replication of scan_sum rule, consolidate + [E.OpSection (E.QualName [] vn) _ _, _, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args, + "+" <- E.baseString vn = do + xsm <- toExp xs + case xsm of + Nothing -> pure Nothing + Just (Idx xs' r@(Range from step to)) -> do + i <- newVName "i" + j <- newVName "j" + pure $ + Just $ + Unions (Var i) (Range (intToTerm 0) step (Len r ~-~ intToTerm 1)) (CNFTerm cnfTrue) $ + Sigma + (Var j) + (Range from step (from ~+~ Var i)) + (Idx xs' $ Var j) +toExp e@(E.AppExp (E.Range from mstep (E.ToInclusive to) _) _) = do + from' <- toExp from + to' <- toExp to + step' <- case mstep of + Nothing -> pure $ Just $ SoP $ SoP.int2SoP 1 + Just step -> toExp step + pure $ Range <$> from' <*> step' <*> to' +toExp (E.Parens e _) = toExp e +toExp (E.Literal pv _) + | E.SignedValue iv <- pv = pure $ Just $ intToTerm $ EP.valueIntegral iv + | E.UnsignedValue iv <- pv = pure $ Just $ intToTerm $ EP.valueIntegral iv +toExp (E.IntLit x _ _) = pure $ Just $ SoP $ SoP.int2SoP x +toExp (E.Negate (E.IntLit x _ _) _) = pure $ Just $ SoP $ SoP.negSoP $ SoP.int2SoP x +toExp e@(E.AppExp (E.If c t f _) _) = do + c' <- toCNF c + t' <- toExp t + f' <- toExp f + pure $ If <$> (CNFTerm <$> c') <*> t' <*> f' +toExp (E.AppExp (E.Index xs [E.DimFix i] _) _) = do + i' <- toExp i + xs' <- toExp xs + pure $ Idx <$> xs' <*> i' +toExp (E.AppExp (E.Index xs [E.DimSlice mstart mend mstep] _) _) = do + start <- maybe (pure $ Just $ intToTerm 0) toExp mstart + step <- maybe (pure $ Just $ intToTerm 1) toExp mstep + xs' <- toExp xs + end <- maybe (pure $ (\xs -> Len xs ~-~ intToTerm 1) <$> xs') toExp mend + pure $ Idx <$> xs' <*> (Range <$> start <*> step <*> end) +-- toExp exp@(E.Coerce e te _ _) = do +-- -- fix +-- me_dims <- sequence <$> mapM toExp (E.shapeDims $ E.arrayShape $ E.typeOf e) +-- mexp_dims <- sequence <$> mapM toExp (E.shapeDims $ E.arrayShape $ E.typeOf exp) +-- case (me_dims, mexp_dims) of +-- (Just e_dims, Just exp_dims) -> do +-- zipWithM (\x y -> SoP.addRel $ expToSoP x SoP.:==: expToSoP y) exp_dims e_dims +-- modify $ \senv -> +-- senv +-- { known = known senv ++ zipWith (:==) exp_dims e_dims, +-- known_map = +-- M.unionWith (<>) (known_map senv) (M.fromList (zipWith (\(Var x) y -> (x, [Var x :== y])) exp_dims e_dims)) +-- } +-- toExp e +-- _ -> pure Nothing +toExp exp@(E.AppExp (E.LetPat sizes p e body _) _) + | (E.Named x, _, _) <- E.patternParam p = do + e' <- toExp e + body' <- toExp body + case e' of + Just e'' -> pure $ SoP.substituteOne (x, e'') <$> body' + _ -> + -- pure Nothing + error $ unlines [prettyString e, show e] +toExp e = error $ prettyString e <> "\n" <> show e + +-- pure Nothing -- + +toCNF :: (Monad m) => E.Exp -> RefineT m (Maybe (CNF Prop)) +toCNF (E.Parens e _) = toCNF e +toCNF e@(E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) + | E.baseTag (E.qualLeaf op) <= E.maxIntrinsicTag, + name <- E.baseString $ E.qualLeaf op, + Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = do + x <- toCNF e_x + y <- toCNF e_y + case bop of + E.LogAnd -> do + pure $ (&&&) <$> x <*> y + E.LogOr -> do + pure $ (|||) <$> x <*> y + _ -> (fmap . fmap) atomCNF $ toProp e +toCNF e = (fmap . fmap) atomCNF $ toProp e + +toProp :: (Monad m) => E.Exp -> RefineT m (Maybe Prop) +toProp (E.Parens e _) = toProp e +toProp e@(E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) + | E.baseTag (E.qualLeaf op) <= E.maxIntrinsicTag, + name <- E.baseString $ E.qualLeaf op, + Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = do + x <- toExp e_x + y <- toExp e_y + case bop of + -- E.NotEqual -> pure $ (:/=) <$> x <*> y + E.Greater -> pure $ (:>) <$> x <*> y + -- E.Less -> pure $ (:<) <$> x <*> y + -- E.Geq -> pure $ (:>=) <$> x <*> y + E.Equal -> pure $ (:==) <$> x <*> y + _ -> error $ show bop +-- | otherwise = do +-- x <- toExp e_x +-- y <- toExp e_y +-- case E.baseString (E.qualLeaf op) of +-- "subseteq" -> pure $ SubsetEq <$> x <*> y +-- "subeq" -> pure $ SubEq <$> x <*> y +-- _ -> error $ show op +toProp e@(E.AppExp (E.Apply f args _) _) + -- | f `isFun` "axiom", + -- [(_, p)] <- NE.toList args = do + -- p' <- toProp p + -- pure $ Axiom <$> p' + | f `isFun` "permutationOf", + [(_, x), (_, y)] <- NE.toList args = do + x' <- (fmap . fmap) termToSet $ toExp x + y' <- (fmap . fmap) termToSet $ toExp y + pure $ PermutationOf <$> x' <*> y' + +-- | f `isFun` "subseteq", +-- [(_, x), (_, y)] <- NE.toList args = do +-- x' <- (fmap . fmap) termToSet $ toExp x +-- y' <- (fmap . fmap) termToSet $ toExp y +-- pure $ SubsetEq <$> x' <*> y' +-- | f `isFun` "forall", +-- [(_, xs), (_, E.Lambda [p] body _ _ _)] <- map (second (\e -> fromMaybe e (E.stripExp e))) (NE.toList args), +-- (E.Named x, _, _) <- E.patternParam p = do +-- i <- newVName "i" +-- xs' <- (fmap . fmap) termToSet $ toExp xs +-- pred' <- (fmap . fmap) (SoP.substituteOne (x, Var i)) $ toProp body +-- pure $ ForAll (Var i) <$> xs' <*> pred' +-- | f `isFun` "foreach", +-- [(_, xs), (_, E.Lambda [p] body _ _ _)] <- map (second (\e -> fromMaybe e (E.stripExp e))) (NE.toList args), +-- (E.Named x, _, _) <- E.patternParam p = do +-- i <- newVName "k" +-- xs' <- (fmap . fmap) termToSet $ toExp xs +-- pred' <- (fmap . fmap) (SoP.substituteOne (x, Var i)) $ toProp body +-- pure $ ForEach (Var i) <$> xs' <*> pred' +toProp e = toExp e + +-- A total hack +isFun :: E.Exp -> String -> Bool +isFun (E.Var (E.QualName _ vn) _ _) fname = fname == E.baseString vn +isFun _ _ = False diff --git a/src/Futhark/Analysis/Refinement/Forward.hs b/src/Futhark/Analysis/Refinement/Forward.hs new file mode 100644 index 0000000000..539c633876 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Forward.hs @@ -0,0 +1,43 @@ +module Futhark.Analysis.Refinement.Forward where + +import Control.Monad.RWS +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Debug.Trace +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Prop +import Futhark.Analysis.Refinement.Relations +import Futhark.Analysis.Refinement.Representation +import Futhark.MonadFreshNames +import Futhark.SoP.Refine +import Futhark.SoP.SoP qualified as SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark qualified as E +import Language.Futhark.Prop qualified as E + +forwards :: E.Exp -> RefineM () +forwards (E.AppExp (E.LetPat _ p e body _) _) + | (E.Named x, _, _) <- E.patternParam p = do + forward x e + forwards body +forwards _ = pure () + +forward :: E.VName -> E.Exp -> RefineM () +forward x e = pure () + +-- ifM +-- (isNonNeg e) +-- ( do +-- i <- newVName "i" +-- let info = ForAll (Var i) (Var x) (Var i :>= intToExp 0) +-- modify $ \senv -> +-- senv +-- { known = known senv ++ [info], +-- known_map = +-- M.insertWith (<>) x [info] $ known_map senv +-- } +-- ) +-- (pure ()) diff --git a/src/Futhark/Analysis/Refinement/Latex.hs b/src/Futhark/Analysis/Refinement/Latex.hs new file mode 100644 index 0000000000..7395af6f75 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Latex.hs @@ -0,0 +1,212 @@ +module Futhark.Analysis.Refinement.Latex + ( mkLaTeX, + LaTeXC (..), + ToLaTeX (..), + LaTeX (..), + ) +where + +import Data.Bifunctor +import Data.List (intersperse) +import Data.List qualified as L +import Data.Set qualified as S +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Representation +import Futhark.SoP.SoP qualified as SoP +import Language.Futhark qualified as E +import Text.LaTeX +import Text.LaTeX.Base.Class +import Text.LaTeX.Base.Syntax +import Text.LaTeX.Packages.AMSMath +import Text.LaTeX.Packages.AMSSymb +import Text.LaTeX.Packages.Inputenc +import Text.LaTeX.Packages.Trees.Qtree +import Prelude hiding (concat) + +class ToLaTeX a where + toLaTeX_ :: (Eq l, LaTeXC l) => Int -> a -> l + toLaTeX :: (Eq l, LaTeXC l) => a -> l + toLaTeX = toLaTeX_ 0 + +concatWith :: (Foldable t, LaTeXC l) => (l -> l -> l) -> t l -> l +concatWith _ ls + | null ls = mempty +concatWith op ls = + foldr1 op ls + +joinWith :: (Eq l, LaTeXC l) => (l -> l -> l) -> l -> l -> l +joinWith _ l r + | r == mempty = l +joinWith _ l r + | l == mempty = r +joinWith op l r = l `op` r + +enclose :: (LaTeXC l) => l -> l -> l -> l +enclose l r x = l <> x <> r + +surround :: (LaTeXC l) => l -> l -> l -> l +surround x l r = enclose l r x + +emptySet :: (LaTeXC l) => l +emptySet = raw "\\varnothing" + +neg :: (LaTeXC l) => l -> l +neg = (comm0 "neg" <>) + +num :: (Show x, LaTeXC l) => x -> l +num = fromString . show + +concat :: (LaTeXC l) => l -> l -> l +concat = surround (raw "\\mathbin{+\\mkern-10mu+}") + +substack :: (LaTeXC l) => [l] -> l +substack = + comm1 "substack" . mconcat . intersperse lnbk + +fun :: (LaTeXC l) => l -> [l] -> l +fun f args = + f <> autoParens (concatWith (surround ", ") args) + +(<+>) :: (LaTeXC l) => l -> l -> l +l <+> r = l <> space <> r + +bigC :: (LaTeXC l) => l -> l +bigC x = raw "\\makebox{\\huge\\ensuremath{C}}" !: x + +autoParensP :: (LaTeXC l) => Int -> Int -> l -> l +autoParensP prec new_prec l + | prec >= new_prec = autoParens l + | otherwise = l + +instance (ToLaTeX u) => ToLaTeX (SoP.Term u) where + toLaTeX_ p t = + autoParensP p 7 $ + concatWith cdot $ + map (toLaTeX_ 7) $ + SoP.termToList t + +instance (ToLaTeX u, Ord u) => ToLaTeX (SoP.SoP u) where + toLaTeX_ p sop + | Just c <- SoP.justConstant sop = num c + | null pos_terms = + (if length neg_neg_terms > 1 then autoParensP p 6 else id) + ( "-" <> concatWith (surround "-") (map mult neg_neg_terms) + ) + | otherwise = + autoParensP p 6 $ + joinWith + (surround "-") + (concatWith (surround "+") (map mult pos_terms)) + (concatWith (surround "-") (map mult neg_neg_terms)) + where + mult (t, n) + | SoP.isConstTerm t = num n + mult (t, 1) = toLaTeX_ p t + mult (t, n) = num n `cdot` toLaTeX_ 7 t + (pos_terms, neg_terms) = L.partition ((>= 0) . snd) $ SoP.sopToList sop + neg_neg_terms = map (second negate) neg_terms + +instance ToLaTeX E.VName where + toLaTeX_ _ vn = fromString $ E.baseString vn + +instance ToLaTeX Hole where + toLaTeX_ _ _ = square + +instance ToLaTeX Term where + toLaTeX_ p (Var x) = toLaTeX_ p x + toLaTeX_ p (THole h) = toLaTeX_ p h + toLaTeX_ p (SoP sop) = toLaTeX_ p sop + toLaTeX_ _ (Len x) = mathtt "len" <> autoParens (toLaTeX x) + toLaTeX_ _ (Elems x) = mathtt "elems" <> autoParens (toLaTeX x) + toLaTeX_ _ (Set es) = + autoBraces $ concatWith (surround ", ") $ map toLaTeX $ S.toList es + toLaTeX_ _ (Array es) = + autoSquareBrackets $ concatWith (surround ", ") $ map toLaTeX es + toLaTeX_ _ (Range from step to') = + autoBraces $ toLaTeX from <> "," <> toLaTeX (from ~+~ step) <> "," <> ldots <> "," <> toLaTeX to' + toLaTeX_ _ (Idx x y) = + toLaTeX_ 9 x <> autoSquareBrackets (toLaTeX y) + toLaTeX_ p (Union x y) = + autoParensP p 4 $ toLaTeX_ 4 x `cup` toLaTeX_ 4 y + toLaTeX_ _ (Unions i set conds x) = + bigcupFromTo (substack [toLaTeX i `in_` toLaTeX set, toLaTeX conds]) mempty <> autoBraces (toLaTeX x) + toLaTeX_ _ (Sigma i set x) = + sumFromTo (toLaTeX i `in_` toLaTeX set) mempty <> toLaTeX_ 9 x + toLaTeX_ _ (x :< y) = toLaTeX_ 3 x <: toLaTeX_ 3 y + toLaTeX_ _ (x :<= y) = toLaTeX_ 3 x <=: toLaTeX_ 3 y + toLaTeX_ _ (x :> y) = toLaTeX_ 3 x >: toLaTeX_ 3 y + toLaTeX_ _ (x :>= y) = toLaTeX_ 3 x >=: toLaTeX_ 3 y + toLaTeX_ _ (x :== y) = toLaTeX_ 3 x =: toLaTeX_ 3 y + toLaTeX_ _ (x :/= y) = toLaTeX_ 3 x /=: toLaTeX_ 3 y + toLaTeX_ _ (PermutationOf x y) = + fun (mathtt "permutation_of") $ map toLaTeX [x, y] + toLaTeX_ _ (Bool b) = mathtt $ fromString $ show b + toLaTeX_ _ (Not x) = neg $ toLaTeX_ 9 x + toLaTeX_ p (CNFTerm cnf) = toLaTeX_ p cnf + toLaTeX_ p (If c t f) = + autoParensP p 4 $ + mathtt "if" + <+> toLaTeX_ 4 c + <+> mathtt "then" + <+> toLaTeX_ 4 t + <+> mathtt "else" + <+> toLaTeX_ 4 f + toLaTeX_ _ (BoolToInt x) = + mathtt "bool_to_int" <> autoParens (toLaTeX x) + toLaTeX_ _ (Forall {}) = undefined + +instance (ToLaTeX a) => ToLaTeX (Or a) where + toLaTeX_ p x = + (if length (ors x) > 1 then autoParensP p 1 else id) $ + concatWith vee $ + map (toLaTeX_ 1) $ + ors x + +instance (ToLaTeX a) => ToLaTeX (And a) where + toLaTeX_ p x = + (if length (ands x) > 1 then autoParensP p 2 else id) $ + concatWith wedge $ + map (toLaTeX_ 2) $ + ands x + +instance (ToLaTeX a) => ToLaTeX (CNF a) where + toLaTeX_ p = toLaTeX_ p . getCNF + +instance (ToLaTeX a) => ToLaTeX [a] where + toLaTeX_ p = mconcat . intersperse lnbk . map (toLaTeX_ p) + +instance ToLaTeX LaTeX where + toLaTeX_ _ = fromLaTeX + +mkLaTeX :: (ToLaTeX a) => FilePath -> [a] -> IO () +mkLaTeX fp as = + renderFile fp content + where + content :: LaTeX + content = + mconcat + [ documentclass [a0paper] article, + usepackage [] qtree, + usepackage [utf8] inputenc, + usepackage [] amsmath, + usepackage [] amssymb, + fit $ + align_ $ + intersperse "\n" $ + map toLaTeX as + ] + fit x = + mconcat $ + intersperse + "\n" + [ raw "\\begin{document}", + raw "\\hoffset=-1in", + raw "\\voffset=-1in", + vbox x, + raw "\\pdfpageheight=\\dimexpr\\ht0+\\dp0\\relax", + raw "\\pdfpagewidth=\\wd0", + raw "\\shipout\\box0", + raw "\\stop" + ] + + vbox x = raw "\\setbox0" <> liftL (\l -> TeXComm "vbox" [FixArg l]) x diff --git a/src/Futhark/Analysis/Refinement/Match.hs b/src/Futhark/Analysis/Refinement/Match.hs new file mode 100644 index 0000000000..f04445c306 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Match.hs @@ -0,0 +1,226 @@ +{-# LANGUAGE IncoherentInstances #-} +{-# LANGUAGE UndecidableInstances #-} + +module Futhark.Analysis.Refinement.Match where + +import Control.Applicative hiding (Const) +import Control.Monad +import Control.Monad.RWS +import Control.Monad.State +import Control.Monad.Trans.Class +import Data.Bifunctor +import Data.Foldable (toList) +import Data.Functor.Identity +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Representation +import Futhark.MonadFreshNames +import Futhark.SoP.SoP (SoP) +import Futhark.SoP.SoP qualified as SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark (VName) +import Language.Futhark qualified as E + +class Unify a b where + unify :: VName -> a -> b -> [Subst] + +instance (AddSubst b, Holes b, Free b, Context b) => Unify VName b where + unify k x b + | x >= k = mempty + | x `occursIn` b = mempty + | not $ S.null (S.filter (>= k) $ fv b) = mempty + | hasHoles b = mempty + | otherwise = + let s = M.singleton x b + in pure $ addSubst x b mempty + +instance Unify Term Term where + unify k e1 e2 = unifies [(e1, e2)] + where + unifies :: [(Term, Term)] -> [Subst] + unifies [] = pure mempty + unifies ((t1, t2) : es) = do + (s, new_es) <- unifyOne (flatten t1) (flatten t2) + let es' = replace s es + fmap (s <>) $ unifies $ new_es ++ es' + + unifyHole :: + Hole -> Term -> [Subst] + unifyHole Unnamed _ = mempty + unifyHole h@(CHole x a@(SoP sop)) b@(SoP sop') = + choice $ + ( do + (ctx, arg) <- sop_contexts + s <- unify k a arg + pure [addSubst x ctx s] + ) + ++ ( do + (ctx, arg) <- contexts b + s <- unifyHole h arg + pure [addSubst x ctx s] + ) + where + sop_contexts = do + perm <- L.permutations $ SoP.sopToLists sop' + let (cand, ctx) = L.splitAt (SoP.numTerms sop) perm + pure + ((SoP (SoP.sopFromList ctx) ~+~), SoP $ SoP.sopFromList cand) + unifyHole h@(CHole x a) b = + choice $ + unify k a b : do + (ctx, arg) <- contexts b + s <- unifyHole h arg + pure [addSubst x ctx s] + unifyHole (Hole x) b = + pure $ addSubst x b mempty + + unifyOneSoPTerm :: + (SoP.Term Term, Integer) -> + (SoP.Term Term, Integer) -> + [Subst] + unifyOneSoPTerm (xs, a) (ys, b) = + case (xs', ys') of + ([THole h], _) -> unifyHole h $ SoP $ SoP.term2SoP ys b + (_, [THole h]) -> unifyHole h $ SoP $ SoP.term2SoP xs a + _ + | length xs' == length ys', + a == b -> + do + xs'' <- L.permutations xs' + unifies $ zip xs'' ys' + _ -> mempty + where + xs' = SoP.termToList xs + ys' = SoP.termToList ys + isHole THole {} = True + isHole _ = False + + unifyOneSoP :: SoP Term -> SoP Term -> [(Subst, [(Term, Term)])] + unifyOneSoP x y + | length xs == length ys = do + xs' <- L.permutations xs + noSubProblems $ + unifySoPTerms $ + zip xs' ys + | otherwise = mempty + where + xs = SoP.sopToList x + ys = SoP.sopToList y + unifySoPTerms :: [((SoP.Term Term, Integer), (SoP.Term Term, Integer))] -> [Subst] + unifySoPTerms [] = pure mempty + unifySoPTerms ((t1, t2) : es) = do + s <- unifyOneSoPTerm t1 t2 + let es' = + map + ( \((t1', a), (t2', b)) -> + ((replace s t1', a), (replace s t2', b)) + ) + es + fmap (s <>) $ unifySoPTerms es' + + unifyOne :: Term -> Term -> [(Subst, [(Term, Term)])] + unifyOne t1 t2 + | t1 == t2 = pure mempty + unifyOne (THole h) t2 = + noSubProblems $ unifyHole h t2 + unifyOne t1 (THole h) = + noSubProblems $ unifyHole h t1 + unifyOne (Var x) t2 = noSubProblems $ unify k x t2 + unifyOne e1 (Var y) = pure (mempty, [(Var y, e1)]) + unifyOne (SoP x) (SoP y) = unifyOneSoP x y + unifyOne (Len x) (Len y) = pure (mempty, [(x, y)]) + unifyOne (Elems x) (Elems y) = pure (mempty, [(x, y)]) + unifyOne (Set xs) (Set ys) = + pure (mempty, zip (S.toList xs) (S.toList ys)) + unifyOne (Array xs) (Array ys) = + pure (mempty, zip xs ys) + unifyOne (Range from step to) (Range from' step' to') = + pure (mempty, [(from, from'), (step, step'), (to, to')]) + unifyOne (Idx arr i) (Idx arr' i') = + pure (mempty, [(arr, arr'), (i, i')]) + unifyOne (Union x y) (Union x' y') = + pure (mempty, [(x, x'), (y, y')]) + unifyOne (Unions i s c xs) (Unions i' s' c' xs') = + pure (mempty, [(i, i'), (s, s'), (c, c'), (xs, xs')]) + unifyOne (Sigma i s e) (Sigma i' s' e') = + pure (mempty, [(i, i'), (s, s'), (e, e')]) + unifyOne (If c t f) (If c' t' f') = + pure (mempty, [(c, c'), (t, t'), (f, f')]) + unifyOne (BoolToInt x) (BoolToInt x') = + pure (mempty, [(x, x')]) + unifyOne (e1 :== e2) (e1' :== e2') = pure (mempty, [(e1, e1'), (e2, e2')]) + unifyOne (e1 :> e2) (e1' :> e2') = pure (mempty, [(e1, e1'), (e2, e2')]) + unifyOne (PermutationOf e1 e2) (PermutationOf e1' e2') = pure (mempty, [(e1, e1'), (e2, e2')]) + unifyOne (Forall x p1 p2) (Forall y p1' p2') = pure (mempty, [(p1, p1'), (p2, p2')]) + unifyOne (Not p) (Not p') = + pure (mempty, [(p, p')]) + unifyOne _ _ = mempty + + noSubProblems = fmap (,mempty) + +unifyM :: + (Show a, Show b, MonadFreshNames m, Rename a, Rename b, Holes a, Holes b, Unify a b) => + a -> + b -> + m ([Subst], (a, b)) +unifyM a b = do + a' <- instHole a + b' <- instHole b + k <- newVName "k" + vns <- getNameSource + a'' <- rename mempty a' + putNameSource vns + b'' <- rename mempty b' + pure (unify k a'' b'', (a'', b'')) + +class Match a b where + match :: (MonadFreshNames m) => a -> b -> m [Subst] + testmatch :: (MonadFreshNames m) => a -> b -> m [(b, Subst)] + varMatch :: (MonadFreshNames m) => a -> b -> m [Subst] + +instance + (Show a, Show b, Rename a, Rename b, Holes a, Holes b, Replace a, Replace b, Unify a b) => + Match a b + where + varMatch x y = fst <$> unifyM x y + match x y = + filter + ( \(Subst s_term s_ctx) -> + S.fromList (M.keys s_term <> M.keys s_ctx) + `S.isSubsetOf` holes x + ) + <$> varMatch x y + testmatch x y = do + (mm, (x', y')) <- unifyM x y + pure $ do + s@(Subst s_term s_ectx) <- mm + guard + ( S.fromList (M.keys s_term <> M.keys s_ectx) + `S.isSubsetOf` holes x + ) + pure s + pure $ (replace s y', s) + +class (Unify a b) => Instantiate a b where + instantiateWith :: (MonadFreshNames m) => a -> b -> m [a] + +instance Instantiate Term Term where + instantiateWith e p = do + ctx <- mkCHole + ss <- varMatch (ctx e :: Term) inner_p + pure $ map (flip replace $ ctx e) ss + where + (inner_p, reqs) = popBinders p + popBinders (Forall x p1 p2) = + let (p, ps) = popBinders p2 + in (p, p1 : ps) + popBinders p = (p, mempty) diff --git a/src/Futhark/Analysis/Refinement/Monad.hs b/src/Futhark/Analysis/Refinement/Monad.hs new file mode 100644 index 0000000000..fd5b0f3a80 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Monad.hs @@ -0,0 +1,130 @@ +module Futhark.Analysis.Refinement.Monad where + +import Control.Applicative +import Control.Monad.RWS +import Data.Foldable (toList) +import Data.Functor.Identity +import Data.Map (Map) +import Data.Map qualified as M +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Latex +import Futhark.Analysis.Refinement.Representation +import Futhark.MonadFreshNames +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark qualified as E +import Language.Futhark.Prop qualified as E +import Language.Futhark.Semantic + +data SEnv = SEnv + { vnamesource :: VNameSource, + algenv :: AlgEnv Exp E.Exp, + exps :: Map E.VName E.Exp, + types :: Map E.VName E.PatType, + known :: [Prop], + known_map :: Map E.VName [Prop] + } + deriving (Eq) + +-- addExps :: SEnv -> Map E.VName E.Exp -> SEnv +-- addExps senv es = senv {exps = exps senv <> es} + +-- addKnown :: Monad m => [Prop] -> RefineT m () +-- addKnown ps = +-- modify $ \senv -> senv {known = known senv ++ ps} + +-- data Log +-- = Fail String +-- | Success String +-- | Message String +type Log = LaTeX + +type RefineM = RefineT Identity + +type CNFM = RefineT CNF + +newtype RefineT m a = RefineT {getRWST :: RWST () [Log] SEnv m a} + +deriving instance (Monad m) => Functor (RefineT m) + +deriving instance (Monad m) => Applicative (RefineT m) + +deriving instance (Monad m) => Monad (RefineT m) + +deriving instance (Monad m) => MonadState SEnv (RefineT m) + +deriving instance (Monad m) => MonadFreshNames (RefineT m) + +deriving instance (Monad m) => MonadWriter [Log] (RefineT m) + +deriving instance (Monad m, MonadPlus m) => MonadPlus (RefineT m) + +deriving instance (Monad m, MonadPlus m, Alternative m) => Alternative (RefineT m) + +instance (Monad m, Monoid w) => MonadFreshNames (RWST r w SEnv m) where + getNameSource = gets vnamesource + putNameSource vns = modify $ \senv -> senv {vnamesource = vns} + +instance (Monad m) => MonadSoP Exp E.Exp (RefineT m) where + getUntrans = gets (untrans . algenv) + getRanges = gets (ranges . algenv) + getEquivs = gets (equivs . algenv) + modifyEnv f = modify $ \env -> env {algenv = f $ algenv env} + +execRefineT :: (Monad m) => RefineT m a -> VNameSource -> m [Log] +execRefineT (RefineT m) vns = snd <$> execRWST m mempty (SEnv vns mempty mempty mempty mempty mempty) + +execRefineM :: RefineM a -> VNameSource -> [Log] +execRefineM m = runIdentity . execRefineT m + +fromCNFM :: CNFM a -> RefineM (CNF (a, SEnv, [Log])) +fromCNFM (RefineT (RWST m)) = + RefineT $ + RWST $ + \r -> \s -> pure (m r s, s, mempty) -- fix + +toCNFM :: RefineM (CNF (a, SEnv, [Log])) -> CNFM a +toCNFM (RefineT (RWST m)) = + RefineT $ + RWST $ + \r s -> (\(x, _, _) -> x) $ runIdentity $ m r s + +bindCNFM :: (a -> CNFM b) -> CNF (a, SEnv, [Log]) -> RefineM (CNF (b, SEnv, [Log])) +bindCNFM f cnf = fromCNFM $ m >>= f + where + m = RefineT $ RWST $ \_ _ -> cnf + +asCNFM :: (CNFM a -> CNFM b) -> RefineM (CNF (a, SEnv, [Log])) -> RefineM (CNF (b, SEnv, [Log])) +asCNFM f = fromCNFM . f . toCNFM + +withCNF :: (a -> Bool) -> CNFM a -> RefineM Bool +withCNF check (RefineT (RWST m)) = + RefineT $ + RWST $ \r s -> + let result = cnfIsValid (check . (\(b, _, _) -> b)) $ m r s + log = foldMap (\(_, _, w) -> w) $ toList $ m r s + in pure (result, s, log) + +addGoals :: (Pretty a) => [CNFM a] -> CNFM a +addGoals ms = + RefineT $ + RWST $ \r -> \s -> + foldr (&&&) cnfTrue (map (\m -> runRWST (getRWST m) r s) ms) + +lookupVName :: (Monad m) => E.VName -> RefineT m (Maybe E.Exp) +lookupVName x = + (M.!? x) <$> gets exps + +lookupType :: (Monad m) => E.VName -> RefineT m (Maybe E.PatType) +lookupType x = + (M.!? x) <$> gets types + +insertExp :: (Monad m) => E.VName -> E.Exp -> RefineT m () +insertExp x e = + modify $ \env -> env {exps = M.insert x e $ exps env} + +insertType :: (Monad m) => E.VName -> E.PatType -> RefineT m () +insertType x t = + modify $ \env -> env {types = M.insert x t $ types env} diff --git a/src/Futhark/Analysis/Refinement/Prop.hs b/src/Futhark/Analysis/Refinement/Prop.hs new file mode 100644 index 0000000000..ece04b25e4 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Prop.hs @@ -0,0 +1,105 @@ +module Futhark.Analysis.Refinement.Prop where + +import Control.Monad.RWS +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Relations +import Futhark.Analysis.Refinement.Representation +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark qualified as E +import Language.Futhark.Prop qualified as E + +-- isNonNeg :: Monad m => E.Exp -> RefineT m Bool +-- isNonNeg (E.AppExp (E.Apply f args _) _) +-- | Just "iota" <- getFun f = pure True +-- | Just fname <- getFun f, +-- "map" `L.isPrefixOf` fname, +-- lam : args' <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args = +-- isNonNeg lam +-- | Just "scan" <- getFun f, +-- [E.OpSection (E.QualName [] vn) _ _, _, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args, +-- "+" <- E.baseString vn = +-- isNonNeg xs +-- isNonNeg e@(E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) +-- | E.baseTag (E.qualLeaf op) <= E.maxIntrinsicTag, +-- name <- E.baseString $ E.qualLeaf op, +-- Just bop <- L.find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = +-- case bop of +-- E.Plus -> isNonNeg e_x ^&& isNonNeg e_y +-- _ -> pure False +-- isNonNeg (E.AppExp (E.If c t f _) _) = +-- (&&) <$> isNonNeg t <*> isNonNeg f +-- isNonNeg (E.IntLit x _ _) = +-- pure $ x >= 0 +-- isNonNeg (E.Lambda _ body _ _ _) = +-- isNonNeg body +-- isNonNeg (E.AppExp (E.Index xs _ _) _) = +-- isNonNeg xs +-- isNonNeg (E.Var (E.QualName qs x) _ _) = do +-- ifM +-- (Var x ^>=^ intToExp 0) +-- (pure True) +-- ( do +-- km <- gets known_map +-- case km M.!? x of +-- Just ps -> +-- anyM isNonNegProp ps +-- Nothing -> pure False +-- ) +-- where +-- isNonNegProp (ForAll i xs (x :>= y)) = +-- ifM +-- ((i ^==^ x) ^&& (intToExp 0 ^==^ y)) +-- (pure True) +-- (pure False) +-- isNonNegProp _ = pure False +-- isNonNeg _ = pure False +-- +-- -- addInfo :: E.Exp -> RefineT m () +-- -- addInfo (E.AppExp (E.Apply f args _) _) +-- -- | Just "iota" <- getFun f = do +-- -- i <- newVName "i" +-- -- ForEach (Var i) (intToExp 0 ... +-- -- +-- -- +-- -- +-- -- | Just fname <- getFun f, +-- -- "map" `L.isPrefixOf` fname, +-- -- E.Lambda params body _ _ _ : args' <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args = do +-- +simplify :: (Monad m) => a -> RefineT m a +simplify = pure + +-- simplify = astMap m +-- where +-- m = +-- ASTMapper +-- { mapOnLit = astMap m, +-- mapOnExp = +-- \e -> +-- case e of +-- (Var x) -> do +-- km <- gets known_map +-- case km M.!? x of +-- Just ps -> +-- let isFwdEq [] = pure $ Var x +-- isFwdEq ((z :== y) : rest) +-- | z == Var x = pure y +-- | otherwise = isFwdEq rest +-- isFwdEq (_ : rest) = isFwdEq rest +-- in isFwdEq ps +-- _ -> pure $ Var x +-- _ -> astMap m e, +-- mapOnProp = +-- \p -> +-- ifM +-- (checkProp p) +-- (pure $ Bool True) +-- (astMap m p) +-- } diff --git a/src/Futhark/Analysis/Refinement/Relations.hs b/src/Futhark/Analysis/Refinement/Relations.hs new file mode 100644 index 0000000000..aa93082ac3 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Relations.hs @@ -0,0 +1,107 @@ +module Futhark.Analysis.Refinement.Relations where + +import Control.Monad.RWS +import Data.Map qualified as M +import Data.Set qualified as S +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Representation +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.SoP (SoP, Substitute (..), substituteOne) +import Futhark.SoP.SoP qualified as SoP +import Futhark.SoP.Util + +-- isTrue :: (Monad m) => CNF Prop -> RefineT m Bool +-- isTrue = cnfIsValidM checkProp +-- +-- isFalse :: (Monad m) => CNF Prop -> RefineT m Bool +-- isFalse = cnfIsValidM (checkProp . negateProp) +-- +-- checkProp :: (Monad m) => Prop -> RefineT m Bool +-- checkProp (x :< y) = x ^<^ y +-- checkProp (x :<= y) = x ^<=^ y +-- checkProp (x :> y) = x ^>^ y +-- checkProp (x :>= y) = x ^>=^ y +-- checkProp (x :== y) = x ^==^ y +-- checkProp _ = pure False +-- +(^==^) :: (Monad m) => Term -> Term -> RefineT m Bool +SoP x ^==^ SoP y = x $==$ y +x ^==^ y = pure $ x == y + +(^<^) :: (Monad m) => Term -> Term -> RefineT m Bool +x ^<^ y = + (x ~+~ intToTerm 1) ^<=^ y + +(^>=^) :: (Monad m) => Term -> Term -> RefineT m Bool +x ^>=^ y = y ^<=^ x + +(^>^) :: (Monad m) => Term -> Term -> RefineT m Bool +x ^>^ y = y ^<^ x + +(^<=^) :: (Monad m) => Term -> Term -> RefineT m Bool +x ^<=^ y = + ifM + (termToSoP x $<=$ termToSoP y) + (pure True) + (x ^==^ y ^|| x ^^<=^^ y) + where + -- z ^^<=^^ Idx xs@(Var xs') i = do + -- let satForAll (ForAll y ys pred) = + -- case (y, ys) of + -- (Var y', xs) -> + -- let pred' = substituteOne (y', Idx xs i) pred + -- in ((z :<= Idx xs i) == pred') + -- || ((Idx xs i :>= z) == pred') + -- _ -> False + -- satForAll _ = False + -- km <- gets known_map + -- case km M.!? xs' of + -- Just ps -> pure $ any satForAll ps + -- _ -> pure False + SoP x ^^<=^^ Sigma (Var y_i) (Range y_from y_step y_to) y_e + | Just (1, Sigma (Var x_i) (Range x_from x_step x_to) x_e, -1) <- SoP.justAffine x = + andM + [ y_e ^<=^ intToTerm 1, + y_to ^==^ x_to, + y_from ^==^ intToTerm 1, + x_from ^==^ intToTerm 0, + x_step ^==^ intToTerm 1, + y_step ^==^ intToTerm 1, + x_e ^<=^ y_e + ] + SoP x ^^<=^^ SoP y = x $<=$ y + Sigma (Var x_i) (Range x_from x_step x_to) x_e ^^<=^^ Sigma (Var y_i) (Range y_from y_step y_to) y_e = do + x <- x_e ^<=^ substituteOne (y_i, Var x_i) y_e + y <- x_from ^<=^ y_from + z <- x_to ^<=^ y_to + andM + [ x_e ^<=^ substituteOne (y_i, Var x_i) y_e, + y_from ^<=^ x_from, + x_to ^<=^ y_to, + x_step ^==^ intToTerm 1, + y_step ^==^ intToTerm 1 + ] + x ^^<=^^ Sigma _ (Range from _ to) e = + orM + [ andM + [ x ^<=^ intToTerm 0, + intToTerm 0 ^<=^ e + ], + andM + [ x ^<=^ e, + from ^<=^ to + ], + andM + [ x ^<=^ intToTerm 0, + to ^<^ from + ] + ] + x ^^<=^^ BoolToInt {} = + x ^<=^ intToTerm 0 + _ ^^<=^^ _ = pure False + +-- isEmptySet :: (Monad m) => Term -> RefineT m Bool +-- isEmptySet Set es = pure $ S.null es +-- isEmptySet (Range from _ to) = to ^<^ from +-- isEmptySet _ = pure False diff --git a/src/Futhark/Analysis/Refinement/Rep.hs b/src/Futhark/Analysis/Refinement/Rep.hs new file mode 100644 index 0000000000..7cc6dfe5c1 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Rep.hs @@ -0,0 +1 @@ +module Futhark.Analysis.Refinement.Rep where diff --git a/src/Futhark/Analysis/Refinement/Representation.hs b/src/Futhark/Analysis/Refinement/Representation.hs new file mode 100644 index 0000000000..b09a08bab1 --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Representation.hs @@ -0,0 +1,809 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Futhark.Analysis.Refinement.Representation where + +import Control.Applicative +import Control.Monad.State +import Data.Functor.Identity +import Data.List qualified as L +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.Refinement.CNF +import Futhark.MonadFreshNames +import Futhark.SoP.Convert +import Futhark.SoP.Monad +import Futhark.SoP.SoP (SoP, Substitute (..)) +import Futhark.SoP.SoP qualified as SoP +import Futhark.Util.Pretty +import Language.Futhark (VName) +import Language.Futhark qualified as E +import Prelude + +data Hole + = Unnamed + | Hole VName + | CHole VName Term + deriving (Show, Eq, Ord) + +data Term + = -- Exps + Var VName + | THole Hole + | SoP (SoP Term) + | Len Term + | Elems Term + | Set (S.Set Term) + | Array [Term] + | Range Term Term Term + | Idx Term Term + | Union Exp Exp + | Unions Term Term Term Term + | Sigma Exp Exp Exp + | If Exp Exp Exp + | BoolToInt Term + | -- Props + (:<) Exp Exp + | (:<=) Exp Exp + | (:>) Exp Exp + | (:>=) Exp Exp + | (:==) Exp Exp + | (:/=) Exp Exp + | PermutationOf Term Term + | Forall VName Term Term + | Not Term + | Bool Bool + | CNFTerm (CNF Term) + deriving (Show, Eq, Ord) + +instance Pretty Term where + pretty (Var x) = pretty x + pretty (THole h) = pretty h + pretty (SoP sop) = pretty sop + pretty (Len xs) = prettyPre "len" [xs] + pretty (Elems x) = prettyPre "elems" [x] + pretty (Set ts) = pretty ts + pretty (Array ts) = pretty ts + pretty (Idx arr i) = parens (pretty arr) <> "[" <> pretty i <> "]" + pretty (Union x y) = pretty x <+> "U" <+> pretty y + pretty (Unions i range cond xs) = + "U_{" + <> pretty i + <> "=" + <+> pretty range + <> "," + <+> pretty cond + <> "}" + <+> parens (pretty xs) + pretty (Sigma i set e) = + "Σ_" + <> pretty i + <> "=" + <+> pretty set + <+> parens (pretty e) + pretty (If c t f) = + "If" + <+> parens (pretty c) + <+> "then" + <+> parens (pretty t) + <+> "else" + <+> parens (pretty f) + pretty (BoolToInt e) = "bool_to_int" <> parens (pretty e) + pretty (Range from step to) = prettyPre "range" [from, from ~+~ step, to] + pretty (x :> y) = prettyBin ">" x y + pretty (x :< y) = prettyBin "<" x y + pretty (x :<= y) = prettyBin "<=" x y + pretty (x :>= y) = prettyBin ">=" x y + pretty (x :== y) = prettyBin "==" x y + pretty (x :/= y) = prettyBin "/=" x y + pretty (PermutationOf x y) = prettyPre "permutationOf" [x, y] + pretty (Not x) = "!" <> parens (pretty x) + pretty (Bool b) = pretty b + pretty (CNFTerm cnf) = pretty cnf + pretty x = error $ show x + +instance Pretty Hole where + pretty (CHole _ t) = "nest(" <> pretty t <> ")" + pretty _ = "_" + +prettyPre :: (Pretty b) => Doc a -> [b] -> Doc a +prettyPre op args = + op <> parens (mconcat $ punctuate comma $ map pretty args) + +prettyBin :: (Pretty b) => Doc a -> b -> b -> Doc a +prettyBin op x y = + pretty x <+> op <+> pretty y + +instance ASTMappable Term where + astMap m (Var x) = (mapOnTerm m) $ Var x + astMap m (THole h) = (mapOnTerm m) $ THole h + astMap m (SoP sop) = do + sop' <- foldl (SoP..+.) (SoP.int2SoP 0) <$> mapM g (SoP.sopToLists sop) + case SoP.justSym sop' of + Just x -> pure x + Nothing -> pure $ SoP sop' + where + g (ts, n) = do + ts' <- traverse (mapOnTerm m) ts + pure $ foldl (SoP..*.) (SoP.int2SoP 1) (SoP.int2SoP n : map termToSoP ts') + astMap m (Len t) = Len <$> mapOnTerm m t + astMap m (Elems t) = Elems <$> mapOnTerm m t + astMap m (Set ts) = (Set . S.fromList) <$> traverse (mapOnTerm m) (S.toList ts) + astMap m (Array ts) = Array <$> traverse (mapOnTerm m) ts + astMap m (Range from step to) = Range <$> astMap m from <*> astMap m step <*> astMap m to + astMap m (Idx arr i) = Idx <$> astMap m arr <*> astMap m i + astMap m (Union x y) = Union <$> astMap m x <*> astMap m y + astMap m (Unions i s c xs) = Unions <$> astMap m i <*> astMap m s <*> astMap m c <*> astMap m xs + astMap m (Sigma i set e) = Sigma <$> astMap m i <*> astMap m set <*> astMap m e + astMap m (BoolToInt x) = BoolToInt <$> astMap m x + astMap m (If c t f) = If <$> astMap m c <*> astMap m t <*> astMap m f + astMap m (x :> y) = (:>) <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (x :< y) = (:<) <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (x :>= y) = (:>=) <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (x :<= y) = (:<=) <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (x :== y) = (:==) <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (x :/= y) = (:/=) <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (PermutationOf x y) = PermutationOf <$> mapOnTerm m x <*> mapOnTerm m y + astMap m (Forall x p1 p2) = Forall x <$> mapOnTerm m p1 <*> mapOnTerm m p2 + astMap _ t@(Bool {}) = pure t + astMap m (Not p) = Not <$> astMap m p + astMap m (CNFTerm cnf) = CNFTerm <$> astMap m cnf + +class ASTMappable a where + astMap :: (Monad m) => ASTMapper m -> a -> m a + +data ASTMapper m = ASTMapper + { mapOnTerm :: Term -> m Term + } + +instance (ASTMappable a) => Substitute VName Term a where + substitute subst = idMap m + where + m = + ASTMapper + { mapOnTerm = + \e -> + case e of + (Var x) + | Just x' <- subst M.!? x -> pure x' + | otherwise -> pure $ Var x + _ -> astMap m e + } + +idMap :: (ASTMappable a) => ASTMapper Identity -> a -> a +idMap m = runIdentity . astMap m + +flatten :: (ASTMappable a) => a -> a +flatten = idMap m + where + m = + ASTMapper + { mapOnTerm = + \e -> + case e of + Var x -> pure $ Var x + THole h -> pure $ THole h + _ -> astMap m e + } + +class Free a where + fv :: a -> Set VName + +instance (Free a) => Free (CNF a) where + fv = (foldMap . foldMap) fv . cnfToLists + +instance Free Term where + fv = flip execState mempty . astMap m + where + m = + ASTMapper + { mapOnTerm = + \e -> + case e of + Var x -> do + modify (S.insert x) + pure e + THole h -> do + put $ fv h + pure e + Unions i s cond xs -> do + put $ (fv s <> fv cond <> fv xs) S.\\ fv i + pure e + Sigma i set e -> do + put $ (fv set <> fv e) S.\\ fv i + pure e + Forall x p1 p2 -> do + put $ (fv p1 <> fv p2) S.\\ S.singleton x + pure e + _ -> astMap m e + } + +instance Free Hole where + fv Unnamed = mempty + fv (Hole x) = S.singleton x + fv (CHole x term) = fv term + +data Subst = Subst + { termSubst :: Map VName Term, + context :: Map VName (Term -> Term) + } + +instance Show Subst where + show (Subst s_term ctxs) = + show s_term ++ "\n" ++ show (fmap ($ (Bool True)) ctxs) + +instance Semigroup Subst where + (Subst s_term1 s_ctx1) <> (Subst s_term2 s_ctx2) = + Subst (s_term1 <> s_term2) (s_ctx1 <> s_ctx2) + +instance Monoid Subst where + mempty = Subst mempty mempty + +instance Free Subst where + fv (Subst s_term _) = foldMap fv $ M.elems s_term + +class AddSubst a where + addSubst :: VName -> a -> Subst -> Subst + +instance AddSubst Term where + addSubst x t s = s {termSubst = M.insert x t $ termSubst s} + +instance AddSubst (Term -> Term) where + addSubst x ctx s = + s {context = M.insertWith (.) x ctx $ context s} + +class Context a where + contexts :: a -> [(Term -> Term, Term)] + +splits :: [a] -> [([a], a, [a])] +splits = splits' [] + where + splits' _ [] = [] + splits' xs (y : ys) = (xs, y, ys) : splits' (xs ++ [y]) ys + +instance Context Term where + contexts Var {} = mempty + contexts THole {} = mempty + contexts (SoP sop) = + term_contexts + where + term_contexts = do + (lt, (ts, a), rt) <- splits $ SoP.sopToLists sop + (l, t, r) <- splits ts + pure + ( \t' -> + SoP $ + SoP.sopFromList $ + lt ++ [(l ++ [t'] ++ r, a)] ++ rt, + t + ) + contexts (Len t) = + [(Len, t)] + contexts (Elems t) = + [(Elems, t)] + contexts (Set ts) = + map (\(l, t, r) -> ((\t' -> Set $ S.fromList $ l ++ [t'] ++ r), t)) $ + splits $ + S.toList ts + contexts (Array ts) = + map (\(l, t, r) -> ((\t' -> Array $ l ++ [t'] ++ r), t)) $ + splits ts + contexts (Range from step to) = + [ ( \from' -> Range from' step to, + from + ), + ( \step' -> Range from step' to, + step + ), + ( \to' -> Range from step to', + to + ) + ] + contexts (Idx arr i) = + [ ( \arr' -> Idx arr' i, + arr + ), + ( \i' -> Idx arr i', + i + ) + ] + contexts (Union x y) = + [ ( \x' -> Union x' y, + x + ), + ( \y' -> Union x y', + y + ) + ] + contexts (Unions i s c xs) = + [ ( \s' -> Unions i s' c xs, + s + ), + ( \c' -> Unions i s c' xs, + c + ), + ( \xs' -> Unions i s c xs', + xs + ) + ] + contexts (Sigma i s e) = + [ ( \s' -> Sigma i s' e, + s + ), + ( \e' -> Sigma i s e', + e + ) + ] + contexts (If c t f) = + [ ( \c' -> If c' t f, + c + ), + ( \t' -> If c t' f, + t + ), + ( \f' -> If c t f', + f + ) + ] + contexts (BoolToInt x) = + [ ( BoolToInt, + x + ) + ] + contexts (e1 :> e2) = + [ ( \e1' -> e1' :> e2, + e1 + ), + ( \e2' -> e1 :> e2', + e2 + ) + ] + contexts (e1 :< e2) = + [ ( \e1' -> e1' :< e2, + e1 + ), + ( \e2' -> e1 :< e2', + e2 + ) + ] + contexts (e1 :>= e2) = + [ ( \e1' -> e1' :>= e2, + e1 + ), + ( \e2' -> e1 :>= e2', + e2 + ) + ] + contexts (e1 :<= e2) = + [ ( \e1' -> e1' :<= e2, + e1 + ), + ( \e2' -> e1 :<= e2', + e2 + ) + ] + contexts (e1 :/= e2) = + [ ( \e1' -> e1' :/= e2, + e1 + ), + ( \e2' -> e1 :/= e2', + e2 + ) + ] + contexts (e1 :== e2) = + [ ( \e1' -> e1' :== e2, + e1 + ), + ( \e2' -> e1 :== e2', + e2 + ) + ] + contexts (PermutationOf e1 e2) = + [ ( \e1' -> PermutationOf e1' e2, + e1 + ), + ( \e2' -> PermutationOf e1 e2', + e2 + ) + ] + contexts (Forall x p1 p2) = + [ ( \p1' -> Forall x p1' p2, + p1 + ), + ( \p2' -> Forall x p1 p2', + p2 + ) + ] + contexts (Not p) = + [ ( \p' -> Not p', + p + ) + ] + contexts Bool {} = mempty + contexts (CNFTerm cnf) = mempty + +class Replace a where + replace :: Subst -> a -> a + +(@) :: (Replace a) => a -> Subst -> a +x @ s = replace s x + +infixl 9 @ + +instance {-# OVERLAPS #-} (Ord u, Replace u) => Replace (SoP.Term u) where + replace s = + SoP.toTerm . map (replace s) . SoP.termToList + +instance Replace Term where + replace s = idMap m + where + m = + ASTMapper + { mapOnTerm = + \t -> + case t of + Var x -> do + case termSubst s M.!? x of + Nothing -> pure t + Just t' -> pure t' + THole Unnamed -> pure t + THole (Hole x) -> + case termSubst s M.!? x of + Nothing -> pure t + Just t' -> pure t' + THole (CHole x term) -> + case context s M.!? x of + Nothing -> pure $ THole $ CHole x $ replace s term + Just ctx -> pure $ ctx $ replace s term + _ -> astMap m t + } + +instance (Functor f, Replace a) => Replace (f a) where + replace s = fmap (replace s) + +instance (ASTMappable a) => ASTMappable (CNF a) where + astMap m = traverse (astMap m) + +class Rename a where + rename :: (MonadFreshNames m) => Subst -> a -> m a + +instance Rename Hole where + rename s h@(Unnamed) = pure h + rename s h@(Hole {}) = pure h + rename s (CHole f term) = + CHole f <$> rename s term + +instance Rename Term where + rename s = astMap m + where + m = + ASTMapper + { mapOnTerm = + \t -> + case t of + Var x -> do + case termSubst s M.!? x of + Nothing -> pure t + Just t' -> pure t' + THole h -> THole <$> rename s h + Unions (Var i) set c xs -> do + i' <- newVName $ E.baseString i + let s' = addSubst i (Var i') s + Unions (Var i') <$> rename s' set <*> rename s' c <*> rename s' xs + Sigma (Var i) set e -> do + i' <- newVName $ E.baseString i + let s' = addSubst i (Var i') s + Sigma (Var i') <$> rename s' set <*> rename s' e + Forall x p1 p2 -> do + x' <- newVName $ E.baseString x + let s' = addSubst x (Var x') s + Forall x' <$> rename s' p1 <*> rename s' p2 + _ -> astMap m t + } + +class Holes a where + holes :: a -> Set VName + instHole :: (MonadFreshNames m) => a -> m a + hole :: a + mkHole :: (MonadFreshNames m) => m a + mkCHole :: (MonadFreshNames m) => m (Term -> a) + +instance Holes Hole where + holes Unnamed = mempty + holes (Hole x) = S.singleton x + holes (CHole x term) = S.singleton x <> holes term + + instHole Unnamed = mkHole + instHole h = pure h + + mkHole = do + h <- newVName "hole" + pure $ Hole h + + mkCHole = do + ctx <- newVName "ctx" + pure $ CHole ctx + + hole = Unnamed + +instance (Holes a) => Holes (CNF a) where + holes = (foldMap . foldMap) holes . cnfToLists + instHole = traverse instHole + +instance Holes Term where + holes = flip execState mempty . astMap m + where + m = + ASTMapper + { mapOnTerm = + \t -> + case t of + Var x -> pure t + THole h -> do + modify $ (<> holes h) + pure t + _ -> astMap m t + } + + instHole = astMap m + where + m = + ASTMapper + { mapOnTerm = + \t -> + case t of + Var x -> pure t + THole h -> THole <$> instHole h + _ -> astMap m t + } + + mkHole = THole <$> mkHole + + mkCHole = do + (THole .) <$> mkCHole + + hole = THole hole + +hasHoles :: (Holes a) => a -> Bool +hasHoles = not . S.null . holes + +mkHoles :: (Holes a, MonadFreshNames m) => Int -> m [a] +mkHoles = flip replicateM mkHole + +instance Nameable Term where + mkName (VNameSource i) = (Var $ E.VName "x" i, VNameSource $ i + 1) + +instance ToSoP Term E.Exp where + toSoPNum e = do + x <- lookupUntransPE e + pure (1, SoP.sym2SoP x) + +occursIn :: (Free a) => VName -> a -> Bool +occursIn x = S.member x . fv + +getFun :: E.Exp -> Maybe String +getFun (E.Var (E.QualName [] vn) _ _) = Just $ E.baseString vn +getFun _ = Nothing + +negateProp :: Prop -> Prop +negateProp (x :== y) = x :/= y +negateProp (Not p) = p +negateProp (x :< y) = x :>= y +negateProp (x :<= y) = x :> y +negateProp (x :> y) = x :<= y +negateProp (x :>= y) = x :< y +negateProp (x :/= y) = x :== y +negateProp p = Not p + +setSize :: Exp -> Maybe Exp +setSize (Range from step to) + | step == intToTerm 1 -- Should be a monadic check with env + = + Just $ to ~-~ from ~+~ intToTerm 1 +setSize _ = Nothing + +-- Fix +knownFromCtx :: (Term -> Term) -> [Term] +knownFromCtx ctx = fromMaybe mempty $ knownFromCtx' $ ctx garbage + where + garbage = Not $ Len $ THole $ Unnamed + knownFromCtx' :: Term -> Maybe [Term] + knownFromCtx' (Not (Len (THole Unnamed))) = pure mempty + knownFromCtx' (Unions (Var i) set (CNFTerm cnf) e) + | [ps] <- dnfToLists $ toDNF cnf = + (ps ++) <$> (knownFromCtx' set <|> knownFromCtx' e) + knownFromCtx' e = choice $ map (knownFromCtx' . snd) $ contexts e + +choice :: (Foldable t, Alternative f) => t (f a) -> f a +choice = foldl (<|>) empty + +type Prop = Term + +type Exp = Term + +(~-~) :: Exp -> Exp -> Exp +x ~-~ y = flatten $ SoP $ termToSoP x SoP..-. termToSoP y + +(~+~) :: Exp -> Exp -> Exp +x ~+~ y = flatten $ SoP $ termToSoP x SoP..+. termToSoP y + +(~*~) :: Exp -> Exp -> Exp +x ~*~ y = flatten $ SoP $ termToSoP x SoP..*. termToSoP y + +termToSoP :: Exp -> SoP Exp +termToSoP e = + case flatten e of + SoP sop -> sop + e -> SoP.sym2SoP e + +termToSet :: Exp -> Exp +termToSet (Array rs) = Set $ S.fromList rs +termToSet e = e + +intToTerm :: Integer -> Term +intToTerm = SoP . SoP.int2SoP + +(...) :: Term -> Term -> Term +from ... to = Range from (intToTerm 1) to + +infix 1 ... + +-- Old instances that didn't use astMap + +-- instance Free Term where +-- fv (Var x) = S.singleton x +-- fv (THole h) = fv h +-- fv (SoP sop) = foldMap (foldMap fv . fst) $ SoP.sopToLists sop +-- fv (Len e) = fv e +-- fv (Elems e) = fv e +-- fv (Set es) = foldMap fv es +-- fv (Array es) = foldMap fv es +-- fv (Range from step to) = fv from <> fv step <> fv to +-- fv (Idx arr i) = fv arr <> fv i +-- fv (Union x y) = fv x <> fv y +-- fv (Unions i s cond xs) = (fv s <> fv cond <> fv xs) S.\\ fv i +-- fv (Sigma i set e) = fv set <> fv e S.\\ fv i +-- fv (If c t f) = fv c <> fv t <> fv f +-- fv (BoolToInt x) = fv x +-- fv (e1 :== e2) = fv e1 <> fv e2 +-- fv (e1 :> e2) = fv e1 <> fv e2 +-- fv (PermutationOf e1 e2) = fv e1 <> fv e2 +-- fv (Forall x p1 p2) = (fv p1 <> fv p2) S.\\ S.singleton x +-- fv (Not p) = fv p +-- fv Bool {} = mempty +-- fv (CNFTerm cnf) = fv cnf +-- fv t = error $ show t + +-- instance Replace Term where +-- replace s t@(Var x) = +-- case termSubst s M.!? x of +-- Nothing -> t +-- Just t' -> t' +-- replace s h@(THole Unnamed) = h +-- replace s h@(THole (Hole x)) = +-- case termSubst s M.!? x of +-- Nothing -> h +-- Just t -> t +-- replace s (THole (CHole x term)) = +-- case context s M.!? x of +-- Nothing -> THole $ CHole x term' +-- Just ctx -> ctx term' +-- where +-- term' = replace s term +-- replace s (SoP sop) = SoP $ SoP.mapSymSoP_ (replace s) sop +-- replace s (Len t) = Len $ replace s t +-- replace s (Elems t) = Elems $ replace s t +-- replace s (Set ts) = Set $ S.map (replace s) ts +-- replace s (Array ts) = Array $ map (replace s) ts +-- replace s (Range from step to) = Range (replace s from) (replace s step) (replace s to) +-- replace s (Idx arr i) = Idx (replace s arr) (replace s i) +-- replace s (Union x y) = Union (replace s x) (replace s y) +-- replace s (Unions (Var i) set c xs) = Unions (Var i) (replace s set) (replace s c) (replace s xs) +-- replace s (Unions i set c xs) = Unions (replace s i) (replace s set) (replace s c) (replace s xs) +-- replace s (Sigma (Var i) set e) = Sigma (Var i) (replace s set) (replace s e) +-- replace s (Sigma i set e) = Sigma (replace s i) (replace s set) (replace s e) +-- replace s (If c t f) = If (replace s c) (replace s t) (replace s f) +-- replace s (BoolToInt x) = BoolToInt (replace s x) +-- replace s (x :== y) = replace s x :== replace s y +-- replace s (x :> y) = replace s x :> replace s y +-- replace s (Forall x p1 p2) = +-- Forall x (replace s p1) (replace s p2) +-- replace _ t@Bool {} = t +-- replace s (Not p) = Not $ replace s p +-- replace s (CNFTerm cnf) = CNFTerm $ replace s <$> cnf +-- replace _ t = error $ prettyString t + +-- instance Rename Term where +-- rename s (Var x) = +-- case termSubst s M.!? x of +-- Just t -> pure t +-- Nothing -> pure $ Var x +-- rename s (THole h) = THole <$> rename s h +-- rename s (SoP sop) = SoP <$> SoP.mapSymSoPM (rename s) sop +-- rename s (Len t) = Len <$> rename s t +-- rename s (Elems t) = Elems <$> rename s t +-- rename s (Set ts) = (Set . S.fromList) <$> mapM (rename s) (S.toList ts) +-- rename s (Array ts) = Array <$> mapM (rename s) ts +-- rename s (Range from step to) = Range <$> rename s from <*> rename s step <*> rename s to +-- rename s (Idx arr i) = Idx <$> rename s arr <*> rename s i +-- rename s (Union x y) = Union <$> rename s x <*> rename s y +-- rename s (Unions (Var i) set c xs) = do +-- i' <- newVName $ E.baseString i +-- let s' = addSubst i (Var i') s +-- Unions (Var i') <$> rename s' set <*> rename s' c <*> rename s' xs +-- rename s (Unions i set c xs) = +-- Unions <$> rename s i <*> rename s set <*> rename s c <*> rename s xs +-- rename s (Sigma (Var i) set e) = do +-- i' <- newVName $ E.baseString i +-- let s' = addSubst i (Var i') s +-- Sigma (Var i') <$> rename s' set <*> rename s' e +-- rename s (Sigma i set e) = +-- Sigma <$> rename s i <*> rename s set <*> rename s e +-- rename s (If c t f) = If <$> rename s c <*> rename s t <*> rename s f +-- rename s (BoolToInt x) = BoolToInt <$> rename s x +-- rename s (x :== y) = (:==) <$> rename s x <*> rename s y +-- rename s (x :> y) = (:>) <$> rename s x <*> rename s y +-- rename s (PermutationOf x y) = PermutationOf <$> rename s x <*> rename s y +-- rename s (Forall x p1 p2) = do +-- x' <- newVName $ E.baseString x +-- let s' = addSubst x (Var x') s +-- Forall x' <$> rename s' p1 <*> rename s' p2 +-- rename _ t@Bool {} = pure t +-- rename s (Not p) = Not <$> rename s p +-- rename s (CNFTerm cnf) = +-- CNFTerm <$> traverse (rename s) cnf +-- rename s t = error $ show t + +-- instance Holes Term where +-- holes Var {} = mempty +-- holes (THole h) = holes h +-- holes (SoP sop) = foldMap holes $ concatMap fst $ SoP.sopToLists sop +-- holes (Len t) = holes t +-- holes (Elems t) = holes t +-- holes (Set ts) = foldMap holes ts +-- holes (Array ts) = foldMap holes ts +-- holes (Range from step to) = holes from <> holes step <> holes to +-- holes (Idx arr i) = holes arr <> holes i +-- holes (Union x y) = holes x <> holes y +-- holes (Unions i s c xs) = holes i <> holes s <> holes c <> holes xs +-- holes (Sigma i set e) = holes i <> holes set <> holes e +-- holes (If c t f) = holes c <> holes t <> holes f +-- holes (BoolToInt x) = holes x +-- holes (x :== y) = holes x <> holes y +-- holes (x :> y) = holes x <> holes y +-- holes (PermutationOf x y) = holes x <> holes y +-- holes (Forall x p1 p2) = holes p1 <> holes p2 +-- holes Bool {} = mempty +-- holes (Not p) = holes p +-- holes (CNFTerm cnf) = holes cnf +-- +-- instHole t@(Var {}) = pure t +-- instHole (THole h) = THole <$> instHole h +-- instHole (SoP sop) = SoP <$> SoP.mapSymSoPM instHole sop +-- instHole (Len t) = Len <$> instHole t +-- instHole (Elems t) = Elems <$> instHole t +-- instHole (Set ts) = (Set . S.fromList) <$> mapM instHole (S.toList ts) +-- instHole (Array ts) = Array <$> mapM instHole ts +-- instHole (Range from step to) = Range <$> instHole from <*> instHole step <*> instHole to +-- instHole (Idx arr i) = Idx <$> instHole arr <*> instHole i +-- instHole (Union x y) = Union <$> instHole x <*> instHole y +-- instHole (Unions i s c xs) = Unions <$> instHole i <*> instHole s <*> instHole c <*> instHole xs +-- instHole (Sigma i set e) = Sigma <$> instHole i <*> instHole set <*> instHole e +-- instHole (If c t f) = If <$> instHole c <*> instHole t <*> instHole f +-- instHole (BoolToInt x) = BoolToInt <$> instHole x +-- instHole (x :== y) = (:==) <$> instHole x <*> instHole y +-- instHole (x :> y) = (:>) <$> instHole x <*> instHole y +-- instHole (PermutationOf x y) = PermutationOf <$> instHole x <*> instHole y +-- instHole (Forall x p1 p2) = Forall x <$> instHole p1 <*> instHole p2 +-- instHole t@Bool {} = pure t +-- instHole (Not p) = Not <$> instHole p +-- instHole (CNFTerm cnf) = CNFTerm <$> instHole cnf +-- +-- mkHole = THole <$> mkHole +-- +-- mkCHole = do +-- (THole .) <$> mkCHole +-- +-- hole = THole hole diff --git a/src/Futhark/Analysis/Refinement/Rules.hs b/src/Futhark/Analysis/Refinement/Rules.hs new file mode 100644 index 0000000000..2bf26648ca --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Rules.hs @@ -0,0 +1,472 @@ +module Futhark.Analysis.Refinement.Rules where + +import Control.Applicative +import Control.Monad.RWS +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Analysis.Refinement.CNF +import Futhark.Analysis.Refinement.Convert +import Futhark.Analysis.Refinement.Match +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Relations +import Futhark.Analysis.Refinement.Representation +import Futhark.MonadFreshNames +import Futhark.SoP.Refine +import Futhark.SoP.SoP qualified as SoP +import Futhark.SoP.Util +import Language.Futhark qualified as E + +withType :: (Monad m) => E.VName -> (E.PatType -> RefineT m (Maybe a)) -> RefineT m (Maybe a) +withType x f = do + me <- lookupVName x + mt <- lookupType x + case (me, mt) of + (Just e, _) -> f $ E.typeOf e + (_, Just t) -> f t + _ -> pure Nothing + +data Rule = Rule + { from :: Term, + to :: Subst -> CNFM (Maybe Term) + } + +nope :: CNFM (Maybe a) +nope = pure Nothing + +yep :: a -> CNFM (Maybe a) +yep = pure . Just + +yepM :: CNFM a -> CNFM (Maybe a) +yepM = fmap Just + +rules :: CNFM [(String, Rule)] +rules = do + h1 <- mkHole + h2 <- mkHole + h3 <- mkHole + h4 <- mkHole + h5 <- mkHole + h6 <- mkHole + h7 <- mkHole + ctx1 <- mkCHole + ctx2 <- mkCHole + let withCtx f subst = + (fmap . fmap) (replace subst . ctx1) $ f subst + withCtxAndSubst f subst = withCtx f subst @ subst + anywhere = ctx1 + pure $ + [ ( "eq", + Rule + { from = h1 :== h2, + to = \s -> + if (replace s h1 == replace s h2) + then yep $ Bool True + else nope + } + ), + ( "permutation_of", + Rule + { from = PermutationOf h1 h2, + to = \s -> + yepM $ + addGoals + [ pure $ (Len h1 :== Len h2) @ s, + pure $ (Elems h1 :== Elems h2) @ s + ] + } + ), + ( "len", + Rule + { from = anywhere $ Len h1, + to = withCtxAndSubst $ \s -> + case h1 @ s of + Var u -> do + withType u $ \t -> + toExp $ head $ E.shapeDims $ E.arrayShape t + Set xs -> yep $ intToTerm $ fromIntegral $ length xs + Array xs -> yep $ intToTerm $ fromIntegral $ length xs + _ -> nope + } + ), + ( "len_range_unit_step", + Rule + { from = anywhere $ Len (Range h1 (intToTerm 1) h2), + to = withCtxAndSubst $ \_ -> + yep $ h2 ~-~ h1 ~+~ intToTerm 1 + } + ), + ( "elems_range", + Rule + { from = anywhere $ Elems (Range h1 h2 h3), + to = withCtx $ \s -> + pure $ Just $ Range h1 h2 h3 @ s + } + ), + ( "elems_union", + Rule + { from = anywhere $ Elems h1, + to = + withCtx $ \s -> do + i <- newVName "i" + pure $ + Just $ + Unions + (Var i) + (intToTerm 0 ... Len h1 @ s ~-~ intToTerm 1) + (CNFTerm cnfTrue) + (Idx h1 @ s $ Var i) + } + ), + ( "map_index", + Rule + { from = anywhere $ Idx h1 h2, + to = + withCtxAndSubst $ \s -> + case h1 @ s of + Var arr' -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just fname <- getFun f, + "map" `L.isPrefixOf` fname, + E.Lambda params body _ _ _ : args' <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + let ps = map (\p -> let (E.Named x, _, _) = E.patternParam p in x) params + body' <- toExp body + argsm'' <- sequence <$> mapM toExp args' + case argsm'' of + Nothing -> pure Nothing + Just args'' -> + let subst = M.fromList $ zip ps (map (\arg -> Idx arg h2) args'') + in pure $ fmap (SoP.substitute subst) body' + _ -> pure Nothing + _ -> pure Nothing + } + ), + ( "union_if", + Rule + { from = anywhere $ Unions h1 h2 h3 (If h4 h5 h6), + to = + withCtxAndSubst $ \s -> + case (h3 @ s, h4 @ s) of + (CNFTerm cond, CNFTerm b) -> + pure $ + Just $ + Unions h1 h2 (CNFTerm $ cond &&& b) h5 + `Union` Unions h1 h2 (CNFTerm $ cond &&& negateCNF negateProp b) h6 + _ -> pure Nothing + } + ), + ( "combine_if_sop", + Rule + { from = anywhere $ h1 ~+~ (If h2 h3 h4), + to = withCtxAndSubst $ \_ -> + pure $ Just $ If h2 (h1 ~+~ h3) (h1 ~+~ h4) + } + ), + ( "split_on_if", + Rule + { from = + ctx1 $ If h1 h2 h3, + to = \s -> + if (not . S.null $ fv (h1 @ s) `S.intersection` fv (ctx1 (Bool True) @ s)) + then pure Nothing + else case (h1 @ s) of + CNFTerm c -> + Just + <$> addGoals + [ addInfo c >> (pure $ ctx1 $ h2 @ s), + addInfo (negateCNF (negateProp) c) + >> (pure $ ctx1 $ h3 @ s) + ] + _ -> pure Nothing + } + ), + ( "scan_sum", + Rule + { from = anywhere $ Idx h1 h2, + to = + withCtxAndSubst $ \s -> + case (h1 @ s, h2 @ s) of + (Var arr', i) -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just "scan" <- getFun f, + [E.OpSection (E.QualName [] vn) _ _, _, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args, + "+" <- E.baseString vn -> do + xsm <- toExp xs + case xsm of + Nothing -> pure Nothing + Just xs' -> do + k <- newVName "j" + yep $ + Sigma + (Var k) + (intToTerm 0 ... i) + (Idx xs' $ Var k) + _ -> nope + _ -> nope + } + ), + ( "union_sigma_bool", + Rule + { from = + ctx1 $ + Unions h1 h2 h3 $ + ctx2 $ + Sigma h4 h5 h6, + to = \s -> + do + let [i, range, conds, j, jset, e] = [h1, h2, h3, h4, h5, h6] @ s + case (conds, range, e, jset) of + (CNFTerm conds', Range from step to, BoolToInt (Idx cs j'), Range jstart jstep jend) + | [[cs']] <- cnfToLists conds' -> + ifM + ( andM + [ jend ^==^ i, + jstart ^==^ intToTerm 0, + jstep ^==^ intToTerm 1, + from ^==^ intToTerm 0, + step ^==^ intToTerm 1, + pure $ constCtx (ctx2 @ s), + j ^==^ j', + Idx cs i ^==^ cs' + ] + ) + ( yep $ + ( ctx1 $ + Range + (ctx2 $ intToTerm 1) + (intToTerm 1) + ( ctx2 $ + Sigma + j + (Range jstart jstep to) + e + ) + ) + @ s + ) + nope + _ -> nope + } + ), + ( "var", + Rule + { from = anywhere $ h1, + to = + withCtxAndSubst $ \s -> + case h1 @ s of + Var x -> + withBinding x $ \e -> + toExp e + _ -> nope + } + ), + ( "split_sigma", + Rule + { from = anywhere $ Sigma h1 h2 h3, + to = withCtxAndSubst $ + \s -> do + let [i, set, e] = [h1, h2, h3] @ s + case e of + SoP sop + | isNothing $ SoP.justSingleTerm sop -> + let sums = map (Sigma i set . SoP . uncurry SoP.term2SoP) $ SoP.sopToLists sop + in yep $ foldl1 (~+~) sums + _ -> nope + } + ), + ( "const_sigma", + Rule + { from = anywhere $ Sigma h1 h2 h3, + to = withCtxAndSubst $ + \s -> do + let [_, set, e] = [h1, h2, h3] @ s + case SoP.justConstant $ termToSoP e of + Just c -> + pure $ (intToTerm c ~*~) <$> setSize set + _ -> nope + } + ), + ( "mul_sigma", + Rule + { from = anywhere $ Sigma h1 h2 h3, + to = withCtxAndSubst $ + \s -> do + let [i, set, e] = [h1, h2, h3] @ s + case SoP.justSingleTerm $ termToSoP e of + Just (t, n) + | n /= 1 -> + pure $ Just $ intToTerm n ~*~ Sigma i set (SoP $ SoP.term2SoP t 1) + _ -> pure Nothing + } + ), + ( "basic_combine_sigma", + Rule + { from = + anywhere $ + Sigma h1 h2 h3 ~-~ Sigma h4 h5 h6, + to = withCtxAndSubst $ + \s -> do + let [Var x_i, x_set, x_e, Var y_i, y_set, y_e] = + [h1, h2, h3, h4, h5, h6] @ s + case (x_set, y_set) of + (Range x_start x_step x_end, Range y_start y_step y_end) -> + ifM + ( andM + [ x_e ^==^ SoP.substituteOne (y_i, Var x_i) y_e, + x_start ^==^ y_start, + x_step ^==^ y_step + ] + ) + ( yep $ + ( Sigma (Var x_i) ((y_end ~+~ intToTerm 1) ... x_end) x_e + ) + ) + nope + _ -> nope + } + ), + -- FIX: Needs to actually check the condition on the union + -- for compatability with the rule. + ( "i_+_sigma_bool_to_int", + Rule + { from = + ctx1 $ + Unions h1 h2 h3 $ + ctx2 $ + h4 ~+~ Sigma h5 h6 h7, + to = + \s -> do + let [u_i, range, cond, vi, y_j, y_set, y_e] = [h1, h2, h3, h4, h5, h6, h7] @ s + case (cond, flatten vi, range, y_e, y_set) of + (CNFTerm cnf, Var i, Range u_min u_step u_end, BoolToInt (Idx arr idx), Range (SoP y_start_sop) y_step y_end) + | Just (1, y_start_i, c) <- SoP.justAffine y_start_sop, + and + [ u_min == intToTerm 0, + u_step == intToTerm 1, + y_start_i == Var i, + u_end == y_end, + idx == y_j + ] -> do + let e_min = Sigma y_j ((SoP.substituteOne (i, intToTerm 0) y_start_i) ... y_end) y_e + e_max = y_end + yep $ + (ctx1 $ ctx2 $ (e_min ... e_max)) @ s + _ -> nope + } + ), + ( "empty_rset", + Rule + { from = ctx1 $ Range h1 h2 h3, + to = + \s -> do + let [from, step, to] = [h1, h2, h3] @ s + ifM + ( localS id $ do + mapM_ addToAlgEnv $ knownFromCtx (ctx1 @ s) + to ^<^ from + ) + (yep $ (ctx1 $ Set mempty) @ s) + nope + } + ), + ( "empty_unions", + Rule + { from = anywhere $ Unions h1 h2 h3 h4, + to = withCtxAndSubst $ \s -> do + case h2 @ s of + Set xs + | S.null xs -> yep $ Set mempty + _ -> nope + } + ), + ( "empty_union", + Rule + { from = anywhere $ Union h1 h2, + to = withCtxAndSubst $ \s -> do + case (h1 @ s, h2 @ s) of + (Set xs, _) + | S.null xs -> yep $ h2 + (_, Set xs) + | S.null xs -> yep $ h1 + _ -> nope + } + ), + ( "combine_ranges", + Rule + { from = anywhere $ Range h1 h2 h3 `Union` Range h4 h5 h6, + to = withCtxAndSubst $ \s -> do + let [from1, step1, to1, from2, step2, to2] = [h1, h2, h3, h4, h5, h6] @ s + ifM + ( andM + [ from1 ^<=^ from2, + from2 ^<=^ (to1 ~+~ intToTerm 1), + step1 ^==^ step2 + ] + ) + ( yep $ Range from1 step1 to2 + ) + nope + } + ) + ] + +constCtx :: (Term -> Term) -> Bool +constCtx ctx = + isJust $ SoP.justConstant $ termToSoP (ctx $ intToTerm 0) + +withBinding :: (Monad m) => E.VName -> (E.Exp -> RefineT m (Maybe a)) -> RefineT m (Maybe a) +withBinding x f = do + me <- lookupVName x + case me of + Just e -> f e + Nothing -> pure Nothing + +matchRule :: Rule -> Term -> CNFM (Maybe Term) +matchRule r p = do + subs <- match (from r) p + checkMatch subs + where + checkMatch [] = pure Nothing + checkMatch (s : ss) = do + mt <- to r s + case mt of + Just t' -> pure $ Just t' + Nothing -> checkMatch ss + +applyRules :: Term -> CNFM (Maybe (String, Term)) +applyRules p = do + rs <- rules + applyRules' rs + where + applyRules' [] = pure Nothing + applyRules' ((label, r) : rs) = do + mp' <- matchRule r p + case mp' of + Nothing -> applyRules' rs + Just p' -> pure $ Just (label, p') + +addInfo :: CNF Prop -> CNFM () +addInfo = + asum + . map + ( \props -> do + modify (\env -> env {known = known env ++ props}) + mapM_ addToAlgEnv props + ) + . dnfToLists + . toDNF + +addToAlgEnv :: Term -> CNFM () +addToAlgEnv (x :> y) = addRel $ termToSoP x SoP.:>: termToSoP y +addToAlgEnv (x :>= y) = addRel $ termToSoP x SoP.:>=: termToSoP y +addToAlgEnv (x :< y) = addRel $ termToSoP x SoP.:<: termToSoP y +addToAlgEnv (x :<= y) = addRel $ termToSoP x SoP.:<=: termToSoP y +addToAlgEnv (Not (x :> y)) = addRel $ termToSoP x SoP.:<=: termToSoP y +addToAlgEnv (Not (x :>= y)) = addRel $ termToSoP x SoP.:<: termToSoP y +addToAlgEnv p = pure () diff --git a/src/Futhark/Analysis/Refinement/Rules.hs.old b/src/Futhark/Analysis/Refinement/Rules.hs.old new file mode 100644 index 0000000000..0648ed783b --- /dev/null +++ b/src/Futhark/Analysis/Refinement/Rules.hs.old @@ -0,0 +1,1982 @@ +module Futhark.Analysis.Refinement.Rules where + +import Control.Applicative +import Control.Monad.RWS +import Data.Bifunctor +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace +import Futhark.Analysis.Refinement.CNFNew +import Futhark.Analysis.Refinement.Convert +import Futhark.Analysis.Refinement.Monad +import Futhark.Analysis.Refinement.Relations +import Futhark.Analysis.Refinement.Representation +import Futhark.MonadFreshNames +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.Refine +import Futhark.SoP.SoP (SoP) +import Futhark.SoP.SoP qualified as SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark qualified as E + +-- 'Exp' contexts for matching holes. +type ExpContext = Exp -> Exp + +-- 'Prop' contexts for matching holes. +type PropContext = [Exp] -> Prop + +type BoundVars = Set E.VName + +-- General rules on goal; yields a CNF formula. +data Rule = Rule + { conditions :: PropContext -> [(ExpContext, BoundVars)] -> [Exp] -> CNFM (Maybe Prop), + conclusion :: Prop + } + +-- Simple substitutions on 'Exp's. +data SubstRule = SubstRule + { to :: [(ExpContext, BoundVars)] -> [Exp] -> CNFM (Maybe Exp), + from :: Exp + } + +class Anywhere a where + anywhere :: a -> a + +instance Anywhere Exp where + anywhere = Nested + +instance Anywhere Prop where + anywhere = NestedProp + +addInfo :: CNF Prop -> CNFM () +addInfo = + asum + . map + ( \props -> do + modify (\env -> env {known = known env ++ props}) + mapM_ addToAlgEnv props + ) + . dnfToLists + . toDNF + where + addToAlgEnv :: Prop -> CNFM () + addToAlgEnv (x :> y) = addRel $ expToSoP x SoP.:>: expToSoP y + addToAlgEnv (x :>= y) = addRel $ expToSoP x SoP.:>=: expToSoP y + addToAlgEnv (x :< y) = addRel $ expToSoP x SoP.:<: expToSoP y + addToAlgEnv (x :<= y) = addRel $ expToSoP x SoP.:<=: expToSoP y + addToAlgEnv (Not (x :> y)) = addRel $ expToSoP x SoP.:<=: expToSoP y + addToAlgEnv (Not (x :>= y)) = addRel $ expToSoP x SoP.:<: expToSoP y + addToAlgEnv p = pure () + +nope :: CNFM (Maybe a) +nope = pure Nothing + +yep :: a -> CNFM (Maybe a) +yep = pure . Just + +yepM :: CNFM a -> CNFM (Maybe a) +yepM = fmap Just + +convertToForEach :: (String, Rule) -> (String, Rule) +convertToForEach (name, r) = + ( name ++ "_foreach", + Rule + { conditions = + \pctx -> \ctxs -> \(i : set : rest) -> do + mgoal <- conditions r pctx ctxs rest + case mgoal of + Nothing -> nope + Just goal -> yep $ ForEach i set goal, + conclusion = ForEach Hole Hole (conclusion r) + } + ) + +rules :: [(String, Rule)] +rules = + concatMap (\r -> [r, convertToForEach r]) rules' + where + constCtx :: ExpContext -> Bool + constCtx ctx = + isJust $ SoP.justConstant $ expToSoP (ctx $ intToExp 0) + rules' :: [(String, Rule)] + rules' = + [ ( "axiom", + Rule + { conditions = \pctx -> \_ -> \_ -> + yep $ Bool True, + conclusion = Axiom (anyProp Hole) + } + ), + ( "eq", + Rule + { conditions = \_ -> \_ -> \[x, y] -> + ifM + (x ^==^ y) + (yep $ Bool True) + nope, + conclusion = Hole :== Hole + } + ), + ( "true_foreach", + Rule + { conditions = \_ -> \_ -> \_ -> yep $ Bool True, + conclusion = ForEach Hole Hole (Bool True) + } + ), + ( "eq_sigma", + Rule + { conditions = \_ -> \_ -> + \[Var i, set_i, e_i, Var j, set_j, e_j] -> do + es <- massageInto i e_i e_j + case es of + [_] -> + ifM + (set_i ^==^ set_j) + (yep $ Bool True) + nope + _ -> nope, + conclusion = Sigma Hole Hole Hole :== Sigma Hole Hole Hole + } + ), + ( "permutation_of", + Rule + { conditions = + \_ -> + \_ -> + \[xs, ys] -> + yepM $ + addGoals + [ pure $ Len xs :== Len ys, + pure $ Elems xs :== Elems ys + ], + conclusion = PermutationOf Hole Hole + } + ), + ( "const_ordered", + Rule + { conditions = + \_ -> + \_ -> + \[xs] -> + case xs of + Lit (Array xs') + | all constExp xs' -> yep $ Bool True + Concats i set conds e + | constExp e -> yep $ Bool True + _ -> nope, + conclusion = Ordered Hole + } + ), + ( "empty_sigma", + toRule $ + SubstRule + { to = + withExpContext $ + \[i, set, e] -> do + aenv <- gets algenv + isempty <- isEmptySet set + ifM + (isEmptySet set) + (yep $ intToExp 0) + nope, + from = anywhere $ Sigma Hole Hole Hole + } + ), + ( "empty_rset", + toRule $ + SubstRule + { to = + withExpContext $ + \[from, step, to] -> + ifM + (isEmptySet $ Range from step to) + (yep Empty) + nope, + from = anywhere $ Range Hole Hole Hole + } + ), + ( "empty_concats", + toRule $ + SubstRule + { to = + withExpContext $ + \_ -> + yep (Lit $ Array $ []), + from = anywhere $ Concats Hole Empty Hole Hole + } + ), + ( "empty_concats2", + toRule $ + SubstRule + { to = + withExpContext $ + \[_, _, CNFExp cnf, _] -> do + r <- isFalse cnf + t <- SoP.int2SoP (-1) $<=$ SoP.int2SoP 0 + k <- intToExp (-1) ^<=^ intToExp 0 + ifM + (isFalse cnf) + (yep $ Lit $ Array $ []) + nope, + from = anywhere $ Concats Hole Hole Hole Hole + } + ), + ( "contract_union", + toRule $ + SubstRule + { to = + \[(ctx1, _), (ctx2, _)] -> + \[Var i, i_set, i_cond, i_exp, Var j, j_set, j_cond, j_exp] -> + ifM + ( andM + [ i_set ^==^ SoP.substituteOne (j, Var i) j_set, + i_cond ^==^ SoP.substituteOne (j, Var i) j_cond + ] + ) + ( yep $ ctx1 $ ctx2 $ Unions (Var i) i_set i_cond $ i_exp `union` SoP.substituteOne (j, Var i) j_exp + ) + nope, + from = + anywhere $ + InUnion $ + Unions Hole Hole Hole Hole + `U` Unions Hole Hole Hole Hole + } + ), + ( "empty_unions", + toRule $ + SubstRule + { to = + withExpContext $ + \_ -> + yep Empty, + from = anywhere $ Unions Hole Empty Hole Hole + } + ), + ( "singleton_union", + toRule $ + SubstRule + { to = + withExpContext $ + \xs -> + case xs of + [x] -> yep $ x + _ -> nope, + from = anywhere $ Union LHole + } + ), + ( "empty_union_filter", + toRule $ + SubstRule + { to = + withExpContext $ + \xs -> + let xs' = filter (/= Empty) xs + in if (xs' /= xs) + then yep $ Union $ Set $ S.fromList xs' + else nope, + from = anywhere $ Union LHole + } + ), + ( "elems_concat", + toRule $ + SubstRule + { to = + withExpContext $ + \[xs, ys] -> + yep $ Elems xs `union` Elems ys, + from = anywhere $ Elems $ Hole `Concat` Hole + } + ), + ( "elems_range", + toRule $ + SubstRule + { to = + withExpContext $ + \[from, step, to] -> + yep $ Range from step to, + from = anywhere $ Elems (Range Hole Hole Hole) + } + ), + ( "combine_range_l", + toRule $ + SubstRule + { to = + \[(ctx1, _), (ctx2, _)] -> \es -> + case es of + [from, step, to, e] -> + ifM + ((to ~+~ step) ^==^ e) + (yep $ ctx1 $ ctx2 $ Range from step e) + nope + _ -> nope, + from = anywhere $ InUnion $ Range Hole Hole Hole `U` Lit LHole + } + ), + ( "elems_array_lit", + toRule $ + SubstRule + { to = + withExpContext $ + yep . Lit . Set . S.fromList, + from = anywhere $ Elems $ Lit $ LHole + } + ), + ( "elems_union", + toRule $ + SubstRule + { to = + withExpContext $ + \[xs] -> do + i <- newVName "i" + yep $ + Unions + (Var i) + (intToExp 0 ... Len xs ~-~ intToExp 1) + (CNFExp cnfTrue) + (Idx xs (Var i)), + from = anywhere $ Elems Hole + } + ), + ( "len_range_unit_step", + toRule $ + SubstRule + { to = + withExpContext $ + \[from, to] -> yep $ to ~-~ from ~+~ intToExp 1, + from = anywhere $ Len (Range Hole (intToExp 1) Hole) + } + ), + ( "union_if", + toRule $ + SubstRule + { to = + \[(ctx1, _), (ctx2, _)] -> + \[i, range, CNFExp cond, CNFExp b, t, f] -> + yep $ + ctx1 $ + Unions i range (CNFExp $ cond &&& b) (ctx2 t) + `union` Unions i range (CNFExp $ cond &&& negateCNF negateProp b) (ctx2 f), + from = anywhere $ Unions Hole Hole Hole (InSum $ If Hole Hole Hole) + } + ), + ( "combine_unions", + toRule $ + SubstRule + { to = + \[(ctx1, _), (ctx2, _)] -> + \[i, range, CNFExp cond, CNFExp b, t, f] -> + yep $ + ctx1 $ + Unions i range (CNFExp $ cond &&& b) (ctx2 t) + `union` Unions i range (CNFExp $ cond &&& negateCNF negateProp b) (ctx2 f), + from = anywhere $ Unions Hole Hole Hole (InSum $ If Hole Hole Hole) + } + ), + ( "len", + toRule $ + SubstRule + { to = + withExpContext $ + \[xs] -> + case xs of + Var u -> do + withType u $ \t -> + toExp $ head $ E.shapeDims $ E.arrayShape t + Lit xs' -> pure $ Just $ intToExp $ fromIntegral $ litLength xs' + Idx xs' (Range start step end) -> + ifM + (step ^==^ intToExp 1) + (yep $ (end ~-~ start) ~+~ intToExp 1) + nope + _ -> pure Nothing, + from = anywhere $ Len Hole + } + ), + ( "iota_index", + toRule $ + SubstRule + { to = withExpContext $ + \[from, step, to, i] -> + ifM + (from ^==^ intToExp 0 ^&& step ^==^ intToExp 1) + (yep i) + nope, + from = anywhere $ Idx (Range Hole Hole Hole) Hole + } + ), + ( "map_index", + toRule $ + SubstRule + { to = + withExpContext $ + \[arr, i] -> + case arr of + Var arr' -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just fname <- getFun f, + "map" `L.isPrefixOf` fname, + E.Lambda params body _ _ _ : args' <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + let ps = map (\p -> let (E.Named x, _, _) = E.patternParam p in x) params + body' <- toExp body + argsm'' <- sequence <$> mapM toExp args' + case argsm'' of + Nothing -> pure Nothing + Just args'' -> + let subst = M.fromList $ zip ps (map (\arg -> Idx arg i) args'') + in pure $ fmap (SoP.substitute subst) body' + _ -> nope + _ -> nope, + from = anywhere $ Idx Hole Hole + } + ), + ( "map_unions", + toRule $ + SubstRule + { to = + withExpContext $ + \[i, set, conds, e] -> + case e of + Var arr' -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just fname <- getFun f, + "map" `L.isPrefixOf` fname, + E.Lambda params body _ _ _ : args' <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + let ps = map (\p -> let (E.Named x, _, _) = E.patternParam p in x) params + mbody' <- toExp body + argsm'' <- sequence <$> mapM toExp args' + case (argsm'', mbody') of + (Just args'', Just body') -> + let subst = M.fromList $ zip ps (map (\arg -> Idx arg i) args'') + body'' = SoP.substitute subst $ body' + in yep $ Unions i set conds body'' + _ -> nope + _ -> nope + _ -> nope, + from = anywhere $ Unions Hole Hole Hole Hole + } + ), + ( "map_concats", + toRule $ + SubstRule + { to = + withExpContext $ + \[i, set, conds, e] -> + case e of + Var arr' -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just fname <- getFun f, + "map" `L.isPrefixOf` fname, + E.Lambda params body _ _ _ : args' <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + let ps = map (\p -> let (E.Named x, _, _) = E.patternParam p in x) params + mbody' <- toExp body + argsm'' <- sequence <$> mapM toExp args' + case (argsm'', mbody') of + (Just args'', Just body') -> + let subst = M.fromList $ zip ps (map (\arg -> Idx arg i) args'') + body'' = SoP.substitute subst $ body' + in yep $ Concats i set conds body'' + _ -> nope + _ -> nope + _ -> nope, + from = anywhere $ Concats Hole Hole Hole Hole + } + ), + ( "unions_slice", + toRule $ + SubstRule + { to = + withExpContext $ + \[i, set, conds, xs, from, step, to, j] -> + yep $ + Unions + i + set + conds + (Idx xs $ from ~+~ i ~*~ step), + from = + anywhere $ + Unions + Hole + Hole + Hole + (Idx (Idx Hole (Range Hole Hole Hole)) Hole) + } + ), + ( "scatter_subeq", + Rule + { conditions = + \_ -> + \_ -> + \[l, r] -> + let dropWith (Without _ x) = x + dropWith x = x + in case dropWith l of + Var sct -> + withBinding sct $ \e -> + case e of + E.AppExp (E.Apply f args _) _ + | Just "scatter" <- getFun f, + [arr, idxs, vals] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + mvals' <- toExp vals + midxs' <- toExp idxs + case (mvals', midxs') of + (Just vals', Just idxs') + | vals' == r -> do + f <- mkFilter (\i -> atomCNF $ (Idx idxs' i :>= intToExp 0)) idxs' + yep $ Ordered f + _ -> nope + _ -> nope + _ -> nope, + conclusion = SubEq Hole Hole + } + ), + ( "scan_sum", + toRule $ + SubstRule + { to = + \[(ctx1, bv1), (ctx2, bv2)] -> + \args -> + case args of + [j, js, c, Var arr', i] -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just "scan" <- getFun f, + [E.OpSection (E.QualName [] vn) _ _, _, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args, + "+" <- E.baseString vn -> do + xsm <- toExp xs + case xsm of + Nothing -> pure Nothing + Just xs' -> do + k <- newVName "j" + yep $ + ctx1 $ + Unions j js c $ + ctx2 $ + Sigma + (Var k) + (intToExp 0 ... i) + (Idx xs' $ Var k) + _ -> nope + _ -> nope, + from = anywhere $ Unions Hole Hole Hole (anywhere $ Idx Hole Hole) + } + ), + ( "scan_sum_concat", + toRule $ + SubstRule + { to = + \[(ctx1, bv1), (ctx2, bv2)] -> + \args -> + case args of + [j, js, c, Var arr', i] -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just "scan" <- getFun f, + [E.OpSection (E.QualName [] vn) _ _, _, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args, + "+" <- E.baseString vn -> do + xsm <- toExp xs + case xsm of + Nothing -> pure Nothing + Just xs' -> do + k <- newVName "j" + yep $ + ctx1 $ + Concats j js c $ + ctx2 $ + Sigma + (Var k) + (intToExp 0 ... i) + (Idx xs' $ Var k) + _ -> nope + _ -> nope, + from = anywhere $ Concats Hole Hole Hole (anywhere $ Idx Hole Hole) + } + ), + ( "scan_sum", + toRule $ + SubstRule + { to = + withExpContext $ + \args -> + case args of + [Var arr', i] -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just "scan" <- getFun f, + [E.OpSection (E.QualName [] vn) _ _, _, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args, + "+" <- E.baseString vn -> do + xsm <- toExp xs + case xsm of + Nothing -> pure Nothing + Just xs' -> do + k <- newVName "j" + yep $ + Sigma + (Var k) + (intToExp 0 ... i) + (Idx xs' $ Var k) + _ -> nope + _ -> nope, + from = anywhere $ Idx Hole Hole + } + ), + ( "var", + toRule $ + SubstRule + { to = + withExpContext $ + \args -> + case args of + [Var x] -> + withBinding x $ \e -> + toExp e + _ -> nope, + from = anywhere $ Hole + } + ), + ( "split_sigma", + toRule $ + SubstRule + { to = withExpContext $ + \[i, set, e] -> + case e of + SoP sop + | isNothing $ SoP.justSingleTerm sop -> + let sums = map (Sigma i set . SoP . uncurry SoP.term2SoP) $ SoP.sopToLists sop + in yep $ foldl1 (~+~) sums + _ -> nope, + from = anywhere $ Sigma Hole Hole Hole + } + ), + ( "const_sigma", + toRule $ + SubstRule + { to = withExpContext $ + \[i, set, e] -> + case SoP.justConstant $ expToSoP e of + Just c -> + pure $ (intToExp c ~*~) <$> setSize set + _ -> nope, + from = anywhere $ Sigma Hole Hole Hole + } + ), + -- ( "const_concats", + -- toRule $ + -- SubstRule + -- { to = + -- withExpContext $ + -- \[i, set, conds, e] -> + -- if (constExp e) + -- then yep $ Lit $ Array [e] + -- else nope, + -- from = anywhere $ Concats Hole Hole Hole Hole + -- } + -- ), + ( "union_sigma_bool", -- fix + toRule $ + SubstRule + { to = + \[(ctx1, bv1), (ctx2, bv2)] -> + \[i, range, conds, j, jset, e] -> + case (range, e, jset) of + (Range from step to, BoolToInt arg, Range jstart jstep jend) + | jend == i, + jstart == intToExp 0, + jstep == intToExp 1, + from == intToExp 0, + step == intToExp 1, + constCtx ctx2 -> + -- error $ show (ctx2 $ intToExp 1) ++ "\n" ++ show (flatten (ctx2 $ intToExp 1)) + yep $ + ctx1 $ + Range + (ctx2 $ intToExp 1) + (intToExp 1) + ( ctx2 $ + Sigma + j + (Range jstart jstep to) + e + ) + _ -> nope, + from = + anywhere $ + Unions Hole Hole Hole $ + anywhere $ + Sigma Hole Hole Hole + } + ), + ( "mul_sigma", + toRule $ + SubstRule + { to = withExpContext $ + \[i, set, e] -> + case SoP.justSingleTerm $ expToSoP e of + Just (t, n) + | n /= 1 -> + pure $ Just $ intToExp n ~*~ Sigma i set (SoP $ SoP.term2SoP t 1) + _ -> pure Nothing, + from = anywhere $ Sigma Hole Hole Hole + } + ), + ( "from_if_sigma", -- generalize better + toRule $ + SubstRule + { to = withExpContext $ + \[i, Range from step to, CNFExp cond, t, f] -> + case (i, cnfToLists cond) of + (Var i, [[Var i' :== from]]) + | i == i' -> + yep $ t ~+~ Sigma (Var i) (Range (from ~+~ step) step to) f + _ -> nope, + from = anywhere $ Sigma Hole Hole (If Hole Hole Hole) + } + ), + ( "combine_sigma_and_term", + toRule $ + SubstRule + { to = + \[(ctx1, bv1), (ctx2, bv2)] -> + \[e, Var i, Range from step to, sigma_e] -> do + es <- massageInto i sigma_e e + case es of + [new_elem] -> + ifM + (new_elem ^>^ to) + ( yep $ + ctx1 $ + ctx2 $ + Sigma + (Var i) + ( Range from step to + `union` Lit (Set (S.singleton new_elem)) + ) + sigma_e + ) + nope + _ -> nope, + from = + anywhere $ + InSum $ + Hole :+: Sigma Hole Hole Hole + } + ), + ( "basic_combine_sigma", -- Should really be adding info...maybe the matching + toRule $ + SubstRule -- should populate the algebraic environment? + { to = + \[(ctx1, bv1), (ctx2, bv2)] -> + \args@[ Var u_i, + range, + cond, + Var x_i, + x_set, + x_e, + Var y_i, + y_set, + y_e + ] -> do + case (range, x_set, y_set) of + (Range _ _ u_end, Range x_start x_step x_end, Range y_start y_step y_end) + | and + [ x_e == SoP.substituteOne (y_i, Var x_i) y_e, + y_end == Var u_i, + x_end == u_end, + x_start == y_start + ] -> + yep $ + ctx1 $ + Unions (Var u_i) range cond $ + ctx2 $ + Sigma (Var x_i) ((Var u_i ~+~ intToExp 1) ... x_end) x_e + _ -> nope, + from = + anywhere $ + Unions Hole Hole Hole $ + InSum $ + Sigma Hole Hole Hole :-: Sigma Hole Hole Hole + } + ), + ( "i + sigma bool", + toRule $ + SubstRule + { to = + withExpContext $ + \args -> do + case args of + [u_i, range, cond, vi, y_j, y_set, y_e] -> do + case (flatten vi, range, y_e, y_set) of + (Var i, Range u_min u_step u_end, BoolToInt (Idx arr idx), Range (SoP y_start_sop) y_step y_end) + | Just (1, y_start_i, c) <- SoP.justAffine y_start_sop, + and + [ u_min == intToExp 0, + u_step == intToExp 1, + y_start_i == Var i, + u_end == y_end, + idx == y_j + ] -> do + let e_min = Sigma y_j ((SoP.substituteOne (i, intToExp 0) y_start_i) ... y_end) y_e + e_max = y_end + yep $ + e_min ... e_max + _ -> nope + _ -> error $ prettyString args, + from = + anywhere $ + Unions Hole Hole Hole $ + Hole :+: Sigma Hole Hole Hole + } + ), + ( "combine_ranges", + toRule $ + SubstRule + { to = + \[(ctx1, _), (ctx2, _)] -> + \[from1, step1, to1, from2, step2, to2] -> + ifM + ( andM + [ from1 ^<=^ from2, + from2 ^<=^ (to1 ~+~ intToExp 1), + step1 ^==^ step2 + ] + ) + ( yep $ ctx1 $ ctx2 $ Range from1 step1 to2 + ) + nope, + from = anywhere $ InUnion $ Range Hole Hole Hole `U` Range Hole Hole Hole + } + ), + ( "simplify_if", + toRule $ + SubstRule + { to = + withExpContext $ + \[CNFExp cnf, t, f] -> + ifM + (isTrue cnf) + (yep t) + ( ifM + (isFalse cnf) + (yep f) + nope + ), + from = + anywhere $ If Hole Hole Hole + } + ), + ( "split_on_if", + Rule + { conditions = + \gctx -> \[(ctx, bvs)] -> \[CNFExp p, t, f] -> + if (not . S.null $ varsOf p `S.intersection` bvs) + then nope + else do + yepM $ + addGoals + [ addInfo p >> (pure (gctx [ctx $ t])), + addInfo (negateCNF (negateProp) p) >> (pure (gctx [ctx $ f])) + ], + conclusion = + anyProp $ anywhere $ If Hole Hole Hole + } + ), + ( "ordered_concat_sigma", + Rule + { conditions -- fix + = + \_ -> \_ -> \[i, i_set, conds, j, j_set, e] -> + case (i_set, j_set) of + (Range i_start i_step i_end, Range j_start j_step j_end) + | j_end == i -> do + ifM + (e ^>=^ intToExp 0) + (yep $ Bool True) + nope + _ -> nope, + conclusion = + Ordered $ Concats Hole Hole Hole $ Sigma Hole Hole Hole + } + ), + ( "sgmSumInt'", -- To be removed once typechecker bug is fixed + toRule $ + SubstRule + { to = + withExpContext $ + \args -> + case args of + [Var arr', k] -> do + withBinding arr' $ \e -> do + case e of + E.AppExp (E.Apply f args _) _ + | Just "sgmSumInt'" <- getFun f, + [shp, flags, input] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + margs <- sequence <$> mapM toExp [shp, flags, input] + case margs of + Nothing -> error "" + Just [shp', flags', input'] -> do + i <- newVName "t" + j <- newVName "r" + l <- newVName "l" + + let segs x = do + j <- newVName "j" + pure $ Sigma (Var j) (intToExp 0 ... x ~-~ intToExp 1) shp' + + segs_start <- segs (Var i) + segs_end <- segs (Var i ~+~ intToExp 1) + segs_i <- segs $ Var i + let info = + ForEach + (Var i) + (intToExp 0 ... Len shp' ~-~ intToExp 2) + ( ForEach + (Var j) + (intToExp 0 ... segs_end ~-~ segs_start) + ( Idx (Var arr') (segs_i ~+~ Var j) + :== Sigma + (Var l) + (intToExp 0 ... Var l) + (Idx input' (segs_i ~+~ Var l)) + ) + ) + + -- modify $ \senv -> + -- senv + -- { known_map = + -- M.insert arr' [info] $ known_map senv + -- } + + -- km <- gets known_map + + test <- instantiateProp info (Idx (Var arr') k) + case test of + [] -> nope + ((l :== r) : _) -> yep $ r + _ -> error $ unlines $ map prettyString test + + -- pms <- instantiateProp info (Idx (Var arr') k) + -- case pms of + -- [] -> nope + -- ps -> + -- error $ + -- unlines + -- [ "sgmSumInt", + -- prettyString $ (Idx (Var arr') k), + -- prettyString i, + -- "\n" + -- ] + -- ++ unlines + -- (map prettyString ps) + _ -> nope + _ -> nope, + from = anywhere $ Idx Hole Hole + } + ) + -- addInfo :: CNF Prop -> CNFM () + + -- ( "sgmSumInt'", -- To be removed once typechecker bug is fixed + -- toRule $ + -- SubstRule + -- { to = + -- withExpContext $ + -- \args -> + -- case args of + -- [Var arr', i] -> do + -- withBinding arr' $ \e -> do + -- case e of + -- E.AppExp (E.Apply f args _) _ + -- | Just "sgmSumInt'" <- getFun f, + -- [shp, flags, xs] <- map ((\x -> fromMaybe x (E.stripExp x)) . snd) $ NE.toList args -> do + -- margs <- sequence <$> mapM toExp [shp, flags, xs] + -- case margs of + -- Nothing -> nope + -- Just [shp', flags', xs] -> + -- k <- newVName "j" + -- yep $ + -- Sigma + -- (Var k) + -- (intToExp 0 ... i) + -- (Idx xs' $ Var k) + -- _ -> nope + -- _ -> nope, + -- from = anywhere $ Idx Hole Hole + -- } + + -- + -- ) + -- ( "sigma_normalize_index", + -- toRule $ + -- SubstRule + -- { to = withExpContext $ + -- \[Var i, set, e] -> + -- case set of + -- Range from step to -> + -- ifM + -- (e `contains` (anywhere $ If Hole Hole Hole)) + -- nope + -- ( if (from == intToExp 0) + -- then nope + -- else + -- let range' = Range (intToExp 0) step (to ~-~ from) + -- e' = SoP.substituteOne (i, Var i ~+~ from) e + -- in yep $ Sigma (Var i) range' e' + -- ) + -- _ -> nope, + -- from = anywhere $ Sigma Hole Hole Hole + -- } + -- ) + ] + +anyProp = PHole + +class ToRule a where + toRule :: a -> Rule + +instance ToRule Rule where + toRule = id + +instance ToRule SubstRule where + toRule (SubstRule t f) = + Rule + { conditions = \pctx -> \ctxs -> \args -> + (fmap . fmap) (pctx . pure) $ t ctxs args, + conclusion = anyProp f + } + +addContextInfo :: (Monad m) => ExpContext -> RefineT m a -> RefineT m a +addContextInfo ctx m = + localS id $ do + astMap mapper (ctx Hole) + m + where + mapper = + ASTMapper + { mapOnLit = pure, + mapOnExp = \e -> + case e of + Concats i set (CNFExp cond) e + | [[prop]] <- cnfToLists cond -> do + modify $ \env -> env {known = known env ++ [prop]} + astMap mapper e + (Var x) -> pure $ Var x + _ -> astMap mapper e, + mapOnProp = pure + } + +withExpContext :: ([Exp] -> CNFM (Maybe Exp)) -> [(ExpContext, BoundVars)] -> [Exp] -> CNFM (Maybe Exp) +withExpContext f ctx_bvs holes = withExpContext' f (contractContexts ctx_bvs) holes + where + contractContexts ((ctx1, bvs1) : (ctx2, bvs2) : rest) = + contractContexts (((ctx1 . ctx2), bvs1 <> bvs2) : rest) + contractContexts [ctx_bv] = [ctx_bv] + + withExpContext' f [] holes = f holes + withExpContext' f [(ctx, _)] holes = (fmap . fmap) ctx $ f holes + withExpContext' _ ctxs es = + error $ + unlines + [ prettyString es, + prettyString $ map (($ Hole) . fst) ctxs + ] + +withBinding :: (Monad m) => E.VName -> (E.Exp -> RefineT m (Maybe a)) -> RefineT m (Maybe a) +withBinding x f = do + me <- lookupVName x + case me of + Just e -> f e + Nothing -> pure Nothing + +withType :: (Monad m) => E.VName -> (E.PatType -> RefineT m (Maybe a)) -> RefineT m (Maybe a) +withType x f = do + me <- lookupVName x + mt <- lookupType x + case (me, mt) of + (Just e, _) -> f $ E.typeOf e + (_, Just t) -> f t + _ -> pure Nothing + +mkFilter :: (Monad m) => (Exp -> CNF Prop) -> Exp -> RefineT m Exp +mkFilter p xs = do + i <- newVName "i" + pure $ Concats (Var i) (intToExp 0 ... Len xs) (CNFExp $ p $ Var i) xs + +-------------------------------------------------------------------------------- +-- Matching +-------------------------------------------------------------------------------- + +type Pattern = Exp + +-- Need to know which Exps belong to which Context +-- Need a more sophisticated handling of bound variables, e.g. by a Sigma +-- Each argument needs to know about bound vars... +data MatchRes = MatchRes + { match_ctxs :: [(ExpContext, BoundVars)], + match_exps :: [Exp], + match_known :: [Prop] + } + +instance Semigroup MatchRes where + MatchRes cs xs ks <> MatchRes ds ys ts = MatchRes (cs <> ds) (xs <> ys) (ks <> ts) + +instance Monoid MatchRes where + mempty = MatchRes mempty mempty mempty + +instance Show MatchRes where + show (MatchRes ctxs exps ks) = + unlines + [ "Contexts:", + prettyString $ map (first ($ Hole)) ctxs, + "Exps:", + prettyString exps, + "Known:", + prettyString ks + ] + +class Match a b c where + match :: (Monad m) => a -> b -> RefineT m [c] + +instance Match Pattern Exp MatchRes where + match :: forall m. (Monad m) => Pattern -> Exp -> RefineT m [MatchRes] + match x x' + | x == x' = pure [mempty] + match Hole x = pure [MatchRes mempty [x] mempty] + -- match x Hole = pure [MatchRes mempty [x] mempty] + match x (PExp p) = do + -- TODO: fix + (pms :: [PropMatch]) <- match (PHole x) p + mapM doOneMatch pms + where + doOneMatch pm = do + let m = prop_exp_match pm + -- pctxs = map (\pctx -> ((\e -> PExp $ pctx [e]), mempty)) $ prop_match_ctxs pm -- fix + [pctx] = prop_match_ctxs pm -- fix + -- pctxs = ((\e -> PExp $ pctx [e]), mempty) + ctxs = + case match_ctxs m of + [] -> [] + ((c, bvs) : ctxs) -> (\e -> PExp $ pctx [c e], bvs) : ctxs + + pure $ m {match_ctxs = ctxs} + match (Lit x) (Lit x') = matchLit x x' + where + matchLit (Array xs) (Array xs') = + matchs $ zip xs xs' + matchLit (Set xs) (Set xs') = + matchs $ zip (S.toList xs) (S.toList xs') + matchLit LHole (Array xs) = + pure [MatchRes mempty xs mempty] + matchLit LHole (Set xs) = + pure [MatchRes mempty (S.toList xs) mempty] + match (SoP sop) (SoP sop') = do + -- concat + -- <$> sequence + -- [reject <$> zipWithM matchTerm (SoP.sopToList sop) cand | cand <- L.permutations $ SoP.sopToList sop'] + + concat + <$> zipWithM matchTerm (SoP.sopToList sop) (SoP.sopToList sop') + where + -- concat + -- <$> zipWithM matchTerm (SoP.sopToList sop) (SoP.sopToList sop') + + -- reject ms + -- -- \| any null ms = mempty + -- | otherwise = mconcat ms + matchTerm (t, n) (t', n') + -- \| [Hole] <- SoP.termToList t = + -- pure [MatchRes mempty [foldl (~*~) (intToExp n') $ SoP.termToList t'] mempty] + -- \| [Hole] <- SoP.termToList t' = + -- pure [MatchRes mempty [foldl (~*~) (intToExp n) $ SoP.termToList t] mempty] + | n == n' = + matchs $ zip (SoP.termToList t) (SoP.termToList t') + | otherwise = pure mempty + match (Concat x y) (Concat x' y') = + matchs [(x, x'), (y, y')] + match (Elems x) (Elems x') = + match x x' + match (Union x) (Union x') = + match (Lit x) (Lit x') + match (Idx x y) (Idx x' y') = + matchs [(x, x'), (y, y')] + match (Unions x y z k) (Unions x' y' z' k') = + matchs [(x, x'), (y, y'), (z, z'), (k, k')] + match (Concats x y z k) (Concats x' y' z' k') = + matchs [(x, x'), (y, y'), (z, z'), (k, k')] + match (Range x y z) (Range x' y' z') = + matchs [(x, x'), (y, y'), (z, z')] + match (Len x) (Len x') = + match x x' + match (If x y z) (If x' y' z') = + matchs [(x, x'), (y, y'), (z, z')] + match (Sigma x y z) (Sigma x' y' z') = + -- addBoundVarsM (varsOf x') $ -- TODO FIX + matchs [(x, x'), (y, y'), (z, z')] + match (Nested x) x' = + matchNested x x' + match (InSum x) (SoP x') = + concat <$> sequence [doMatch ctx summands cand | (cand, ctx) <- sop_splits] + where + summands = flattenPlus x + sop_splits = + [L.splitAt (length summands) perm | perm <- L.permutations $ SoP.sopToLists x'] + doMatch ctx summands cand = + map (addExpContext (~+~ (SoP $ SoP.sopFromList ctx)) mempty) <$> match x (flatten $ SoP $ SoP.sopFromList cand) + match (InSum x) x' = + map (addExpContext id mempty) <$> match x x' + match e@(x :+: y) (SoP sop) + | length summands == length (SoP.sopToLists sop) = do + concat <$> sequence [matchs $ zip summands cand | cand <- L.permutations terms] + where + summands = flattenPlus e + + processTerm (_, 0) = intToExp 0 + processTerm ([], 1) = intToExp 1 + processTerm (ts, 1) = foldl1 (:*:) ts + processTerm (ts, n) = foldl (:*:) (intToExp n) ts + + terms = map processTerm $ SoP.sopToLists sop + match e@(x :-: y) z = + match (x :+: (intToExp (-1) :*: y)) z + match (x :*: y) (x' :*: y') = + matchs [(x, x'), (y, y')] + match e@(x :*: y) (SoP sop) + | [(ts, n)] <- SoP.sopToLists sop, + length ts + 1 == length multiplicands = do + let terms = intToExp n : ts + + concat <$> sequence [matchs $ zip cand terms | cand <- L.permutations multiplicands] + where + multiplicands = flattenMult e + match (InUnion x) (Union es) = + concat <$> sequence [doMatch ctx unionands cand | (cand, ctx) <- union_splits] + where + unionands = flattenUnion x + union_splits = + [L.splitAt (length unionands) perm | perm <- L.permutations $ litToList es] + doMatch ctx summands cand = + map (addExpContext (union (Lit $ Set $ S.fromList ctx)) mempty) <$> match x (Union $ Set $ S.fromList cand) + match (InUnion x) x' = + map (addExpContext id mempty) <$> match x x' + match e@(U x y) (Union es) + | length unionands == litLength es = do + concat <$> sequence [matchs $ zip unionands cand | cand <- L.permutations $ litToList es] + where + unionands = flattenUnion e + match _ _ = pure mempty + +matchs :: (Monad m) => [(Pattern, Exp)] -> RefineT m [MatchRes] +matchs [] = pure [mempty] +matchs ((x, x') : rest) = do + mx <- match x x' + mrest <- matchs rest + pure [m <> m' | m <- mx, m' <- mrest] + +addExpContext :: ExpContext -> BoundVars -> MatchRes -> MatchRes +addExpContext ctx bvs (MatchRes ctxs exps ks) = MatchRes ((ctx, bvs) : ctxs) exps ks + +-- Fix +addKnown :: CNF Prop -> MatchRes -> MatchRes +addKnown p (MatchRes ctx exps ks) + | [[prop]] <- cnfToLists p = + MatchRes ctx exps (ks ++ [prop]) + | otherwise = error $ prettyString p + +addBoundVarsM :: (Monad m) => BoundVars -> RefineT m [MatchRes] -> RefineT m [MatchRes] +addBoundVarsM bvs ms = + (fmap . fmap) f ms + where + f (MatchRes [] exps kn) = MatchRes [(id, bvs)] exps kn + f m = addBoundVars bvs m + +addBoundVars :: BoundVars -> MatchRes -> MatchRes +addBoundVars bvs m = + m + { match_ctxs = + map (second $ (<> bvs)) $ match_ctxs m + } + +matchNested :: forall m. (Monad m) => Pattern -> Exp -> RefineT m [MatchRes] +matchNested p exp = tryMatch id mempty exp + where + noMatch :: RefineT m [MatchRes] + noMatch = pure mempty + + tryMatch :: ExpContext -> BoundVars -> Exp -> RefineT m [MatchRes] + tryMatch ctx bvs x = do + topMatches <- match p x + nestedMatches <- doNested ctx bvs x + pure $ (map (addExpContext ctx bvs) topMatches) <> nestedMatches + + matchNestedLit :: ExpContext -> BoundVars -> Lit -> RefineT m [MatchRes] + matchNestedLit ctx bvs (Array xs) = do + matchMany (Lit . Array) ctx bvs xs + matchNestedLit ctx bvs (Set xs) = do + matchMany (Lit . Set . S.fromList) ctx bvs (S.toList xs) + + doNested :: ExpContext -> BoundVars -> Exp -> RefineT m [MatchRes] + doNested ctx bvs (Lit x) = matchNestedLit ctx bvs x + doNested ctx bvs Empty = noMatch + doNested ctx bvs Var {} = noMatch + doNested ctx bvs (SoP sop) = + matchSoP ctx bvs sop + doNested ctx bvs (Concat x y) = + matchTwo Concat ctx bvs x y + doNested ctx bvs (Elems x) = + matchOne Elems ctx bvs x + doNested ctx bvs (Union (Set xs)) = + matchMany (Union . Set . S.fromList) ctx bvs (S.toList xs) + doNested ctx bvs (Idx x y) = + matchTwo Idx ctx bvs x y + doNested ctx bvs (Unions x y z k) = + addBoundVarsM (varsOf x) $ + matchFour Unions ctx bvs x y z k + doNested ctx bvs (Concats x y (CNFExp cnf) k) = + asumM + [ matchThree (\x' y' z' -> Concats x' y' z' k) ctx bvs x y (CNFExp cnf), + (map $ addKnown cnf) <$> tryMatch (ctx . Concats x y (CNFExp cnf)) bvs k + ] + doNested ctx bvs (Range x y z) = + matchThree Range ctx bvs x y z + doNested ctx bvs (Len x) = + matchOne Len ctx bvs x + doNested ctx bvs (If x y z) = + matchThree If ctx bvs x y z + doNested ctx bvs (Sigma x y z) = + addBoundVarsM (varsOf x) $ + matchThree Sigma ctx bvs x y z + doNested ctx bvs (BoolToInt x) = + matchOne BoolToInt ctx bvs x + doNested ctx bvs (Without x y) = + matchTwo Without ctx bvs x y + doNested ctx bvs (CNFExp cnf) + | justAnds cnf = + -- traceM $ unlines ["Pattern:", prettyString p, show p] + -- exps_map <- gets exps + concat <$> mapM (uncurry doMatch) cands + where + -- traceM $ + -- unlines + -- [ prettyString cnf, + -- show ms, + -- prettyString cands + -- ] + -- pure ms + + cands = [(p', ps') | (p' : ps') <- L.permutations $ concat $ cnfToLists cnf] + doMatch cand ps' = do + -- traceM $ + -- unlines + -- [ "Trying match", + -- prettyString p, + -- prettyString cand + -- ] + pms <- match (PHole $ Nested p) cand + let cnfctx prop = + CNFExp $ listsToCNF $ [prop] : map pure ps' + pure $ map (fix cnfctx) pms + fix cnfctx (PropMatch [pctx] (MatchRes [] args ks)) = + MatchRes [(\e -> ctx $ cnfctx $ pctx $ pure $ e, mempty)] args ks + fix cnfctx (PropMatch [pctx] (MatchRes ctxs args ks)) = + MatchRes (map (\(ctx', bvs) -> (\e -> ctx $ cnfctx $ pctx $ pure $ ctx' e, bvs)) ctxs) args ks + doNested _ _ _ = noMatch + + matchMany :: ([Exp] -> Exp) -> ExpContext -> BoundVars -> [Exp] -> RefineT m [MatchRes] + matchMany mkExp ctx bvs = matchMany' [] + where + matchMany' _ [] = noMatch + matchMany' xs (y : ys) = + asumM + [ tryMatch (\y' -> ctx (mkExp $ xs ++ [y'] ++ ys)) bvs y, + matchMany' (xs ++ [y]) ys + ] + + matchOne mkExp ctx bvs x = + matchMany (\[x'] -> mkExp x') ctx bvs [x] + + matchTwo mkExp ctx bvs x y = + matchMany (\[x', y'] -> mkExp x' y') ctx bvs [x, y] + + matchThree mkExp ctx bvs x y z = + matchMany (\[x', y', z'] -> mkExp x' y' z') ctx bvs [x, y, z] + + matchFour mkExp ctx bvs x y z k = + matchMany (\[x', y', z', k'] -> mkExp x' y' z' k') ctx bvs [x, y, z, k] + + matchSoP :: ExpContext -> BoundVars -> SoP Exp -> RefineT m [MatchRes] + matchSoP ctx bvs = + matchSummond [] . SoP.sopToLists + where + matchSummond _ [] = noMatch + matchSummond xs ((y, n) : ys) = + asumM + [ matchTerm (\term -> ctx $ SoP $ SoP.sopFromList $ xs ++ [(term, n)] ++ ys) [] y, + matchSummond (xs ++ [(y, n)]) ys + ] + + matchTerm ctx _ [] = pure mempty + matchTerm ctx xs (y : ys) = + asumM + [ tryMatch (\rexp -> ctx $ xs ++ [rexp] ++ ys) bvs y, + matchTerm ctx (xs ++ [y]) ys + ] + +-- PropContext is to enable 'anyProp' +data PropMatch = PropMatch + { prop_match_ctxs :: [PropContext], + prop_exp_match :: MatchRes + } + +instance Semigroup PropMatch where + PropMatch cs xs <> PropMatch ds ys = PropMatch (cs <> ds) (xs <> ys) + +instance Monoid PropMatch where + mempty = PropMatch mempty mempty + +instance Show PropMatch where + show (PropMatch pctxs m) = show m + +-- instance Match Exp Prop PropMatch where + +instance Match Prop Prop PropMatch where + match = matchProp' + where + matchProp' (x :< y) (x' :< y') = + withPropContext (listifyTwo (:<)) $ + matchs [(x, x'), (y, y')] + matchProp' (x :<= y) (x' :<= y') = + withPropContext (listifyTwo (:<=)) $ + matchs [(x, x'), (y, y')] + matchProp' (x :> y) (x' :> y') = + withPropContext (listifyTwo (:>)) $ + matchs [(x, x'), (y, y')] + matchProp' (x :>= y) (x' :>= y') = + withPropContext (listifyTwo (:>=)) $ + matchs [(x, x'), (y, y')] + matchProp' (x :== y) (x' :== y') = + withPropContext (listifyTwo (:==)) $ + matchs [(x, x'), (y, y')] + matchProp' (x :/= y) (x' :/= y') = + withPropContext (listifyTwo (:/=)) $ + matchs [(x, x'), (y, y')] + matchProp' (PermutationOf x y) (PermutationOf x' y') = + withPropContext (listifyTwo PermutationOf) $ + matchs [(x, x'), (y, y')] + matchProp' (SubsetEq x y) (SubsetEq x' y') = + withPropContext (listifyTwo SubsetEq) $ + matchs [(x, x'), (y, y')] + matchProp' (SubEq x y) (SubEq x' y') = + withPropContext (listifyTwo SubEq) $ + matchs [(x, x'), (y, y')] + matchProp' (Ordered x) (Ordered x') = + withPropContext (listifyOne Ordered) $ + matchs [(x, x')] + matchProp' (Not x) (Not x') = + addPropContext Not $ + match x x' + matchProp' (Bool x) (Bool x') + | x == x' = + withPropContext (const $ Bool x) $ + pure [mempty] + matchProp' (GExp x) (GExp x') = + withPropContext (listifyOne GExp) $ + match x x' + matchProp' (ForEach i set p) (ForEach i' set' p') = do + ms <- matchs [(i, i'), (set, set')] + pms <- match p p' + pure + [ PropMatch + pctx + (m <> pmatches) + | PropMatch pctx pmatches <- pms, + m <- ms + ] + matchProp' (Axiom _) (Axiom _) = + withPropContext (const $ Bool True) $ + pure [mempty] + -- [ PropMatch + -- ( ( \xs -> case xs of + -- (x : y : rest) -> ForEach x y (pctx rest) + -- _ -> error $ prettyString xs + -- ) + -- : pctxs + -- ) + -- (m <> pmatches) + -- | PropMatch (pctx : pctxs) pmatches <- pms, + -- m <- ms + -- ] + matchProp' (PHole x) p = do + let pctx_xs = contextAndExps p + -- case (null $ head $ map snd pctx_xs) of + -- True -> pure mempty + -- _ -> do + -- -- (m :: [MatchRes]) <- match x (last $ head $ map snd pctx_xs) -- Rather, can be a prop + -- traceM $ + -- unlines + -- [ "pctx_xs", + -- prettyString $ map snd pctx_xs, + -- show (m :: [MatchRes]), + -- "x", + -- prettyString x, + -- prettyString (last $ head $ map snd pctx_xs), + -- show (last $ head $ map snd pctx_xs) + -- ] + concat <$> mapM (\(pctx, xs) -> matchHole pctx [] xs) pctx_xs + where + -- matchHole :: PropContext -> [Exp] -> [Exp] -> RefineT m [PropMatch] + matchHole pctx rs [] = pure [] + matchHole pctx rs (e : es) = do + ms <- match x e + rest <- matchHole pctx (rs ++ [e]) es + let pctx' es' = pctx $ rs ++ es' ++ es + pure $ map (PropMatch [pctx']) ms ++ rest + + contextAndExps :: Prop -> [(PropContext, [Exp])] + contextAndExps (x :< y) = [(listifyTwo (:<), [x, y])] + contextAndExps (x :<= y) = [(listifyTwo (:<=), [x, y])] + contextAndExps (x :> y) = [(listifyTwo (:>), [x, y])] + contextAndExps (x :>= y) = [(listifyTwo (:>=), [x, y])] + contextAndExps (x :== y) = [(listifyTwo (:==), [x, y])] + contextAndExps (x :/= y) = [(listifyTwo (:/=), [x, y])] + contextAndExps (PermutationOf x y) = + [(listifyTwo PermutationOf, [x, y])] + contextAndExps (SubsetEq x y) = + [(listifyTwo SubsetEq, [x, y])] + contextAndExps (SubEq x y) = + [(listifyTwo SubEq, [x, y])] + contextAndExps (Ordered x) = [(listifyOne Ordered, [x])] + contextAndExps (Not p) = + map (\(ctx, ex) -> (Not . ctx, ex)) $ contextAndExps p + contextAndExps p@Bool {} = [(const p, mempty)] + contextAndExps (GExp x) = [(listifyOne GExp, [x])] + contextAndExps (PHole x) = [(error "phole", [x])] + -- contextAndExps (ForEach i set pred) = + -- [(listifyTwo (\i' set' -> ForEach i' set' pred), [i, set])] + + contextAndExps (ForEach i set pred) = + [(listifyThree (\i' set' (PExp pred') -> ForEach i' set' pred'), [i, set, PExp pred])] + -- [(listifyTwo (\set' (PExp pred') -> ForEach i set' pred'), [set, PExp pred])] + contextAndExps p = error $ prettyString p + -- Hack to make instantiateForEach work + -- matchProp' p (PHole x) = do + -- let pctx_xs = contextAndExps p + -- concat <$> mapM (\(pctx, xs) -> matchHole pctx [] xs) pctx_xs + -- where + -- matchHole :: PropContext -> [Exp] -> [Exp] -> RefineT m [PropMatch] + -- matchHole pctx rs [] = pure [] + -- matchHole pctx rs (e : es) = do + -- ms <- match e x + -- rest <- matchHole pctx (rs ++ [e]) es + -- let pctx' es' = pctx $ rs ++ es' ++ es + -- pure $ map (PropMatch [pctx']) ms ++ rest + + -- contextAndExps :: Prop -> [(PropContext, [Exp])] + -- contextAndExps (x :< y) = [(listifyTwo (:<), [x, y])] + -- contextAndExps (x :<= y) = [(listifyTwo (:<=), [x, y])] + -- contextAndExps (x :> y) = [(listifyTwo (:>), [x, y])] + -- contextAndExps (x :>= y) = [(listifyTwo (:>=), [x, y])] + -- contextAndExps (x :== y) = [(listifyTwo (:==), [x, y])] + -- contextAndExps (x :/= y) = [(listifyTwo (:/=), [x, y])] + -- contextAndExps (PermutationOf x y) = + -- [(listifyTwo PermutationOf, [x, y])] + -- contextAndExps (SubsetEq x y) = + -- [(listifyTwo SubsetEq, [x, y])] + -- contextAndExps (SubEq x y) = + -- [(listifyTwo SubEq, [x, y])] + -- contextAndExps (Ordered x) = [(listifyOne Ordered, [x])] + -- contextAndExps (Not p) = + -- map (\(ctx, ex) -> (Not . ctx, ex)) $ contextAndExps p + -- contextAndExps p@Bool {} = [(const p, mempty)] + -- contextAndExps (GExp x) = [(listifyOne GExp, [x])] + -- contextAndExps (PHole x) = [(error "phole", [x])] + -- contextAndExps (ForEach i set pred) = + -- [(listifyThree (\i' set' (PExp pred') -> ForEach i' set' pred'), [i, set, PExp pred])] + -- contextAndExps p = error $ prettyString p + matchProp' _ _ = pure mempty + + withPropContext pctx = + fmap $ map (PropMatch [pctx]) + + addPropContext pctx = + fmap $ map (\(PropMatch pctx' ms) -> PropMatch (map (pctx .) pctx') ms) + + listifyOne f [x] = f x + + listifyTwo f [x, y] = f x y + + listifyThree f [x, y, z] = f x y z + +matchRule :: Rule -> Prop -> CNFM (Maybe Prop) +matchRule r g = + match (conclusion r) g >>= tryMatches + where + tryMatches [] = pure Nothing + tryMatches (PropMatch [pctx] (MatchRes ctxs exps ks) : pms) = do + mg <- do + -- FIX THIS + -- localS (\senv -> senv {known = known senv ++ ks}) $ + modify (\senv -> senv {known = known senv ++ ks}) + g <- conditions r pctx ctxs exps + modify (\senv -> senv {known = known senv L.\\ ks}) + pure g + case mg of + Nothing -> tryMatches pms + _ -> pure mg -- conditions r pctx ctxs exps + +applyRules :: [(String, Rule)] -> Prop -> CNFM (Maybe (String, Prop)) +applyRules [] _ = pure Nothing +applyRules ((label, r) : rs) g = do + mg' <- matchRule r g + case mg' of + Nothing -> applyRules rs g + Just g' -> pure $ Just (label, g') + +contains :: forall m. (Monad m) => Exp -> Exp -> RefineT m Bool +contains e p = (not . null) <$> (match p e :: RefineT m [MatchRes]) + +-- instantiateForEach :: (Monad m) => Prop -> Exp -> RefineT m [Prop] +-- instantiateForEach fe e = do +-- instantiateForEach' vars [p] +-- where +-- flattenForEach (ForEach (Var j) set_j k) = +-- let (vars, p_inner) = flattenForEach k +-- in ((j, set_j) : vars, p_inner) +-- flattenForEach p' = (mempty, p') +-- +-- (vars, p) = flattenForEach fe +-- +-- instantiateForEach' [] ps = pure $ ps +-- instantiateForEach' ((i, set) : res) ps = do +-- ps' <- concat <$> mapM (matchAndInstantiate i e) ps -- can't be recursive +-- instantiateForEach' res ps' +-- +---- Assumes all Holes are from the same (free) variable +-- matchAndInstantiate :: (Monad m) => E.VName -> Exp -> Prop -> RefineT m [Prop] +-- matchAndInstantiate i e p = do +-- let p' = SoP.substituteOne (i, Hole) p +-- ms <- match (PHole e) p' +-- traceM $ +-- unlines +-- [ prettyString e, +-- prettyString p', +-- show ms +-- ] +-- let es = map (head . match_exps . prop_exp_match) $ filter (\m -> length (L.nub $ match_exps $ prop_exp_match m) == 1) ms +-- cands = map (\e' -> SoP.substituteOne (i, e') p) es +-- pure cands + +-- ms <- filter (\z -> length z >= 1) <$> ((fmap . fmap) match_exps $ match p' e) + +massageInto :: (Monad m) => E.VName -> Exp -> Exp -> RefineT m [Exp] +massageInto x p e = do + ms <- filter (\z -> length z >= 1) <$> ((fmap . fmap) match_exps $ match p' e) + when (length ms > 1) $ + error $ + unlines + [ prettyString x, + prettyString p, + prettyString e, + show ms, + show $ length ms + ] + pure $ concat $ filter (\m -> length (L.nub m) == 1) ms + where + p' = SoP.substituteOne (x, Hole) p + +-- pms <- matchProp p' (PHole e) +-- let props = +-- map +-- ( \pm -> +-- let [pctx] = prop_match_ctxs pm +-- holes = match_exps $ prop_exp_match pm +-- ctx = +-- case map fst $ match_ctxs $ prop_exp_match pm of +-- [ctx'] -> ctx +-- _ -> id +-- in pctx $ map ctx $ holes +-- ) +-- pms +-- error $ +-- unlines +-- [ prettyString p', +-- prettyString props, +-- prettyString $ ForEach (Var i) set p, +-- prettyString e, +-- prettyString hole_candidates +-- ] + +-- instantiateForEach :: Monad m => Prop -> Exp -> RefineT m [Exp] +-- instantiateForEach prop e = do +-- +-- +-- where flattenForEach (ForEach i iset p@(ForEach{})) = +-- let (rest, p') = flattenForEach p +-- in ((i, iset) : rest, p') +-- flattenForEach (ForEach i iset p) = +-- ([(i, iset)], p) +-- (vars, prop') = flattenForEach prop + +-- convertMapBody :: Monad m => [E.Exp] -> E.Exp -> RefineT m (Maybe Exp) +-- convertMapBody args body = do +-- let (:arrays) = map ((\x -> fromMaybe x (E.stripExp x)) . snd) +-- $ NE.toList args +-- case args' of +-- (E.Lambda params +-- | E.Lambda params body _ _ _ : args' <- -> do + +sopToHole sop = foldr1 (:+:) $ replicate (SoP.numTerms sop) Hole + +holifyMapper = + identityMapper + { mapOnExp = + \e -> + case e of + SoP sop -> pure $ sopToHole sop + e -> pure $ Hole + } + +holify = idMap holifyMapper + +combinations [] = pure mempty +combinations [ms] = ms +combinations (ms : mss) = do + m <- ms + m' <- combinations mss + pure $ m <> m' + +instantiateProp :: (Monad m) => Prop -> Exp -> RefineT m [Prop] +instantiateProp = + instantiateProp' mempty + where + instantiateProp' free (ForEach (Var i) set p) e = + instantiateProp' (S.insert i free) p e + instantiateProp' free p e = do + ms <- reversePropMatch free p e + pure $ L.nub $ map (flip SoP.substitute p) ms + +-- instantiateForEach :: (Monad m) => Prop -> Exp -> RefineT m [Prop] +-- instantiateForEach fe e = do +-- instantiateForEach' vars [p] +-- where +-- flattenForEach (ForEach (Var j) set_j k) = +-- let (vars, p_inner) = flattenForEach k +-- in ((j, set_j) : vars, p_inner) +-- flattenForEach p' = (mempty, p') +-- +-- (vars, p) = flattenForEach fe +-- +-- instantiateForEach' [] ps = pure $ ps +-- instantiateForEach' ((i, set) : res) ps = do +-- ps' <- concat <$> mapM (matchAndInstantiate i e) ps -- can't be recursive +-- instantiateForEach' res ps' + +reversePropMatch :: forall m. (Monad m) => Set E.VName -> Prop -> Exp -> RefineT m [M.Map E.VName Exp] +reversePropMatch free p e = do + -- Match @e@ with holes on all arguments. + (pattern_matches :: [PropMatch]) <- match (PHole $ Nested $ holify e) p + -- Figure out what the arguments actually were matched + (e_args :: [MatchRes]) <- match (holify e) e + -- Form potential matching pairs + let match_pairs = [(prop_exp_match pm, em) | pm <- pattern_matches, em <- e_args] + -- traceM $ + -- unlines $ + -- [ "p", + -- prettyString p, + -- "e", + -- prettyString e, + -- "holify e", + -- prettyString $ holify e, + -- "pattern_matches", + -- show pattern_matches, + -- "e_args", + -- show e_args + -- ] + res <- + (L.nub . concat) + <$> mapM + ( \(pm, em) -> + assign free e pm em + ) + match_pairs + -- error $ unlines $ (prettyString . map M.toList) res + pure res + +---- Match @e@ with holes on all arguments. +-- (ms :: [PropMatch]) <- match (PHole $ Nested $ holify e) p +---- Figure out what the arguments actually were matched +-- (e_args :: [MatchRes]) <- match (holify e) e +-- let exps = L.nub [(match_exps $ prop_exp_match m1, match_exps m2) | m1 <- ms, m2 <- e_args] +-- traceM $ +-- unlines +-- [ "propmatch", +-- "p", +-- prettyString p, +-- "e", +-- prettyString e +-- ] +---- undefined +---- error $ show $ length exps +-- res <- +-- (L.nub . concat) +-- <$> mapM +-- ( \(m1, m2) -> +-- combinations <$> zipWithM (assign free e) m1 m2 +-- ) +-- exps +---- error $ show $ length res -- (prettyString . map M.toList) res +-- pure res + +assign :: forall m. (Monad m) => Set E.VName -> Exp -> MatchRes -> MatchRes -> RefineT m [M.Map E.VName Exp] +assign free e pm em = do + traceM $ "free':" <> prettyString free' + -- combinations <$> zipWithM processExp p_exps e_exps + doExp [(mempty, free')] (zip p_exps e_exps) + where + (p_exps, p_bvs) = toExpBVs pm + (e_exps, _) = toExpBVs em + free' = free <> mconcat p_bvs + toExpBVs m = + (match_exps m, map snd $ match_ctxs m) + + doExp :: [(M.Map E.VName Exp, Set E.VName)] -> [(Exp, Exp)] -> RefineT m [M.Map E.VName Exp] + doExp msf [] = pure $ map fst msf + doExp msf ((p, e') : rest) = + concat + <$> ( forM msf $ + \(m, f) -> do + new_maps <- processExp m f (SoP.substitute m p) (SoP.substitute m e') + let msf' = do + new <- new_maps + let f' = f S.\\ S.fromList (M.keys new) + pure (m <> new, f') + doExp msf' rest + ) + + processExp u f (Var i) (Var j) + | i == j = pure [mempty] + | i `S.member` free' && not (j `S.member` free') = pure [M.singleton i (Var j)] + | j `S.member` free' && not (i `S.member` free') = pure [M.singleton j (Var i)] + | otherwise = pure [] + processExp u f (Var i) y + | i `S.member` free' = pure [M.singleton i y] + | otherwise = pure [] + processExp u f x y = do + ifM + (x ^==^ y) + (pure [mempty]) + (if (e == y) then pure [] else reverseMatch free' x y) + +-- assign :: (Monad m) => Set E.VName -> Exp -> Exp -> Exp -> RefineT m [M.Map E.VName Exp] +-- assign free e (Var i) (Var j) +-- | i == j = pure [mempty] +-- | i `S.member` free && not (j `S.member` free) = pure [M.singleton i (Var j)] +-- | j `S.member` free && not (i `S.member` free) = pure [M.singleton j (Var i)] +-- | otherwise = pure [] +-- assign free e (Var i) y +-- | i `S.member` free = pure [M.singleton i y] +-- | otherwise = pure [] +-- assign free e x y = do +-- ifM +-- (x ^==^ y) +-- (pure [mempty]) +-- (if (e == y) then pure [] else reverseMatch free x y) + +-- The pattern is larger than the expression and we +-- attempt to instantiate the pattern in such a way that +-- the expression then appears within the pattern +reverseMatch :: forall m. (Monad m) => Set E.VName -> Pattern -> Exp -> RefineT m [M.Map E.VName Exp] +reverseMatch free p e = do + -- Match @e@ with holes on all arguments. + (pattern_matches :: [MatchRes]) <- match (Nested $ holify e) p + -- Figure out what the arguments actually were matched + (e_args :: [MatchRes]) <- match (holify e) e + -- Form potential matching pairs + let match_pairs = [(pm, em) | pm <- pattern_matches, em <- e_args] + traceM $ + unlines $ + [ "p", + prettyString p, + "e", + prettyString e, + "holify e", + prettyString $ holify e, + "pattern_matches", + show pattern_matches, + "e_args", + show e_args + ] + res <- + (L.nub . concat) + <$> mapM + ( \(pm, em) -> + assign free e pm em + ) + match_pairs + -- error $ unlines $ (prettyString . map M.toList) res + pure res + +-- reverseMatch' free p e +-- where +-- reverseMatch' _ p e +-- | p == e = pure [mempty] +-- reverseMatch' free (Idx xs i) e@(Idx xs' i') = +-- reverseAndSubMatches free [(xs, xs'), (i, i')] +-- -- reverseMatches free [(xs, xs'), (i, i')] +-- reverseMatch' free (Var x) e +-- | x `S.member` free = +-- pure [M.singleton x e] +-- | otherwise = pure mempty +-- -- Expensive to match +-- reverseMatch' free (SoP sop) e +-- | Just x <- SoP.justSym sop = +-- reverseMatch' free x e +-- reverseMatch' free (SoP sop) (SoP sop') = +-- (concat . concat) +-- <$> sequence +-- [ zipWithM matchTerm (SoP.sopToList sop) cand' +-- | cand <- S.toList $ S.filter ((== SoP.numTerms sop') . S.size) $ S.powerSet $ S.fromList $ SoP.sopToList sop', +-- cand' <- L.permutations $ S.toList cand +-- ] +-- where +-- matchTerm (t, n) (t', n') +-- | SoP.isConstTerm t', +-- [Var x] <- SoP.termToList t, +-- x `S.member` free, +-- n == 1 = +-- pure [M.singleton x $ intToExp n'] +-- | n == n' = +-- reverseMatches free $ zip (SoP.termToList t) (SoP.termToList t') +-- | otherwise = pure mempty +-- reverseMatch' free (SoP sop) e = do +-- concat <$> mapM (\term -> reverseMatch' free (SoP $ uncurry SoP.term2SoP term) e) (SoP.sopToList sop) +-- reverseMatch' _ _ _ = pure mempty + +-- combinations [] = pure mempty +-- combinations [ms] = ms +-- combinations (ms : mss) = do +-- m <- ms +-- m' <- combinations mss +-- pure $ m <> m' + +-- reverseAndSubMatches :: Set E.VName -> [(Pattern, Exp)] -> RefineT m [M.Map E.VName Exp] +-- reverseAndSubMatches free pes = +-- (++) +-- <$> reverseMatches free pes +-- <*> (concat <$> mapM (flip (reverseMatch free) e) (map fst pes)) + +-- reverseMatches :: Set E.VName -> [(Pattern, Exp)] -> RefineT m [M.Map E.VName Exp] +-- reverseMatches free pes = +-- combinations <$> mapM (uncurry $ reverseMatch free) pes diff --git a/src/Futhark/CLI/Dataset.hs b/src/Futhark/CLI/Dataset.hs index c3b615a441..92cea0e3da 100644 --- a/src/Futhark/CLI/Dataset.hs +++ b/src/Futhark/CLI/Dataset.hs @@ -208,6 +208,7 @@ toValueType (TEVar (QualName [] v) _) f t = (nameFromText (V.primTypeText t), t) toValueType (TEVar v _) = Left $ "Unknown type " <> prettyText v +toValueType TERefine {} = error "TERefine not implemented in toValueType" -- | Closed interval, as in @System.Random@. type Range a = (a, a) diff --git a/src/Futhark/CLI/Main.hs b/src/Futhark/CLI/Main.hs index a637cb1b2d..8f424e55ea 100644 --- a/src/Futhark/CLI/Main.hs +++ b/src/Futhark/CLI/Main.hs @@ -29,6 +29,7 @@ import Futhark.CLI.PyOpenCL qualified as PyOpenCL import Futhark.CLI.Python qualified as Python import Futhark.CLI.Query qualified as Query import Futhark.CLI.REPL qualified as REPL +import Futhark.CLI.Refinement qualified as Refinement import Futhark.CLI.Run qualified as Run import Futhark.CLI.Test qualified as Test import Futhark.CLI.WASM qualified as WASM @@ -69,6 +70,7 @@ commands = ("doc", (Doc.main, "Generate documentation for Futhark code.")), ("pkg", (Pkg.main, "Manage local packages.")), ("check", (Check.main, "Type-check a program.")), + ("refinement", (Refinement.main, "Perform refinement checking.")), ("check-syntax", (Misc.mainCheckSyntax, "Syntax-check a program.")), ("imports", (Misc.mainImports, "Print all non-builtin imported Futhark files.")), ("hash", (Misc.mainHash, "Print hash of program AST.")), diff --git a/src/Futhark/CLI/Refinement.hs b/src/Futhark/CLI/Refinement.hs new file mode 100644 index 0000000000..f15f7503ef --- /dev/null +++ b/src/Futhark/CLI/Refinement.hs @@ -0,0 +1,91 @@ +module Futhark.CLI.Refinement (main) where + +import Control.Monad +import Control.Monad.IO.Class +import Data.List qualified as L +import Futhark.Analysis.Refinement +import Futhark.Analysis.Refinement.Latex +import Futhark.Compiler +import Futhark.Util.Options +import Futhark.Util.Pretty (hPutDoc) +import Language.Futhark.Warnings +import System.IO + +data RefineConfig = RefineConfig + { printSuccesses :: Bool, + checkWarn :: Bool, + printAlg :: Bool, + printInfos :: Bool, + laTeX :: Maybe FilePath + } + +newRefineConfig :: RefineConfig +newRefineConfig = RefineConfig False True False False Nothing + +options :: [FunOptDescr RefineConfig] +options = + [ Option + "l" + ["filepath"] + ( ReqArg + (\fp -> Right $ \config -> config {laTeX = Just fp}) + "FILEPATH" + ) + "Print LaTeX trace." + ] + +-- [ Option +-- "v" +-- [] +-- (NoArg $ Right $ \cfg -> cfg {printSuccesses = True}) +-- "Print all checks.", +-- Option +-- "w" +-- [] +-- (NoArg $ Right $ \cfg -> cfg {checkWarn = False}) +-- "Disable all typechecker warnings.", +-- Option +-- "a" +-- [] +-- (NoArg $ Right $ \cfg -> cfg {printAlg = True}) +-- "Print the algebraic environment.", +-- Option +-- "i" +-- [] +-- (NoArg $ Right $ \cfg -> cfg {printSuccesses = True}) +-- "Print info." +-- ] + +-- | Run @futhark refinement@. +main :: String -> [String] -> IO () +main = mainWithOptions newRefineConfig options "program" $ \args cfg -> + case args of + [file] -> Just $ do + (warnings, imps, vns) <- readProgramOrDie file + when (checkWarn cfg && anyWarnings warnings) $ + liftIO $ + hPutDoc stderr $ + prettyWarnings warnings + -- putStrLn $ "Proved: " <> show (refineProg vns imps) + -- putStrLn $ unlines (refineProg vns imps) + let Just basename = reverse <$> L.stripPrefix (reverse ".fut") (reverse file) + res = refineProg vns imps + + -- mapM_ (putStrLn . prettyString) (refineProg vns imps) + + case laTeX cfg of + Just fp -> mkLaTeX fp res + _ -> pure () + -- putStrLn $ unlines (refineProg vns imps) + -- let (_, algenv, log) = refineProg vns imps + -- when (printInfos cfg) $ do + -- putStrLn "Info:" + -- liftIO $ mapM_ putStrLn $ infos log + -- putStrLn "Failed checks:" + -- liftIO $ mapM_ putStrLn $ fails log + -- when (printSuccesses cfg) $ do + -- putStrLn "Successful checks:" + -- liftIO $ mapM_ putStrLn $ successes log + -- when (printAlg cfg) $ do + -- liftIO $ putStrLn $ prettyString algenv + _ -> Nothing diff --git a/src/Futhark/Doc/Generator.hs b/src/Futhark/Doc/Generator.hs index a7c075dc62..ec613b9c0a 100644 --- a/src/Futhark/Doc/Generator.hs +++ b/src/Futhark/Doc/Generator.hs @@ -531,6 +531,9 @@ typeHtml t = case t of where ppClause (n, ts) = joinBy " " . (ppConstr n :) <$> mapM typeHtml ts ppConstr name = "#" <> toHtml (nameToString name) + Scalar (Refinement ty e) -> do + tyH <- typeHtml ty + pure $ "{" <> tyH <> "| " <> toHtml (prettyString e) <> "}" retTypeHtml :: StructRetType -> DocM Html retTypeHtml (RetType [] t) = typeHtml t @@ -665,6 +668,9 @@ typeExpHtml e = case e of TEDim dims t _ -> do t' <- typeExpHtml t pure $ "?" <> mconcat (map (brackets . renderName . baseName) dims) <> "." <> t' + TERefine te p _ -> do + teH <- typeExpHtml te + pure $ "{" <> teH <> "| " <> toHtml (prettyString p) <> "}" qualNameHtml :: QualName VName -> DocM Html qualNameHtml (QualName names vname@(VName name tag)) = diff --git a/src/Futhark/FreshNames.hs b/src/Futhark/FreshNames.hs index 5b0d5171ee..eb77e2ba18 100644 --- a/src/Futhark/FreshNames.hs +++ b/src/Futhark/FreshNames.hs @@ -1,6 +1,6 @@ -- | This module provides facilities for generating unique names. module Futhark.FreshNames - ( VNameSource, + ( VNameSource (..), blankNameSource, newNameSource, newName, diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 848745bca7..3a813d4f8e 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -173,6 +173,7 @@ replaceStaticValSizes globals orig_substs sv = TEParens (onTypeExp substs te) loc onTypeExp _ (TEVar v loc) = TEVar v loc + onTypeExp _ TERefine {} = error "TERefine not implemented in replaceStaticValSizes" onEnv substs = M.fromList @@ -295,6 +296,7 @@ arraySizes :: StructType -> S.Set VName arraySizes (Scalar Arrow {}) = mempty arraySizes (Scalar (Record fields)) = foldMap arraySizes fields arraySizes (Scalar (Sum cs)) = foldMap (foldMap arraySizes) cs +arraySizes (Scalar (Refinement ty _)) = arraySizes ty arraySizes (Scalar (TypeVar _ _ _ targs)) = mconcat $ map f targs where @@ -722,6 +724,7 @@ defuncExp (Constr name es (Info sum_t@(Scalar (Sum all_fs))) loc) = do defuncScalar (Sum fs) = Sum $ M.map (map defuncType) fs defuncScalar (Prim t) = Prim t defuncScalar (TypeVar as u tn targs) = TypeVar as u tn targs + defuncScalar (Refinement ty e) = Refinement (defuncType ty) e defuncExp (Constr name _ (Info t) loc) = error $ "Constructor " diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 7e99ee3d5d..0f6e687b54 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -2204,6 +2204,7 @@ typeExpForError (E.TESum cs _) = do onClause c = do c' <- mapM typeExpForError c pure $ intercalate [" "] c' +typeExpForError E.TERefine {} = error "TERefine not implemented in typeExpForError" sizeExpForError :: E.SizeExp Info VName -> InternaliseM [ErrorMsgPart SubExp] sizeExpForError (SizeExp e _) = do diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index 6b45910935..4605e92c38 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -476,6 +476,7 @@ transformTypeSizes typ = onArg (TypeArgDim dim) = TypeArgDim <$> onDim dim onArg (TypeArgType ty) = TypeArgType <$> transformTypeSizes ty transformScalarSizes ty@Prim {} = pure ty + transformScalarSizes Refinement {} = error "Refinement not implemented in transformTypeSizes" onDim e | e == anySize = pure e @@ -531,6 +532,7 @@ transformTypeExp (TESum cs loc) = TESum <$> traverse (traverse (traverse transformTypeExp)) cs <*> pure loc transformTypeExp (TEDim dims te loc) = TEDim dims <$> transformTypeExp te <*> pure loc +transformTypeExp TERefine {} = error "TERefine not implemented in transformTypeExp" -- This carries out record replacements in the alias information of a type. -- diff --git a/src/Futhark/Internalise/Refinement.hs b/src/Futhark/Internalise/Refinement.hs new file mode 100644 index 0000000000..939bf6c933 --- /dev/null +++ b/src/Futhark/Internalise/Refinement.hs @@ -0,0 +1,168 @@ +module Futhark.Internalise.Refinement (transformProg) where + +import Control.Monad +import Control.Monad.RWS (MonadReader (..), MonadWriter (..), RWS, asks, lift, runRWS) +import Data.List (find) +import Data.Maybe +import Debug.Trace +import Futhark.Analysis.PrimExp (PrimExp) +import Futhark.Analysis.PrimExp qualified as PE +import Futhark.Internalise.TypesValues (internalisePrimType, internalisePrimValue) +import Futhark.MonadFreshNames +import Futhark.SoP.Convert +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.Monad +import Futhark.SoP.Refine +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.Prop +import Language.Futhark.Semantic hiding (Env) + +type Env = () + +newtype RefineM a + = RefineM (SoPMT VName Exp (RWS Env () VNameSource) a) + deriving + ( Functor, + Applicative, + Monad, + MonadReader Env, + MonadSoP VName Exp + ) + +instance MonadFreshNames RefineM where + getNameSource = RefineM $ getNameSource + putNameSource = RefineM . putNameSource + +checkExp :: Exp -> RefineM Bool +checkExp e = do + (_, sop) <- toSoPCmp e + sop $>=$ zeroSoP + +runRefineM :: VNameSource -> RefineM a -> (a, AlgEnv VName Exp, VNameSource) +runRefineM src (RefineM m) = + let ((a, algenv), src', _) = runRWS (runSoPMT_ m) mempty src + in (a, algenv, src') + +considerSlice :: PatType -> Slice -> RefineM Bool +considerSlice (Array _ _ (Shape ds) _) is = + and <$> zipWithM check ds is + where + inBounds :: Exp -> Exp -> RefineM Bool + inBounds d i = do + d' <- toSoPNum_ d + i' <- toSoPNum_ i + andM + [ zeroSoP $<=$ i', + i' $<$ d' + ] + check :: Size -> DimIndexBase Info VName -> RefineM Bool + check _ (DimSlice Nothing Nothing Nothing) = + pure True + check d (DimFix i) = + inBounds d i + check d (DimSlice (Just start) (Just end) Nothing) = do + d' <- toSoPNum_ d + start' <- toSoPNum_ start + end' <- toSoPNum_ end + andM + [ inBounds d start, + end' $<=$ d', + start' $<=$ end' + ] + check d (DimSlice (Just i) Nothing Nothing) = + inBounds d i + check d (DimSlice Nothing (Just j) Nothing) = do + d' <- toSoPNum_ d + j' <- toSoPNum_ j + j' $<=$ d' + check _ (DimSlice Nothing Nothing (Just stride)) = do + stride' <- toSoPNum_ stride + stride' $/=$ zeroSoP + check d (DimSlice (Just i) Nothing (Just s)) = do + s' <- toSoPNum_ s + andM + [ inBounds d i, + zeroSoP $<=$ s' + ] + check d (DimSlice Nothing (Just j) (Just s)) = do + d' <- toSoPNum_ d + j' <- toSoPNum_ j + s' <- toSoPNum_ s + andM + [ j' $<=$ d', + zeroSoP $<=$ s' + ] + check d (DimSlice (Just i) (Just j) (Just s)) = do + d' <- toSoPNum_ d + i' <- toSoPNum_ i + j' <- toSoPNum_ j + s' <- toSoPNum_ s + let nonzero_stride = s' $/=$ zeroSoP + ok_or_empty = n $==$ zeroSoP ^|| slice_ok + slice_ok = backwards ^&& backwards_ok ^|| forwards_ok + backwards_ok = + andM + [ int2SoP (-1) $<=$ j', + j' $<=$ i', + zeroSoP $<=$ i_p_m_t_s, + i_p_m_t_s $<=$ d' + ] + forwards_ok = + andM + [ zeroSoP $<=$ i', + i' $<=$ j', + zeroSoP $<=$ i_p_m_t_s, + i_p_m_t_s $<$ d' + ] + backwards = fromJust (signumSoP s') $==$ int2SoP (-1) + i_p_m_t_s = i' .+. m .*. s' + m = n .-. int2SoP 1 + n = fromJust $ (j' .-. i') `divSoPInt` s' + andM + [ nonzero_stride, + ok_or_empty + ] +considerSlice t _ = error $ "considerSlice: not an array " <> show t + +mkUnsafe :: Exp -> Exp +mkUnsafe e = + Attr (AttrAtom (AtomName "unsafe") mempty) e mempty + +transformExp :: Exp -> RefineM Exp +transformExp (Assert cond e t loc) = do + e' <- transformExp e + safe <- checkExp cond + if safe + then pure e' + else pure $ Assert cond e' t loc +transformExp e@(AppExp (Index arr slice loc) res) = do + arr' <- transformExp arr + b <- considerSlice (typeOf arr) slice + let e' = AppExp (Index arr' slice loc) res + if b + then pure $ mkUnsafe e' + else pure e' +transformExp e = pure e + +transformValBind :: ValBind -> RefineM ValBind +transformValBind vb = do + body <- transformExp $ valBindBody vb + pure $ vb {valBindBody = body} + +transformDec :: Dec -> RefineM Dec +transformDec (ValDec vb) = ValDec <$> transformValBind vb +transformDec d = pure d + +transformImport :: (ImportName, FileModule) -> RefineM (ImportName, FileModule) +transformImport (name, imp) = do + let p = fileProg imp + decs <- mapM transformDec $ progDecs p + pure $ (name, imp {fileProg = p {progDecs = decs}}) + +transformProg :: MonadFreshNames m => Imports -> m Imports +transformProg prog = modifyNameSource $ \namesrc -> + let (prog', _, namesrc') = runRefineM namesrc $ mapM transformImport prog + in (prog', namesrc') diff --git a/src/Futhark/Internalise/TypesValues.hs b/src/Futhark/Internalise/TypesValues.hs index 84325bbeac..ee8c14af2e 100644 --- a/src/Futhark/Internalise/TypesValues.hs +++ b/src/Futhark/Internalise/TypesValues.hs @@ -274,6 +274,8 @@ internaliseTypeM exts orig_t = error $ "internaliseTypeM: cannot handle type variable: " ++ prettyString orig_t E.Scalar E.Arrow {} -> error $ "internaliseTypeM: cannot handle function type: " ++ prettyString orig_t + E.Scalar E.Refinement {} -> + error $ "internaliseTypeM: refinement type not implemented: " ++ prettyString orig_t E.Scalar (E.Sum cs) -> do (ts, _) <- internaliseConstructors diff --git a/src/Futhark/SoP/Convert.hs b/src/Futhark/SoP/Convert.hs new file mode 100644 index 0000000000..507bfaef3e --- /dev/null +++ b/src/Futhark/SoP/Convert.hs @@ -0,0 +1,202 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} + +-- | Translating to-and-from PrimExp to the sum-of-product representation. +module Futhark.SoP.Convert + ( FromSoP (..), + ToSoP (..), + toSoPNum_, + toSoPCmp_, + ) +where + +import Control.Monad.State +import Data.List (find) +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.PrimExp (PrimExp, PrimType, (~*~), (~+~), (~-~), (~/~), (~==~)) +import Futhark.Analysis.PrimExp qualified as PE +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark.Core +import Language.Futhark.Prop +import Language.Futhark.Syntax (VName) +import Language.Futhark.Syntax qualified as E + +toSoPNum_ :: (ToSoP u e, MonadSoP u e m) => e -> m (SoP u) +toSoPNum_ e = snd <$> toSoPNum e + +toSoPCmp_ :: (ToSoP u e, MonadSoP u e m) => e -> m (SoP u >= 0) +toSoPCmp_ e = snd <$> toSoPNum e + +-- | Conversion from 'SoP's to other representations. +class FromSoP u e where + fromSoP :: MonadSoP u e m => SoP u -> m e + +-- instance Ord u => FromSoP u (PrimExp u) where +-- fromSoP sop = +-- foldr ((~+~) . fromTerm) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 (0 :: Integer)) (sopToLists sop) +-- where +-- fromTerm (term, n) = +-- foldl (~*~) (PE.ValueExp $ PE.IntValue $ PE.intValue PE.Int64 n) $ +-- map fromSym term +-- fromSym sym = PE.LeafExp sym $ PE.IntType PE.Int64 + +-- | Conversion from some expressions to +-- 'SoP's. Monadic because it may involve look-ups in the +-- untranslatable expression environment. +class ToSoP u e where + toSoPNum :: MonadSoP u e m => e -> m (Integer, SoP u) + +instance (Nameable u, Ord u, Show u, Pretty u) => ToSoP u Integer where + toSoPNum x = pure (1, int2SoP x) + +instance (Nameable u, Ord u, Show u, Pretty u) => ToSoP u (PrimExp u) where + toSoPNum primExp = do + (f, sop) <- toSoPNum' 1 primExp + pure (abs f, signum f `scaleSoP` sop) + where + notIntType :: PrimType -> Bool + notIntType (PE.IntType _) = False + notIntType _ = True + + divideIsh :: PE.BinOp -> Bool + divideIsh (PE.UDiv _ _) = True + divideIsh (PE.UDivUp _ _) = True + divideIsh (PE.SDiv _ _) = True + divideIsh (PE.SDivUp _ _) = True + divideIsh (PE.FDiv _) = True + divideIsh _ = False + toSoPNum' _ pe + | notIntType (PE.primExpType pe) = + error "toSoPNum' applied to a PrimExp whose prim type is not Integer" + toSoPNum' f (PE.LeafExp vnm _) = + pure (f, sym2SoP vnm) + toSoPNum' f (PE.ValueExp (PE.IntValue iv)) = + pure (1, int2SoP $ getIntVal iv `div` f) + where + getIntVal :: PE.IntValue -> Integer + getIntVal (PE.Int8Value v) = fromIntegral v + getIntVal (PE.Int16Value v) = fromIntegral v + getIntVal (PE.Int32Value v) = fromIntegral v + getIntVal (PE.Int64Value v) = fromIntegral v + toSoPNum' f (PE.UnOpExp PE.Complement {} x) = do + (f', x_sop) <- toSoPNum' f x + pure (f', negSoP x_sop) + toSoPNum' f (PE.BinOpExp PE.Add {} x y) = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + let l_c_m = lcm x_f y_f + (x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f) + x_sop' = mulSoPs (int2SoP x_m) x_sop + y_sop' = mulSoPs (int2SoP y_m) y_sop + pure (f * l_c_m, addSoPs x_sop' y_sop') + toSoPNum' f (PE.BinOpExp PE.Sub {} x y) = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + let l_c_m = lcm x_f y_f + (x_m, y_m) = (l_c_m `div` x_f, l_c_m `div` y_f) + x_sop' = mulSoPs (int2SoP x_m) x_sop + n_y_sop' = mulSoPs (int2SoP (-y_m)) y_sop + pure (f * l_c_m, addSoPs x_sop' n_y_sop') + toSoPNum' f pe@(PE.BinOpExp PE.Mul {} x y) = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + case (x_f, y_f) of + (1, 1) -> pure (f, mulSoPs x_sop y_sop) + _ -> do + x' <- lookupUntransPE pe + toSoPNum' f $ PE.LeafExp x' $ PE.primExpType pe + -- pe / 1 == pe + toSoPNum' f (PE.BinOpExp divish pe q) + | divideIsh divish && PE.oneIshExp q = + toSoPNum' f pe + -- evaluate `val_x / val_y` + toSoPNum' f (PE.BinOpExp divish x y) + | divideIsh divish, + PE.ValueExp v_x <- x, + PE.ValueExp v_y <- y = do + let f' = v_x `vdiv` v_y + toSoPNum' f $ PE.ValueExp f' + -- Trivial simplifications: + -- (y * v) / y = v and (u * y) / y = u + | divideIsh divish, + PE.BinOpExp (PE.Mul _ _) u v <- x, + (is_fst, is_snd) <- (u == y, v == y), + is_fst || is_snd = do + toSoPNum' f $ if is_fst then v else u + where + vdiv (PE.IntValue (PE.Int64Value x')) (PE.IntValue (PE.Int64Value y')) = + PE.IntValue $ PE.Int64Value (x' `div` y') + vdiv (PE.IntValue (PE.Int32Value x')) (PE.IntValue (PE.Int32Value y')) = + PE.IntValue $ PE.Int32Value (x' `div` y') + vdiv (PE.IntValue (PE.Int16Value x')) (PE.IntValue (PE.Int16Value y')) = + PE.IntValue $ PE.Int16Value (x' `div` y') + vdiv (PE.IntValue (PE.Int8Value x')) (PE.IntValue (PE.Int8Value y')) = + PE.IntValue $ PE.Int8Value (x' `div` y') + -- vdiv (FloatValue (Float32Value x)) (FloatValue (Float32Value y)) = + -- FloatValue $ Float32Value $ x / y + -- vdiv (FloatValue (Float64Value x)) (FloatValue (Float64Value y)) = + -- FloatValue $ Float64Value $ x / y + vdiv _ _ = error "In vdiv: illegal type for division!" + -- try heuristic for exact division + toSoPNum' f pe@(PE.BinOpExp divish x y) + | divideIsh divish = do + (x_f, x_sop) <- toSoPNum x + (y_f, y_sop) <- toSoPNum y + case (x_f, y_f, divSoPs x_sop y_sop) of + (1, 1, Just res) -> pure (f, res) + _ -> do + x' <- lookupUntransPE pe + toSoPNum' f $ PE.LeafExp x' $ PE.primExpType pe + -- Anything that is not handled by specific cases of toSoPNum' + -- is handled by this default procedure: + -- If the target `pe` is in the unknwon `env` + -- Then return thecorresponding binding + -- Else make a fresh symbol `v`, bind it in the environment + -- and return it. + toSoPNum' f pe = do + x <- lookupUntransPE pe + toSoPNum' f $ PE.LeafExp x $ PE.primExpType pe + +instance ToSoP VName Exp where + toSoPNum (E.Literal v _) = + (pure . (1,)) $ + case v of + E.SignedValue x -> int2SoP $ PE.valueIntegral x + E.UnsignedValue x -> int2SoP $ PE.valueIntegral x + _ -> error "" + toSoPNum (E.IntLit v _ _) = pure (1, int2SoP v) + toSoPNum (E.Var (E.QualName [] v) _ _) = pure (1, sym2SoP v) + toSoPNum e@(E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) + | E.baseTag (E.qualLeaf op) <= maxIntrinsicTag, + name <- E.baseString $ E.qualLeaf op, + Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = do + (_, x) <- toSoPNum e_x + (_, y) <- toSoPNum e_y + (1,) + <$> case bop of + E.Plus -> pure $ x .+. y + E.Minus -> pure $ x .-. y + E.Times -> pure $ x .*. y + _ -> sym2SoP <$> lookupUntransPE e + toSoPNum e = do + x <- lookupUntransPE e + pure (1, sym2SoP x) + +-- +-- {-- +---- This is a more refined treatment, but probably +---- an overkill (harmful if you get the type wrong) +-- fromSym unknowns sym +-- | Nothing <- M.lookup sym (dir unknowns) = +-- LeafExp sym $ IntType Integer +-- | Just pe1 <- M.lookup sym (dir unknowns), +-- IntType Integer <- PE.primExpType pe1 = +-- pe1 +-- fromSym unknowns sym = +-- error ("Type error in fromSym: type of " ++ +-- show sym ++ " is not Integer") +----} diff --git a/src/Futhark/SoP/Expression.hs b/src/Futhark/SoP/Expression.hs new file mode 100644 index 0000000000..4ea784eceb --- /dev/null +++ b/src/Futhark/SoP/Expression.hs @@ -0,0 +1,116 @@ +{-# LANGUAGE DataKinds #-} + +module Futhark.SoP.Expression + ( Expression (..), + processExps, + ) +where + +import Data.List (find) +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.PrimExp +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark qualified as E +import Language.Futhark.Prop + +class Expression e where + -- -- | Is this 'PrimType' not integral? + -- notIntType :: PrimType -> Bool + + -- | Is this expression @mod@? + moduloIsh :: e -> Maybe (e, e) + + -- -- | Is this 'PrimExp' @<@? + -- lthishType :: CmpOp -> Maybe IntType + + -- -- | Is this 'PrimExp' @<=@? + -- leqishType :: CmpOp -> Maybe IntType + + -- | Rewrite a mod expression into division. + divInsteadOfMod :: e -> e + + -- | Algebraically manipulates an 'e' into a set of equality + -- and inequality constraints. + processExp :: e -> (Set (e == 0), Set (e >= 0)) + +processExps :: (Ord e, Expression e, Foldable t) => t e -> (Set (e == 0), Set (e >= 0)) +processExps = foldMap processExp + +instance Expression Exp where + moduloIsh (E.AppExp (E.BinOp (op, _) _ (e_x, _) (e_y, _) _) _) + | E.baseTag (E.qualLeaf op) <= maxIntrinsicTag, + name <- E.baseString $ E.qualLeaf op, + Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp], + E.Mod <- bop = + Just (e_x, e_y) + moduloIsh _ = Nothing + +instance Ord u => Expression (PrimExp u) where + moduloIsh :: PrimExp u -> Maybe (PrimExp u, PrimExp u) + moduloIsh (BinOpExp (SMod _ _) pe1 pe2) = Just (pe1, pe2) + moduloIsh (BinOpExp (UMod _ _) pe1 pe2) = Just (pe1, pe2) + moduloIsh _ = Nothing + + processExp :: PrimExp u -> (Set (PrimExp u == 0), Set (PrimExp u >= 0)) + processExp (CmpOpExp (CmpEq ptp) x y) + -- x = y => x - y = 0 + | IntType {} <- ptp = + (S.singleton (x ~-~ y), mempty) + processExp (CmpOpExp lessop x y) + -- x < y => x + 1 <= y => y >= x + 1 => y - (x+1) >= 0 + | Just itp <- lthishType lessop = + let pe = y ~-~ (x ~+~ ValueExp (IntValue $ intValue itp (1 :: Integer))) + in (mempty, S.singleton pe) + -- x <= y => y >= x => y - x >= 0 + | Just _ <- leqishType lessop = + (mempty, S.singleton $ y ~-~ x) + where + -- Is this 'PrimExp' @<@? + lthishType :: CmpOp -> Maybe IntType + lthishType (CmpSlt itp) = Just itp + lthishType (CmpUlt itp) = Just itp + lthishType _ = Nothing + + -- Is this 'PrimExp' @<=@? + leqishType :: CmpOp -> Maybe IntType + leqishType (CmpUle itp) = Just itp + leqishType (CmpSle itp) = Just itp + leqishType _ = Nothing + processExp (BinOpExp LogAnd x y) = + processExps [x, y] + processExp (CmpOpExp CmpEq {} pe1 pe2) = + case (pe1, pe2) of + -- (x && y) == True => x && y + (BinOpExp LogAnd _ _, ValueExp (BoolValue True)) -> + processExp pe1 + -- True == (x && y) => x && y + (ValueExp (BoolValue True), BinOpExp LogAnd _ _) -> + processExp pe2 + -- (x || y) == False => !x && !y + (BinOpExp LogOr x y, ValueExp (BoolValue False)) -> + processExps [UnOpExp Not x, UnOpExp Not y] + -- False == (x || y) => !x && !y + (ValueExp (BoolValue False), BinOpExp LogOr x y) -> + processExps [UnOpExp Not x, UnOpExp Not y] + _ -> mempty + processExp (UnOpExp Not pe) = + case pe of + -- !(!x) => x + UnOpExp Not x -> + processExp x + -- !(x < y) => y <= x + CmpOpExp (CmpSlt itp) x y -> + processExp $ CmpOpExp (CmpSle itp) y x + -- !(x <= y) => y < x + CmpOpExp (CmpSle itp) x y -> + processExp $ CmpOpExp (CmpSlt itp) y x + -- !(x < y) => y <= x + CmpOpExp (CmpUlt itp) x y -> + processExp $ CmpOpExp (CmpUle itp) y x + -- !(x <= y) => y < x + CmpOpExp (CmpUle itp) x y -> + processExp $ CmpOpExp (CmpUlt itp) y x + _ -> mempty + processExp _ = mempty diff --git a/src/Futhark/SoP/FourierMotzkin.hs b/src/Futhark/SoP/FourierMotzkin.hs new file mode 100644 index 0000000000..c14dd81610 --- /dev/null +++ b/src/Futhark/SoP/FourierMotzkin.hs @@ -0,0 +1,132 @@ +-- | A sum-of-product representation of integral algebraic expressions +-- formed with addition and multiplication. +-- +-- This representation is intended to allow to statically querry whether +-- a symbolic expression is less than (or equal to) zero by means of +-- Fourier-Motzkin elimination, i.e., if we know that the range of +-- `i` is `[l, u]`, then the querry `a*i + b <= 0` +-- can be solved (recursively) by solving the four subproblems: +-- `( (a <= 0) && (a*l + b <= 0) ) || ( (a >= 0) && (a*u + b <= 0) )` +-- +-- A possible future extension of the representation and algebra is to +-- support min/max operations (as well). This would require (i) adding +-- a min-max (outermost) layer to the current representation and +-- (ii) extending the algebra. For example, translating a multiplication +-- such as `min(e1,e2) * sop` from PrimExp to sum-of-products form +-- would require to use Fourier-Motzkin to determine the sign of `sop`: +-- if positive than this is equivalent with `min(e1 * sop, e2 * sop)` and +-- if negative it is equivalent to `max(e1 * sop, e2 * sop)`, where +-- `e1 * sop` and `e2 * sop` are recursively translated. +module Futhark.SoP.FourierMotzkin + ( fmSolveLTh0, + fmSolveLEq0, + fmSolveGTh0, + fmSolveGEq0, + ($<$), + ($<=$), + ($>$), + ($>=$), + ($==$), + ($/=$), + ) +where + +import Data.Set qualified as S +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty + +--------------------------------------------- +--- Fourier-Motzkin elimination algorithm --- +--- for solving inequation of the form: --- +--- `a*i + b <= 0` and `a*i + b < 0` --- +--- assumming we have an environment of --- +--- ranges containing `i =[l,u]` inclusive--- +--------------------------------------------- + +-- | Solves the inequation `sop < 0` by reducing it to +-- `sop + 1 <= 0`, where `sop` denotes an expression +-- in sum-of-product form. +fmSolveLTh0 :: MonadSoP u e m => SoP u -> m Bool +fmSolveLTh0 = fmSolveLEq0 . (.+. int2SoP 1) + +-- | Solves the inequation `sop > 0` by reducing it to +-- `(-1)*sop < 0`, where `sop` denotes an expression +-- in sum-of-product form. +fmSolveGTh0 :: MonadSoP u e m => SoP u -> m Bool +fmSolveGTh0 = fmSolveLTh0 . negSoP + +-- | Solves the inequation `sop >= 0` by reducing it to +-- `(-1)*sop <= 0`, where `sop` denotes an expression +-- in sum-of-product form. +fmSolveGEq0 :: MonadSoP u e m => SoP u -> m Bool +fmSolveGEq0 = fmSolveLEq0 . negSoP + +-- | Assuming `sop` an expression in sum-of-products (SoP) form, +-- this solves the inequation `sop <= 0` as follows: +-- 1. find `i`, the most dependent variable in `sop`, i.e., whose +-- transitive closure of the symbols appearing in its range is +-- maximal. +-- 2. re-write `sop = a*i + b`, where `a` and `b` are in SoP form. +-- 3. assumming the range of `i` to be `[l, u]`, we rewrite our +-- problem as below and solve it recursively: +-- `(a <= 0 && a*l + b <= 0) || (a >= 0 && a*u + b <= 0)` +-- If one of the ranges is missing, then we solve only the +-- subrpoblem that does not use the missing range. +-- If the result is +-- (i) `True` if the inequality is found to always holds; +-- (ii) `False` if there is an `i` for which the inequality does +-- not hold or if the answer is unknown. +fmSolveLEq0 :: MonadSoP u e m => SoP u -> m Bool +fmSolveLEq0 sop = do + sop' <- substEquivs sop + let syms = S.toList $ free sop' + case (justConstant sop', not (null syms)) of + (Just v, _) -> pure (v <= 0) + (_, True) -> do + rs <- getRanges + -- step 1: find `i` + let i = + snd $ + maximum $ + map (\s -> (length $ transClosInRanges rs $ S.singleton s, s)) syms + -- step 2: re-write `sop = a*i + b` + (a, b) = factorSoP [i] sop' + Range lb k ub <- lookupRange i + a_leq_0 <- fmSolveLEq0 a + al_leq_0 <- + anyM + ( \l -> + fmSolveLEq0 $ + a .*. l .+. int2SoP k .*. b + ) + (S.toList lb) + a_geq_0 <- fmSolveLEq0 $ negSoP a + au_leq_0 <- + anyM + ( \u -> + fmSolveLEq0 $ + a .*. u .+. int2SoP k .*. b + ) + (S.toList ub) + pure (a_leq_0 && al_leq_0 || a_geq_0 && au_leq_0) + _ -> pure False + +($<$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool +x $<$ y = fmSolveLTh0 $ x .-. y + +($<=$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool +x $<=$ y = fmSolveLEq0 $ x .-. y + +($>$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool +x $>$ y = fmSolveGTh0 $ x .-. y + +($>=$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool +x $>=$ y = fmSolveGEq0 $ x .-. y + +($==$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool +x $==$ y = (&&) <$> (x $<=$ y) <*> (x $>=$ y) + +($/=$) :: MonadSoP u e m => SoP u -> SoP u -> m Bool +x $/=$ y = (||) <$> (x $<$ y) <*> (x $>$ y) diff --git a/src/Futhark/SoP/Monad.hs b/src/Futhark/SoP/Monad.hs new file mode 100644 index 0000000000..ae1934a7a0 --- /dev/null +++ b/src/Futhark/SoP/Monad.hs @@ -0,0 +1,348 @@ +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | The Algebraic Environment, which is in principle +-- maintained during program traversal, is used to +-- solve symbolically algebraic inequations. +module Futhark.SoP.Monad + ( Nameable (..), + mkNameM, + RangeEnv, + EquivEnv, + UntransEnv (..), + AlgEnv (..), + addUntrans, + transClosInRanges, + lookupUntransPE, + lookupUntransSym, + lookupRange, + addRange, + SoPMT, + SoPM, + lookupSoP, + runSoPMT, + runSoPMT_, + runSoPM, + runSoPM_, + evalSoPMT, + evalSoPMT_, + evalSoPM, + evalSoPM_, + MonadSoP (..), + substEquivs, + addEquiv, + delFromEnv, + ) +where + +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer +import Data.Map (Map) +import Data.Map.Strict qualified as M +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.PrimExp +import Futhark.FreshNames +import Futhark.MonadFreshNames +import Futhark.SoP.Expression +import Futhark.SoP.SoP +import Futhark.Util.Pretty +import Language.Futhark.Syntax hiding (Range) + +-------------------------------------------------------------------------------- +-- Names; probably will remove in the end. +-------------------------------------------------------------------------------- + +-- | Types which can use a fresh source to generate +-- unique names. +class Nameable u where + mkName :: VNameSource -> (u, VNameSource) + +instance Nameable String where + mkName (VNameSource i) = ("x" <> show i, VNameSource $ i + 1) + +instance Nameable VName where + mkName (VNameSource i) = (VName "x" i, VNameSource $ i + 1) + +instance Nameable Name where + mkName = mkName + +mkNameM :: (Nameable u, MonadFreshNames m) => m u +mkNameM = modifyNameSource mkName + +-------------------------------------------------------------------------------- +-- Monad +-------------------------------------------------------------------------------- + +class + ( Ord u, + Ord e, + Nameable u, + Show u, -- To be removed + Pretty u, -- To be removed + MonadFreshNames m, + Expression e + ) => + MonadSoP u e m + | m -> u, + m -> e + where + getUntrans :: m (UntransEnv u e) + getRanges :: m (RangeEnv u) + getEquivs :: m (EquivEnv u) + modifyEnv :: (AlgEnv u e -> AlgEnv u e) -> m () + +-- | The algebraic monad; consists of a an algebraic +-- environment along with a fresh variable source. +newtype SoPMT u e m a = SoPMT (StateT (AlgEnv u e) m a) + deriving + ( Functor, + Applicative, + Monad + ) + +instance MonadTrans (SoPMT u e) where + lift = SoPMT . lift + +instance MonadFreshNames m => MonadFreshNames (SoPMT u e m) where + getNameSource = lift getNameSource + putNameSource = lift . putNameSource + +instance (MonadFreshNames m) => MonadFreshNames (StateT (AlgEnv u e) m) where + getNameSource = lift getNameSource + putNameSource = lift . putNameSource + +instance MonadReader r m => MonadReader r (SoPMT u e m) where + ask = SoPMT $ lift ask + local f (SoPMT m) = + SoPMT $ do + env <- get + (a, env') <- lift $ local f $ runStateT m env + put env' + pure a + +instance MonadState s m => MonadState s (SoPMT u e m) where + get = SoPMT $ lift get + put = SoPMT . lift . put + +instance MonadWriter w m => MonadWriter w (SoPMT u e m) where + tell = SoPMT . lift . tell + listen (SoPMT m) = SoPMT $ listen m + pass (SoPMT m) = SoPMT $ pass m + +type SoPM u e = SoPMT u e (State VNameSource) + +runSoPMT :: MonadFreshNames m => AlgEnv u e -> SoPMT u e m a -> m (a, AlgEnv u e) +runSoPMT env (SoPMT sm) = runStateT sm env + +runSoPMT_ :: (Ord u, Ord e, MonadFreshNames m) => SoPMT u e m a -> m (a, AlgEnv u e) +runSoPMT_ = runSoPMT mempty + +runSoPM :: (Ord u, Ord e) => AlgEnv u e -> SoPM u e a -> (a, AlgEnv u e) +runSoPM env = flip evalState mempty . runSoPMT env + +runSoPM_ :: (Ord u, Ord e) => SoPM u e a -> (a, AlgEnv u e) +runSoPM_ = runSoPM mempty + +evalSoPMT :: MonadFreshNames m => AlgEnv u e -> SoPMT u e m a -> m a +evalSoPMT env m = fst <$> runSoPMT env m + +evalSoPMT_ :: (Ord u, Ord e, MonadFreshNames m) => SoPMT u e m a -> m a +evalSoPMT_ = evalSoPMT mempty + +evalSoPM :: (Ord u, Ord e) => AlgEnv u e -> SoPM u e a -> a +evalSoPM env = fst . runSoPM env + +evalSoPM_ :: (Ord u, Ord e) => SoPM u e a -> a +evalSoPM_ = evalSoPM mempty + +instance + ( Ord u, + Ord e, + Nameable u, + Show u, + Pretty u, + MonadFreshNames m, + Expression e + ) => + MonadSoP u e (SoPMT u e m) + where + getUntrans = SoPMT $ gets untrans + + getRanges = SoPMT $ gets ranges + + getEquivs = SoPMT $ gets equivs + + modifyEnv f = SoPMT $ modify f + +-- \| Insert a symbol equal to an untranslatable 'PrimExp'. +addUntrans :: MonadSoP u e m => u -> e -> m () +addUntrans sym pe = + modifyEnv $ \env -> + env + { untrans = + (untrans env) + { dir = M.insert sym pe (dir (untrans env)), + inv = M.insert pe sym (inv (untrans env)) + } + } + +-- \| Look-up the sum-of-products representation of a symbol. +lookupSoP :: MonadSoP u e m => u -> m (Maybe (SoP u)) +lookupSoP x = (M.!? x) <$> getEquivs + +-- \| Look-up the symbol for a 'PrimExp'. If no symbol is bound +-- to the expression, bind a new one. +lookupUntransPE :: MonadSoP u e m => e -> m u +lookupUntransPE pe = do + inv_map <- inv <$> getUntrans + case inv_map M.!? pe of + Nothing -> do + x <- mkNameM + addUntrans x pe + pure x + Just x -> pure x + +-- \| Look-up the untranslatable 'PrimExp' bound to the given symbol. +lookupUntransSym :: MonadSoP u e m => u -> m (Maybe e) +lookupUntransSym sym = ((M.!? sym) . dir) <$> getUntrans + +-- \| Look-up the range of a symbol. If no such range exists, +-- return the empty range (and add it to the environment). +lookupRange :: MonadSoP u e m => u -> m (Range u) +lookupRange sym = do + mr <- (M.!? sym) <$> getRanges + case mr of + Nothing -> do + let r = Range mempty 1 mempty + addRange sym r + pure r + Just r + | rangeMult r <= 0 -> error "Non-positive constant encountered in range." + | otherwise -> pure r + +-- \| Add range information for a symbol; augments the existing +-- range. +addRange :: MonadSoP u e m => u -> Range u -> m () +addRange sym r = + modifyEnv $ \env -> + env {ranges = M.insertWith (<>) sym r (ranges env)} + +-- \| Add equivalent information for a symbol; unsafe and +-- should only be used for newly introduced variables. +addEquiv :: MonadSoP u e m => u -> SoP u -> m () +addEquiv sym sop = do + -- sop' <- substEquivs sop + modifyEnv $ \env -> + env {equivs = M.insert sym sop (equivs env)} + +-------------------------------------------------------------------------------- +-- Environment +-------------------------------------------------------------------------------- + +-- | The environment of untranslatable 'PrimeExp's. It maps both +-- ways: +-- +-- 1. A fresh symbol is generated and mapped to the +-- corresponding 'PrimeExp' @pe@ in 'dir'. +-- 2. The target @pe@ is mapped backed to the corresponding symbol in 'inv'. +data UntransEnv u e = Unknowns + { dir :: Map u e, + inv :: Map e u + } + deriving (Eq, Show, Ord) + +instance (Ord u, Ord e) => Semigroup (UntransEnv u e) where + Unknowns d1 i1 <> Unknowns d2 i2 = Unknowns (d1 <> d2) (i1 <> i2) + +instance (Ord u, Ord e) => Monoid (UntransEnv u e) where + mempty = Unknowns mempty mempty + +instance (Pretty u, Pretty e) => Pretty (UntransEnv u e) where + pretty env = + "dir:" + <> line + <> pretty (M.toList $ dir env) + <> line + <> "inv:" + <> line + <> pretty (M.toList $ inv env) + +-- | The equivalence environment binds a variable name to +-- its equivalent 'SoP' representation. +type EquivEnv u = Map u (SoP u) + +-- | The range environment binds a variable name to a range. +type RangeEnv u = Map u (Range u) + +instance Pretty u => Pretty (RangeEnv u) where + pretty = pretty . M.toList + +-- | The main algebraic environment. +data AlgEnv u e = AlgEnv + { -- | Binds untranslatable PrimExps to fresh symbols. + untrans :: UntransEnv u e, + -- | Binds symbols to their sum-of-product representation.. + equivs :: EquivEnv u, + -- | Binds symbols to ranges (in sum-of-product form). + ranges :: RangeEnv u + } + deriving (Ord, Show, Eq) + +instance (Ord u, Ord e) => Semigroup (AlgEnv u e) where + AlgEnv u1 s1 r1 <> AlgEnv u2 s2 r2 = + AlgEnv (u1 <> u2) (s1 <> s2) (r1 <> r2) + +instance (Ord u, Ord e) => Monoid (AlgEnv u e) where + mempty = AlgEnv mempty mempty mempty + +instance (Pretty u, Pretty e) => Pretty (AlgEnv u e) where + pretty (env) = + "Untranslatable environment:" + <> line + <> pretty (untrans env) + <> line + <> "Equivalence environment:" + <> line + <> pretty (M.toList $ equivs env) + <> line + <> "Ranges:" + <> line + <> pretty (M.toList $ ranges env) + +transClosInRanges :: (Ord u) => RangeEnv u -> Set u -> Set u +transClosInRanges rs syms = + transClosHelper rs syms S.empty syms + where + transClosHelper rs' clos_syms seen active + | S.null active = clos_syms + | (sym, active') <- S.deleteFindMin active, + seen' <- S.insert sym seen = + case M.lookup sym rs' of + Nothing -> + transClosHelper rs' clos_syms seen' active' + Just range -> + let new_syms = free range S.\\ seen + clos_syms' = S.union clos_syms new_syms + active'' = S.union new_syms active' + in transClosHelper rs' clos_syms' seen' active'' + +substEquivs :: MonadSoP u e m => SoP u -> m (SoP u) +substEquivs sop = flip substitute sop <$> getEquivs + +-- | Removes a symbol from the environment +delFromEnv :: MonadSoP u e m => u -> m () +delFromEnv x = + modifyEnv $ \env -> + env + { untrans = delFromUntrans $ untrans env, + equivs = M.delete x $ equivs env, + ranges = M.delete x $ ranges env + } + where + delFromUntrans ut = + ut + { dir = M.delete x $ dir ut, + inv = M.filter (/= x) $ inv ut + } diff --git a/src/Futhark/SoP/Refine.hs b/src/Futhark/SoP/Refine.hs new file mode 100644 index 0000000000..8b78e65892 --- /dev/null +++ b/src/Futhark/SoP/Refine.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE DataKinds #-} + +-- | Top-Level functionality for adding info to the +-- algebraic environment, which is in principle +-- constructed during program traversal, and used +-- to solve symbolically algebraic inequations. +module Futhark.SoP.Refine + ( addRel, + addRels, + ) +where + +import Control.Monad.State +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace +import Futhark.Analysis.PrimExp +import Futhark.SoP.Convert +import Futhark.SoP.Expression +import Futhark.SoP.Monad +import Futhark.SoP.RefineEquivs +import Futhark.SoP.RefineRanges +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty + +constraintToSoP :: (Ord u, MonadSoP u e m) => Rel u -> m (Set (SoP u == 0), Set (SoP u >= 0)) +constraintToSoP (x :<=: y) = pure (mempty, S.singleton $ y .-. x) +constraintToSoP (x :<: y) = pure (mempty, S.singleton $ y .-. (x .+. int2SoP 1)) +constraintToSoP (x :>: y) = pure (mempty, S.singleton $ x .-. (y .+. int2SoP 1)) +constraintToSoP (x :>=: y) = pure (mempty, S.singleton $ x .-. y) +constraintToSoP (x :==: y) = pure (S.singleton $ x .-. y, mempty) +constraintToSoP (x :&&: y) = (<>) <$> constraintToSoP x <*> constraintToSoP y +constraintToSoP c = error $ "constraintToSoP: " <> prettyString c + +addRel :: (ToSoP u e, MonadSoP u e m) => Rel u -> m () +addRel c = do + (eqZs, ineqZs) <- constraintToSoP c + extra_ineqZs <- addEqZeros eqZs + addIneqZeros $ ineqZs <> extra_ineqZs + +addRels :: (FromSoP u e, ToSoP u e, MonadSoP u e m) => Set (Rel u) -> m () +addRels cs = do + -- Split candidates into equality and inequality sets. + (eqZs, ineqZs) <- mconcat <$> mapM constraintToSoP (S.toList cs) + -- Refine the environment with the equality set. + extra_ineqZs <- addEqZeros eqZs + -- Refine the environment with the extended inequality set. + addIneqZeros $ ineqZs <> extra_ineqZs diff --git a/src/Futhark/SoP/RefineEquivs.hs b/src/Futhark/SoP/RefineEquivs.hs new file mode 100644 index 0000000000..2db4a698bf --- /dev/null +++ b/src/Futhark/SoP/RefineEquivs.hs @@ -0,0 +1,212 @@ +{-# LANGUAGE DataKinds #-} + +-- | Functionality for refining the equivalence environment +module Futhark.SoP.RefineEquivs + ( addEqZeros, + addEq, + ) +where + +import Control.Monad +import Data.Foldable (minimumBy) +import Data.Map.Strict qualified as M +import Data.MultiSet qualified as MS +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace +import Futhark.Analysis.PrimExp +import Futhark.SoP.Convert +import Futhark.SoP.Expression +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Futhark.SoP.Util +import Futhark.Util.Pretty + +addEq :: forall u e m. (ToSoP u e, MonadSoP u e m) => u -> SoP u -> m () +addEq sym sop = do + -- cands <- mkEquivCands (/= sym) $ sop .-. sym2SoP sym + addLegalCands $ S.singleton $ EquivCand sym sop + +-- | Refine the environment with a set of 'PrimExp's with the assertion that @pe = 0@ +-- for each 'PrimExp' in the set. +addEqZeros :: forall u e m. (ToSoP u e, MonadSoP u e m) => Set (SoP u == 0) -> m (Set (SoP u >= 0)) +addEqZeros sops = do + -- Make equivalence candidates along with any extra constraints. + (extra_inEqZs :: Set (SoP u >= 0), equiv_cands) <- + mconcat <$> mapM addEquiv2CandSet (S.toList sops) + -- Add one-by-one all legal equivalences to the algebraic + -- environment, i.e., range and equivalence envs are updated + -- as long as the new substitutions do not introduce cycles. + addLegalCands equiv_cands + -- Return the newly generated constraints. + pure extra_inEqZs + +-- | An equivalence candidate; a candidate @'EquivCand' sym sop@ means +-- @sym = sop@. +data EquivCand u = EquivCand + { equivCandSym :: u, + equivCandSoP :: SoP u + } + deriving (Eq, Show, Ord) + +instance Ord u => Free u (EquivCand u) where + free = free . equivCandSoP + +instance Ord u => Substitute u (SoP u) (EquivCand u) where + substitute subst (EquivCand sym sop) = + EquivCand sym $ substitute subst sop + +-- | A candidate for the equivalence env is found when: +-- +-- (1) A term of the SoP has one symbol (here, named @sym@). +-- (2) @sym@'s value factor is @1@ or @-1@. +-- (3) @sym@ does not appear in the other terms of the 'SoP'. +-- (4) @sym@ is not already present in the equivalence environment. +-- +-- ToDo: try to give common factor first, e.g., +-- nx - nbq - n = 0 => n*(x-bq-1) = 0 => x = bq+1, +-- if we can prove that n != 0 +mkEquivCands :: MonadSoP u e m => (u -> Bool) -> SoP u -> m (Set (EquivCand u)) +mkEquivCands p sop = + pure (getTerms sop) + >>= M.foldrWithKey mkEquivCand (pure mempty) + where + -- getTerms <$> substEquivs sop + -- >>= M.foldrWithKey mkEquivCand (pure mempty) + + mkEquivCand (Term term) v mcands + | abs v == 1, + [sym] <- MS.toList term, + sop' <- deleteTerm term sop, + sym `notElem` free sop', + p sym = do + msop <- lookupSoP sym + case msop of + Nothing -> + S.insert (EquivCand sym (scaleSoP (-v) sop')) <$> mcands + Just {} -> + mcands + | otherwise = mcands + +-- | Algebraic manipulation of 'EquivCand's. Potentially yields +-- additional constraints (inequalities). +-- +-- Currently supports two cases: +-- +-- 1. Refinement of the modulo expression: +-- Assume the equivalence @sym = sop@, where @sym@ is +-- bound in the untranslatable environment as @sym -> pe1 % pe2@. +-- This means `pe1 % pe2 = sop` and we can do the +-- following re-writting: +-- +-- (1) Check @pe2 >= 0@. +-- (2) Check the sum-of-products representation of @pe1@ +-- is a single symbol, which is not in equivalence environment. +-- +-- If (1) and (2) hold then: +-- +-- * We rewrite: @pe1 = pe2 * q + sop@. +-- * Possibly add the constraints @0 <= sop <= pe2 - 1@. +-- +-- 2: TODO: try to give common factors and get simpler. +refineEquivCand :: forall u e m. (ToSoP u e, MonadSoP u e m) => EquivCand u -> m (Set (SoP u >= 0), EquivCand u) +-- refineEquivCand cand@(EquivCand sym sop) +-- | justPositive sop = pure (S.singleton $ sym2SoP sym, cand) +refineEquivCand cand@(EquivCand sym sop) = do + mpe <- lookupUntransSym sym + case mpe of + Just pe + | Just (pe1, pe2) <- moduloIsh pe -> do + (f1, sop1) <- toSoPNum pe1 + (f2, sop2) <- toSoPNum pe2 + is_pos <- fmSolveGEq0 sop2 + case (f1, f2, justSym sop1, is_pos) of + (1, 1, Just sym1, True) -> do + msop <- lookupSoP sym1 + case msop of + Just {} -> pure (mempty, cand) + Nothing -> do + q <- mkNameM + let div_pe = divInsteadOfMod pe + q_sop = sym2SoP q + new_cand = EquivCand sym1 $ sop .+. q_sop .*. sop2 + new_ineq = sop2 .-. (sop .+. int2SoP 1) + pe_ineq = S.fromList [new_ineq, sop] + addUntrans q div_pe + pure (pe_ineq, new_cand) + _ -> pure (mempty, cand) + _ -> do + if (justPositive sop) + then pure (S.singleton $ sym2SoP sym, cand) + else pure (mempty, cand) + +-- | Takes a 'PrimExp' @pe@ with the property that @pe = 0@ and +-- returns two sets @'addEquiv2CandSet' pe = (ineqs,cand)@: +-- +-- * @ineqs@: a set of extra inequality constraints generated during the +-- creation/refinement of the mapping. +-- * @cands@: set of equivalence candidates. +addEquiv2CandSet :: + (ToSoP u e, MonadSoP u e m) => + SoP u == 0 -> + m (Set (SoP u >= 0), Set (EquivCand u)) +addEquiv2CandSet sop = do + cands <- mkEquivCands (const True) sop + (ineqss, cands') <- mapAndUnzipM refineEquivCand $ S.toList cands + pure (mconcat ineqss, S.fromList cands') + +-- | Add legal equivalence candidates to the environment. +addLegalCands :: MonadSoP u e m => Set (EquivCand u) -> m () +addLegalCands cand_set + | S.null cand_set = pure () +addLegalCands cand_set = do + rs <- getRanges + eqs <- getEquivs + let -- Chose the candidate @sym -> sop@ whose @sop@ has + -- the smallest number of symbols in the environment. + env_syms = M.keysSet eqs <> M.keysSet rs + cand = minimumBy (scoreCand env_syms) cand_set + -- Check whether target substitution does not create cycles + -- in the equivalence and range environments. + if not $ validCand rs eqs cand + then addLegalCands $ S.delete cand cand_set + else do + -- Apply substitution to equivalence and range envs + -- and add the new binding to the equivalence env. + modifyEnv $ \env -> + env + { equivs = + M.insert (equivCandSym cand) (equivCandSoP cand) $ + M.map (subCand cand) eqs, + ranges = M.map (subCand cand) rs + } + addLegalCands $ subCand cand $ S.delete cand cand_set + where + subCand (EquivCand sym sop) = substituteOne (sym, sop) + scoreCand env_syms cand1 cand2 = + let score cand = length $ free (equivCandSoP cand) S.\\ env_syms + in score cand1 `compare` score cand2 + validCand rs eqs cand = + not $ + -- Check if a candidate is already present in the + -- equivalence environment. + equivCandSym cand + `elem` M.keysSet eqs + -- Detect if an equivalence candidate would introduce + -- a cycle into the equivalence environment. + && any hasEquivCycle (M.toList eqs) + -- Detect if an equivalence candidate would introduce + -- a cycle into the range environment. + && any hasRangeCycle (M.toList rs) + where + -- Since the equivalence environment contains the fully + -- substituted bindings (accounting for predecessor + -- substitutions), we do not need to (explicitly) compute the + -- transititve closures. + hasEquivCycle (sym, sop) = + (sym `elem` free cand) + && (equivCandSym cand `elem` free sop) + hasRangeCycle (sym, range) = + (sym `elem` transClosInRanges rs (free cand)) + && (equivCandSym cand `elem` free range) diff --git a/src/Futhark/SoP/RefineRanges.hs b/src/Futhark/SoP/RefineRanges.hs new file mode 100644 index 0000000000..98846eef3d --- /dev/null +++ b/src/Futhark/SoP/RefineRanges.hs @@ -0,0 +1,210 @@ +{-# LANGUAGE DataKinds #-} + +-- | Functionality for processing the range environment +module Futhark.SoP.RefineRanges + ( addIneqZeros, + ) +where + +import Control.Monad.State +import Data.Map.Strict qualified as M +import Data.MultiSet qualified as MS +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace +import Futhark.Analysis.PrimExp +import Futhark.Analysis.PrimExp.Convert +import Futhark.SoP.Convert +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Futhark.SoP.Util + +-- | Refine the environment with a set of 'PrimExp's with the assertion that @pe >= 0@ +-- for each 'PrimExp' in the set. +addIneqZeros :: forall u e m. (ToSoP u e, MonadSoP u e m) => Set (SoP u >= 0) -> m () +addIneqZeros sops = do + ineq_cands <- + mconcat + <$> mapM mkRangeCands (S.toList sops) + addRangeCands ineq_cands + +-- | A range candidate; a candidate @'RangeCand' v sym sop@ means +-- +-- > v*sym + sop >= 0 +data RangeCand u = RangeCand + { rangeCandScale :: Integer, + rangeCandSym :: u, + rangeCandSoP :: SoP u + } + deriving (Eq, Show, Ord) + +instance Ord u => Free u (RangeCand u) where + free = free . rangeCandSoP + +-- | Make range candidates from a 'SoP' from its 'Term's. A candidate +-- 'Term' for the range env is found when: +-- +-- 1. It consists of a single symbol, @sym@. +-- 2. @sym@ does not appear in the other 'Term's of the 'SoP'. +-- +-- TODOs: try to give common factor first, e.g., +-- @nx - nbq - n = 0@ is equivalent to +-- @n*(x-bq-1) >= 0@, hence, if we can prove +-- that @n >= 0@ we can derive @x >= bq+1@. +mkRangeCands :: MonadSoP u e m => (SoP u >= 0) -> m (Set (RangeCand u)) +mkRangeCands sop = do + -- sop' <- substEquivs sop + let sop' = sop + let singleSymCands = mkSingleSymCands sop' + factorCands <- factorCandsM sop' + pure $ singleSymCands <> factorCands + where + factorCandsM sop' = + mconcat + <$> forM + (sopFactors sop') + ( \(rem, term) -> do + ifM + (term2SoP term 1 $>=$ zeroSoP) + (pure $ mkSingleSymCands rem) + (pure mempty) + ) + mkSingleSymCands sop' = + M.foldrWithKey mkRangeCand mempty $ getTerms $ sop' + mkRangeCand (Term term) v cands + | [sym] <- MS.toList term, + sop' <- deleteTerm term sop, + sym `notElem` free sop' = + S.insert (RangeCand v sym sop') cands + | otherwise = cands + +-- | Refines a range in the range environment from a range +-- canditate. +-- +-- @'refineRangeInEnv' (j, sym, sop)@ refines the existing range of +-- the symbol @sym@ +-- +-- > max{lbs} <= k*sym <= min{ubs} +-- +-- by computing the 'lcm' of @j@ and @k@ to obtain the bounds +-- +-- (1) @max{k_z*lbs} <= z*sym <= min{k_z*ubs}@ +-- (2) @z*sym + j_z*sop <= 0@ +-- +-- where @z = 'lcm' k j@. If (2) refines (1) (i.e., it tightens the +-- upper or lower bounds on @sym@), it's merged with (1) and any +-- bounds that are looser than the bound introduced by (2) are +-- removed. If @j < 0@ (@j >= 0@), the upper (lower) bound on @sym@ +-- may be tightened. +-- +-- Returns a set of new range canditates: if @j < 0@ (@j >= 0@) +-- these are @lbs' <= -j_z * sop@ (@j_z * sop <= ubs'@) where @lbs'@ +-- (@ubs'@) are the refined bounds from the previous step. +refineRangeInEnv :: + MonadSoP u e m => + RangeCand u -> + m (Set (RangeCand u)) +refineRangeInEnv (RangeCand j sym sop) = do + Range lbs k ubs <- lookupRange sym + let z = lcm k j + j_z = z `div` j + k_z = z `div` k + lbs' = S.map (scaleSoP k_z) lbs + ubs' = S.map (scaleSoP k_z) ubs + sop' = (-j_z) `scaleSoP` sop + if j < 0 + then do + -- reject: ∃b.sop' >= b? + -- remove: only keep b with b < new_bound = !(new_bound <= b) + ubs'' <- mergeBound (sop' $>=$) ($<=$ sop') ubs' sop' + addRange sym $ Range lbs' z ubs'' + -- New candidates: lbs <= sop' --> sop' - lbs >= 0 + mconcat <$> mapM (mkRangeCands . (sop' .-.)) (S.toList lbs') + else do + -- reject: ∃b.new_bound <= b? + -- remove: only keep b with new_bound < b = !(new_bound >= b) + lbs'' <- mergeBound (sop' $<=$) ($>=$ sop') lbs' sop' + addRange sym $ Range lbs'' z ubs' + -- New candidates: sop' <= ubs --> ubs - sop' >= 0 + mconcat <$> mapM (mkRangeCands . (.-. sop')) (S.toList ubs') + where + mergeBound reject remove bs sop' = + ifM + (anyM reject bs) + (pure bs) + (S.insert sop' <$> (S.fromList <$> filterM (fmap not . remove) (S.toList bs))) + +-- | Candidate ranking. @'SymNotBound' > 'CompletesRange' > 'Default'@. +-- +-- * 'SymNotBound': candidate symbol doesn't appear in the range environment. +-- * 'CompetesRange': candidate completes the range of a symbol in the range environment +-- with a partial range. +-- * 'Default': all other candidates. +data CandRank + = Default + | CompletesRange + | SymNotBound + deriving (Ord, Eq) + +addRangeCands :: MonadSoP u e m => Set (RangeCand u) -> m () +addRangeCands cand_set + | S.null cand_set = pure () +addRangeCands cand_set = do + rs <- getRanges + let cands = S.toList cand_set + -- 1. Compute the transitive closure of the 'SoP' of + -- each candidate through the range environment. + tcs = map (transClosInRanges rs . free) cands + -- 2. Filter out the candidates that introduce cycles + -- through the range environment. A cycle appears iff + -- @sym@ appears in the transitive closure of the + -- symbols appearing in its ranges. + cands_tcs = filter (not . uncurry hasCycle) $ zip cands tcs + -- 3. Choose the candidates whose transitive closure through the + -- range env have the lowest number of free symbols (symbols + -- which do not belong to the keys f the range env.) + cands' = fst $ foldr (compareNumFreeInRangeEnv rs) (mempty, maxBound) cands_tcs + -- 4. Rank the candidates. + cands'' = fst $ foldr (compareRank rs) (mempty, Default) cands' + case S.toList cands'' of + [] -> pure () + cand : _ -> do + -- Incorporate the constraints imposed by the top-scoring + -- candidate into the range environment and continue, + -- adding any newly generated candidates into the candidate + -- set. + new_cands <- refineRangeInEnv cand + addRangeCands $ + S.delete cand $ + new_cands <> S.fromList (map fst cands_tcs) + where + -- A cycle appears if and only if @sym@ is in its + -- transitive closure. Proof: Suppose a cycle appears + -- via some @x@. Since the range environment is + -- (inductively) assumed to be cycle-free, this means + -- that the range of @sym@ depends on @x@ + -- and the range of @x@ depends on @sym@. It follows that + -- @sym@ will be in its transitive closure. + hasCycle cand tc = + rangeCandSym cand `S.member` tc + compareNumFreeInRangeEnv rs (cand, tc) (acc, n_acc) = + case n_acc `compare` n_free of + LT -> (acc, n_acc) + EQ -> (S.insert cand acc, n_acc) + GT -> (S.singleton cand, n_free) + where + n_free = length (tc S.\\ M.keysSet rs) + rankCand rs (RangeCand k sym _) + | sym `M.notMember` rs = SymNotBound + | completesRange k (rs M.! sym) = CompletesRange + | otherwise = Default + completesRange v (Range lbs _ ubs) = + (v < 0 && not (S.null lbs) && S.null ubs) + || (v > 0 && S.null lbs && not (S.null ubs)) + compareRank rs cand (acc, acc_rank) + | cand_rank > acc_rank = (S.singleton cand, cand_rank) + | cand_rank == acc_rank = (S.insert cand acc, acc_rank) + | otherwise = (acc, acc_rank) + where + cand_rank = rankCand rs cand diff --git a/src/Futhark/SoP/SoP.hs b/src/Futhark/SoP/SoP.hs new file mode 100644 index 0000000000..e71fb24e2a --- /dev/null +++ b/src/Futhark/SoP/SoP.hs @@ -0,0 +1,530 @@ +-- | The sum-of-products representation and related operations. +module Futhark.SoP.SoP + ( Term (..), + SoP (..), + Range (..), + toTerm, + mapSoP, + mapTermSoP, + mapTermSoPM, + mapSymSoP_, + mapSymSoPM, + mapSymSoP, + isConstTerm, + filterSoP, + term2SoP, + sym2SoP, + int2SoP, + scaleSoP, + zeroSoP, + negSoP, + addSoPs, + (.+.), + subSoPs, + (.-.), + mulSoPs, + (.*.), + divSoPs, + (./.), + divSoPInt, + signumSoP, + factorSoP, + sopFactors, + numTerms, + justSym, + justConstant, + justAffine, + justSingleTerm, + justSingleTerm_, + justPositive, + deleteTerm, + insertTerm, + powerSoP, + Free (..), + Substitute (..), + substituteOne, + sopToList, + sopToLists, + sopFromList, + termToList, + Rel (..), + orRel, + andRel, + normalize, + ) +where + +import Data.Map (Map) +import Data.Map.Strict qualified as M +import Data.Maybe +import Data.MultiSet (MultiSet) +import Data.MultiSet qualified as MS +import Data.Set (Set) +import Data.Set qualified as S +import Futhark.Analysis.PrimExp.Convert +import Futhark.SoP.Util +import Futhark.Util.Pretty +import Language.Futhark.Core +import Language.Futhark.Prop + +-- | A 'Term' is a product of symbols. +newtype Term u = Term {getTerm :: MultiSet u} + deriving (Eq, Ord, Monoid, Semigroup, Foldable, Show) + +-- | A sum-of-products is a constant value added to a sum of terms, +-- which are (by construction) +-- +-- 1. Lexicographically sorted. +-- 2. Contain no duplicated terms, i.e., @2*x*y + 3*x*y@ is +-- illegal. +data SoP u = SoP {getTerms :: Map (Term u) Integer} + deriving (Ord, Show) + +instance (Ord u, Eq u) => Eq (SoP u) where + x == y = getTerms (normalize x) == getTerms (normalize y) + +-- | A symbol @sym@ with range @'Range' lbs k ubs@ means @max{lbs} <= +-- k*sym <= min{ubs}@. 'lbs' and 'ubs' are (potentially empty) sets +-- of 'SoP's. +data Range u = Range + { lowerBound :: Set (SoP u), + rangeMult :: Integer, + upperBound :: Set (SoP u) + } + deriving (Eq, Ord) + +-- | Should probably make this smarter +instance (Ord u) => Semigroup (Range u) where + Range lb1 k1 ub1 <> Range lb2 k2 ub2 = + Range + (S.map (int2SoP m1 .*.) lb1 <> S.map (int2SoP m2 .*.) lb2) + (lcm k1 k2) + (S.map (int2SoP m1 .*.) ub1 <> S.map (int2SoP m2 .*.) ub2) + where + m1 = lcm k1 k2 `div` k1 + m2 = lcm k1 k2 `div` k2 + +instance (Ord u) => Monoid (Range u) where + mempty = Range mempty 1 mempty + +instance (Pretty u) => Pretty (Term u) where + pretty (Term t) = + mconcat $ punctuate "*" $ map pretty $ MS.toList t + +instance (Pretty u) => Pretty (SoP u) where + pretty (SoP ts) + | M.null ts = "0" + | otherwise = + mconcat $ + punctuate " + " $ + map (uncurry pTerm) $ + M.toList ts + where + pTerm term n + | isConstTerm term = pretty n + | n == 1 = pretty term + | otherwise = pretty n <> "*" <> pretty term + +instance (Pretty a) => Pretty (Set a) where + pretty as = "{" <> mconcat (punctuate ", " (map pretty $ S.toList as)) <> "}" + +instance (Pretty u) => Pretty (Range u) where + pretty (Range lb k ub) = + pretty_lb <> pretty k <> pretty_ub + where + pretty_lb = + -- \| S.null lb = mempty + "max" <> pretty lb <+> "<= " + pretty_ub = + -- \| S.null ub = mempty + " <=" <+> "min" <> pretty ub + +instance {-# OVERLAPS #-} (Pretty u) => Pretty (u, Range u) where + pretty (sym, Range lb k ub) = + pretty_lb <> psym <> pretty_ub + where + psym + | k == 1 = pretty sym + | otherwise = pretty k <> "*" <> pretty sym + pretty_lb = + -- \| S.null lb = mempty + "max" <> pretty lb <+> "<= " + pretty_ub = + -- \| S.null ub = mempty + " <=" <+> "min" <> pretty ub + +-- instance Pretty u => Show (Term u) where +-- show = prettyString +-- +-- instance Pretty u => Show (SoP u) where +-- show = prettyString + +instance (Pretty u) => Show (Range u) where + show = prettyString + +-------------------------------------------------------------------------------- +-- Term operations +-------------------------------------------------------------------------------- + +-- | Is the term a constant? +isConstTerm :: Term u -> Bool +isConstTerm (Term t) = MS.null t + +-- | Converts anything list-like into a term. +toTerm :: (Foldable t, Ord u) => t u -> Term u +toTerm = Term . toMS + +termToList :: Term u -> [u] +termToList (Term t) = MS.toList t + +-- | Is 'x' a factor of 'y'? +isFactorOf :: (Ord u) => Term u -> Term u -> Bool +isFactorOf (Term x) (Term y) = x `MS.isSubsetOf` y + +-- | Divides 'x' by 'y'. +divTerm :: (Ord u) => Term u -> Term u -> Maybe (Term u) +divTerm xt@(Term x) yt@(Term y) + | yt `isFactorOf` xt = Just $ Term $ x MS.\\ y + | otherwise = Nothing + +termPowers :: Term u -> [(u, Int)] +termPowers (Term t) = MS.toOccurList t + +-------------------------------------------------------------------------------- +-- Basic operations +-------------------------------------------------------------------------------- + +-- | Pads a SoP with a 0 constant, if it doesn't have one. I.e., +-- transforms sop into sop + 0. Useful for pattern matching the constant +-- term of SoPs. +padWithZero :: (Ord u) => SoP u -> SoP u +padWithZero sop@(SoP ts) = + case ts M.!? mempty of + Nothing -> + SoP $ M.insert mempty 0 ts + Just {} -> sop + +-- | Filters the terms of an 'SoP'. +filterSoP :: (Term u -> Integer -> Bool) -> SoP u -> SoP u +filterSoP p (SoP ts) = SoP $ M.filterWithKey p ts + +-- | Normalizes a SoP. Here, that just means removing any keys of the +-- form @0 * term@. (i.e., superfluous zeros). +normalize :: (Ord u) => SoP u -> SoP u +normalize sop + | Just {} <- justConstant sop = sop + | otherwise = filterSoP (\_ n -> n /= 0) sop + +mapSoP :: (Integer -> Integer) -> SoP u -> SoP u +mapSoP f (SoP ts) = SoP $ fmap f ts + +mapTermSoP :: (Foldable t, Ord u, Ord (t u)) => ([u] -> Integer -> (t u, Integer)) -> SoP u -> SoP u +mapTermSoP f = sopFromList . map (uncurry f) . sopToLists + +mapSymSoP_ :: (Ord u) => (u -> u) -> SoP u -> SoP u +mapSymSoP_ f = SoP . M.mapKeys (Term . MS.map f . getTerm) . getTerms + +mapSymSoPM :: (Ord u, Monad m) => (u -> m u) -> SoP u -> m (SoP u) +mapSymSoPM f = fmap sopFromList . mapM (\(ts, a) -> (,a) <$> mapM f ts) . sopToLists + +mapSymSoP :: (Ord u) => (u -> SoP u) -> SoP u -> SoP u +mapSymSoP f = + foldr + ( \(ts, a) -> + ( ( scaleSoP a $ + foldr (.*.) (int2SoP 1) $ + map f ts + ) + .+. + ) + ) + (int2SoP 0) + . sopToLists + +mapTermSoPM :: (Foldable t, Ord u, Ord (t u), Monad m) => ([u] -> Integer -> m (t u, Integer)) -> SoP u -> m (SoP u) +mapTermSoPM f = + fmap sopFromList . mapM (uncurry f) . sopToLists + +sopToList :: SoP u -> [(Term u, Integer)] +sopToList (SoP ts) = M.toList ts + +sopTerms :: SoP u -> [Term u] +sopTerms = map fst . sopToList + +sopToLists :: (Ord u) => SoP u -> [([u], Integer)] +sopToLists (SoP ts) = M.toList $ M.mapKeys termToList ts + +sopFromList :: (Foldable t, Ord u, Ord (t u)) => [(t u, Integer)] -> SoP u +sopFromList = SoP . M.mapKeys toTerm . M.fromList + +-- | An 'SoP' composed of a single term. +term2SoP :: (Foldable t, Ord u) => t u -> Integer -> SoP u +term2SoP t n = SoP $ M.singleton (toTerm t) n + +-- | An 'SoP' composed of a single symbol. +sym2SoP :: (Ord u) => u -> SoP u +sym2SoP sym = term2SoP (MS.singleton sym) 1 + +-- | An 'SoP' composed of a single constant. +int2SoP :: (Ord u) => Integer -> SoP u +int2SoP = term2SoP MS.empty + +-- | Deletes a term from an 'SoP'. Warning: ignores the multiplicity +-- of the term---__not__ the same as subtraction! +deleteTerm :: (Foldable t, Ord u) => t u -> SoP u -> SoP u +deleteTerm t (SoP ts) = SoP $ M.delete (toTerm t) ts + +-- | Inserts a term into an 'SoP'. Warning: ignores the multiplicity +-- of the term---__not__ the same as addition! +insertTerm :: (Foldable t, Ord u) => t u -> SoP u -> SoP u +insertTerm t (SoP ts) = SoP $ M.insert (toTerm t) 1 ts + +-- | Power set analogue of a 'SoP'. +powerSoP :: (Ord u) => SoP u -> Set (SoP u) +powerSoP sop = + S.map (sopFromList . S.toList) $ S.powerSet $ S.fromList $ sopToList sop + +-------------------------------------------------------------------------------- +-- SoP arithmetic +-------------------------------------------------------------------------------- + +zeroSoP :: (Ord u) => SoP u +zeroSoP = SoP $ M.singleton (toTerm MS.empty) 0 + +scaleSoP :: Integer -> SoP u -> SoP u +scaleSoP k = mapSoP (* k) + +negSoP :: SoP u -> SoP u +negSoP = scaleSoP (-1) + +addSoPs :: (Ord u) => SoP u -> SoP u -> SoP u +addSoPs (SoP xs) (SoP ys) = normalize $ SoP $ M.unionWith (+) xs ys + +(.+.) :: (Ord u) => SoP u -> SoP u -> SoP u +(.+.) = addSoPs + +infixl 6 .+. + +subSoPs :: (Ord u) => SoP u -> SoP u -> SoP u +subSoPs x y = x .+. negSoP y + +(.-.) :: (Ord u) => SoP u -> SoP u -> SoP u +(.-.) = subSoPs + +infixl 6 .-. + +mulSoPs :: (Ord u) => SoP u -> SoP u -> SoP u +mulSoPs (SoP xs) (SoP ys) = normalize $ SoP $ M.fromListWith (+) $ do + (x_term, x_n) <- M.toList xs + (y_term, y_n) <- M.toList ys + pure (x_term <> y_term, x_n * y_n) + +(.*.) :: (Ord u) => SoP u -> SoP u -> SoP u +(.*.) = mulSoPs + +infixl 7 .*. + +-- | @'factorSoP' term sop = (a, b)@ where @sop = a*term + b@. +factorSoP :: (Foldable t, Ord u) => t u -> SoP u -> (SoP u, SoP u) +factorSoP fact sop = (sopFromList as, sopFromList bs) + where + fact' = toTerm fact + as = mapMaybe (\(t, n) -> (,n) <$> t `divTerm` fact') $ sopToList sop + bs = filter (not . (fact' `isFactorOf`) . fst) $ sopToList sop + +-- | The factors of an 'SoP'. +sopFactors :: (Ord u) => SoP u -> [(SoP u, Term u)] +sopFactors sop = + map (\(t, (a, _)) -> (a, t)) $ + filter ((zeroSoP ==) . snd . snd) $ + map (\t -> (t, factorSoP t sop)) $ + sopTerms sop + +-- | Division of 'SoP's. Handles the following cases: +-- +-- 1. @(qv + qv_1 * t_1 + ... + qv_n*t_n) / q@ results in +-- @'Just' (v + v_1*t_1 + ... + v_n*t_n)@ +-- +-- 2. @(0 + v_1 * t_1 * t_q + ... + v_n * t_n * t_q) / t_q@ +-- results in @'Just' (0 + v_1 * t_1 + ... + v_n * t_n)@. +-- +-- Otherwise results in 'Nothing'. A possible generalization would +-- be to perform symbolically division with reminder, i.e., the +-- result would be two sum-of-products representing the quotient and +-- the reminder. +divSoPs :: (Ord u) => SoP u -> SoP u -> Maybe (SoP u) +divSoPs (SoP x) (SoP q_sop) + | [q] <- M.toList q_sop = SoP . M.fromList <$> mapM (`divSoPTerm` q) (M.toList x) + | otherwise = Nothing + where + divideVal v qv + | v `mod` qv == 0 = Just $ v `div` qv + | otherwise = Nothing + divSoPTerm (term, v) (qterm, qv) = + (,) <$> term `divTerm` qterm <*> v `divideVal` qv + +(./.) :: (Ord u) => SoP u -> SoP u -> Maybe (SoP u) +(./.) = divSoPs + +infixl 7 ./. + +-- | Integer division of 'SoP's. Both 'SoP's must be constants. +divSoPInt :: (Ord u) => SoP u -> SoP u -> Maybe (SoP u) +divSoPInt x y = + int2SoP <$> (div <$> justConstant x <*> justConstant y) + +-- | Sign of a constant 'SoP'. +signumSoP :: (Ord u) => SoP u -> Maybe (SoP u) +signumSoP = fmap (int2SoP . signum) . justConstant + +-------------------------------------------------------------------------------- +-- SoP queries +-------------------------------------------------------------------------------- + +-- | How many terms does the 'SoP' have? +numTerms :: (Ord u) => SoP u -> Int +numTerms = length . getTerms . normalize + +-- | Is the 'SoP' just a constant? +justConstant :: (Ord u) => SoP u -> Maybe Integer +justConstant sop + | [(term, n)] <- sopToList $ padWithZero sop, + isConstTerm term = + Just n + | otherwise = Nothing + +-- | Is the 'SoP' just a single symbol? +justSym :: (Ord u) => SoP u -> Maybe u +justSym sop + | [([x], 1)] <- sopToLists $ normalize sop = Just x + | otherwise = Nothing + +-- | Is the 'SoP' of the form a*x + b? +justAffine :: (Ord u) => SoP u -> Maybe (Integer, u, Integer) +justAffine sop + | [([], a), ([x], m)] <- sopToLists $ padWithZero sop = Just (m, x, a) + | otherwise = Nothing + +-- | Is the 'SoP' a single term? +justSingleTerm :: (Ord u) => SoP u -> Maybe (Term u, Integer) +justSingleTerm sop + | [t] <- sopToList $ normalize sop = Just t + | otherwise = Nothing + +justSingleTerm_ :: (Ord u) => SoP u -> Maybe ([u], Integer) +justSingleTerm_ sop + | [(t, a)] <- sopToList $ normalize sop = Just (termToList t, a) + | otherwise = Nothing + +-- | Can we guarantee the 'SoP' is positive? TODO: This can be more sophisticated. +justPositive :: (Ord u) => SoP u -> Bool +justPositive sop + | Just x <- justConstant sop = x > 0 + | Just (t, a) <- justSingleTerm sop = + a > 0 && all (even . snd) (termPowers t) + | otherwise = False + +-------------------------------------------------------------------------------- +-- Free symbols in SoPs +-------------------------------------------------------------------------------- + +class Free u a where + free :: a -> Set u + +instance (Ord u, Free u a) => Free u (Set a) where + free = foldMap free + +instance (Ord u) => Free u (SoP u) where + free = foldMap (MS.toSet . getTerm) . M.keys . getTerms + +instance (Ord u) => Free u (Range u) where + free r = free (lowerBound r) <> free (upperBound r) + +-------------------------------------------------------------------------------- +-- Substitutions in SoPs +-------------------------------------------------------------------------------- + +class Substitute a b c where + substitute :: Map a b -> c -> c + +substituteOne :: (Substitute a b c) => (a, b) -> c -> c +substituteOne (a, b) = substitute (M.singleton a b) + +instance (Ord c, Substitute a b c) => Substitute a b (Set c) where + substitute subst = S.map (substitute subst) + +instance (Ord c, Substitute a b c) => Substitute a b [c] where + substitute subst = map (substitute subst) + +instance (Substitute a b c) => Substitute a b (Map k c) where + substitute subst = fmap (substitute subst) + +instance (Ord u) => Substitute u (SoP u) (SoP u) where + substitute subst = + SoP + . M.unionsWith (+) + . map + ( \(term, n) -> + getTerms $ + foldr (mulSoPs . lookupSubst) (int2SoP n) (termToList term) + ) + . sopToList + where + lookupSubst u = + case subst M.!? u of + Nothing -> sym2SoP u + Just sop -> sop + +instance (Ord u) => Substitute u (SoP u) (Range u) where + substitute subst (Range lb k ub) = + Range (substitute subst lb) k (substitute subst ub) + +instance (Ord u) => Substitute u u (SoP u) where + substitute subst = substitute (fmap sym2SoP subst) + +data Rel u + = (:<:) (SoP u) (SoP u) + | (:<=:) (SoP u) (SoP u) + | (:>:) (SoP u) (SoP u) + | (:>=:) (SoP u) (SoP u) + | (:==:) (SoP u) (SoP u) + | (:/=:) (SoP u) (SoP u) + | (:&&:) (Rel u) (Rel u) + | (:||:) (Rel u) (Rel u) + deriving (Eq, Ord, Show) + +infixr 4 :<: + +infixr 4 :<=: + +infixr 4 :>: + +infixr 4 :>=: + +infixr 4 :==: + +infixr 4 :/=: + +infixr 3 :&&: + +infixr 2 :||: + +andRel :: [Rel u] -> Rel u +andRel = foldr1 (:&&:) + +orRel :: [Rel u] -> Rel u +orRel = foldr1 (:||:) + +instance (Pretty u) => Pretty (Rel u) where + pretty c = + case c of + x :<: y -> op "<" x y + x :<=: y -> op "<=" x y + x :>: y -> op ">" x y + x :>=: y -> op ">=" x y + x :==: y -> op "==" x y + x :/=: y -> op "/=" x y + x :&&: y -> op "&&" x y + x :||: y -> op "||" x y + where + op s x y = pretty x <+> s <+> pretty y diff --git a/src/Futhark/SoP/Util.hs b/src/Futhark/SoP/Util.hs new file mode 100644 index 0000000000..9a4c8b1c4c --- /dev/null +++ b/src/Futhark/SoP/Util.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} + +module Futhark.SoP.Util + ( anyM, + allM, + ifM, + toMS, + localS, + type (>=), + type (==), + (^&&), + (^||), + andM, + orM, + asumM, + ) +where + +import Control.Applicative +import Control.Monad.State +import Data.Foldable +import Data.MultiSet (MultiSet) +import Data.MultiSet qualified as MS +import GHC.TypeLits (Natural) + +ifM :: (Monad m) => m Bool -> m a -> m a -> m a +ifM mb mt mf = do + b <- mb + if b then mt else mf + +(^&&) :: (Monad m) => m Bool -> m Bool -> m Bool +x ^&& y = ifM x y (pure False) + +infixr 3 ^&& + +(^||) :: (Monad m) => m Bool -> m Bool -> m Bool +x ^|| y = ifM x (pure True) y + +infixr 2 ^|| + +andM :: (Monad m, Foldable t) => t (m Bool) -> m Bool +andM = allM id + +orM :: (Monad m, Foldable t) => t (m Bool) -> m Bool +orM = anyM id + +anyM :: (Monad m, Foldable t) => (a -> m Bool) -> t a -> m Bool +anyM p = foldr (\a b -> ifM (p a) (pure True) b) (pure False) + +allM :: (Monad m, Foldable t) => (a -> m Bool) -> t a -> m Bool +allM p = foldr (\a b -> ifM (p a) b (pure False)) (pure True) + +toMS :: (Ord a, Foldable t) => t a -> MultiSet a +toMS = MS.fromList . Data.Foldable.toList + +localS :: (MonadState s m) => (s -> s) -> m a -> m a +localS f m = do + env <- get + modify f + a <- m + put env + pure a + +-- | A type label to indicate @a >= 0@. +type a >= (b :: Natural) = a + +-- | A type label to indicate @a = 0@. +type a == (b :: Natural) = a + +asumM :: (Monad m, Traversable t, Alternative f) => t (m (f a)) -> m (f a) +asumM = (fmap asum) . sequence diff --git a/src/Language/Futhark/FreeVars.hs b/src/Language/Futhark/FreeVars.hs index 3fdb6a522c..185081df8f 100644 --- a/src/Language/Futhark/FreeVars.hs +++ b/src/Language/Futhark/FreeVars.hs @@ -153,6 +153,8 @@ freeInType t = freeInType t1 <> freeInType t2 Scalar (TypeVar _ _ _ targs) -> foldMap typeArgDims targs + Scalar (Refinement ty e) -> + freeInType ty <> freeInExp e where typeArgDims (TypeArgDim d) = freeInExp d typeArgDims (TypeArgType at) = freeInType at diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index be98c16614..a9edb415a2 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -626,6 +626,7 @@ expandType env (Scalar (TypeVar () u tn args)) = expandArg (TypeArgDim s) = TypeArgDim s expandArg (TypeArgType t) = TypeArgType $ expandType env t expandType env (Scalar (Sum cs)) = Scalar $ Sum $ (fmap . fmap) (expandType env) cs +expandType _ (Scalar Refinement {}) = error "Refinement not implemented in expandType" evalWithExts :: Env -> EvalM Eval evalWithExts env = do diff --git a/src/Language/Futhark/Parser/Parser.y b/src/Language/Futhark/Parser/Parser.y index f6ea390204..f3fccd193d 100644 --- a/src/Language/Futhark/Parser/Parser.y +++ b/src/Language/Futhark/Parser/Parser.y @@ -467,6 +467,7 @@ TypeExpTerm :: { UncheckedTypeExp } | TypeExpApply %prec typeprec { $1 } | SumClauses %prec sumprec { let (cs, loc) = $1 in TESum cs (srclocOf loc) } + | '{' TypeExp '|' Exp '}' { TERefine $2 $4 (srcspan $1 $>) } SumClauses :: { ([(Name, [UncheckedTypeExp])], Loc) } : SumClauses '|' SumClause %prec sumprec diff --git a/src/Language/Futhark/Prelude.hs b/src/Language/Futhark/Prelude.hs index b633f6c6ff..e6a25c6e0f 100644 --- a/src/Language/Futhark/Prelude.hs +++ b/src/Language/Futhark/Prelude.hs @@ -1,7 +1,7 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} --- | The Futhark Prelude Library embedded embedded as strings read +-- | The Futhark Prelude Library embedded embedded as strings read . -- during compilation of the Futhark compiler. The advantage is that -- the prelude can be accessed without reading it from disk, thus -- saving users from include path headaches. diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 7a6ba0b52c..2a4e8f46e8 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -142,6 +142,7 @@ prettyScalarType p (Sum cs) = where ppConstr (name, fs) = sep $ ("#" <> pretty name) : map (prettyType 2) fs cs' = map ppConstr $ M.toList cs +prettyScalarType _ (Refinement ty p) = "{" <> pretty ty <> "| " <> pretty p <> "}" instance Pretty (Shape dim) => Pretty (ScalarTypeBase dim as) where pretty = prettyScalarType 0 @@ -182,6 +183,8 @@ instance (Eq vn, IsName vn, Annot f) => Pretty (TypeExp f vn) where ppConstr (name, fs) = "#" <> pretty name <+> sep (map pretty fs) pretty (TEDim dims te _) = "?" <> mconcat (map (brackets . prettyName) dims) <> "." <> pretty te + pretty (TERefine ty p _) = + "{" <> pretty ty <> "| " <> pretty p <> "}" instance (Eq vn, IsName vn, Annot f) => Pretty (TypeArgExp f vn) where pretty (TypeArgExpSize d) = pretty d diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index a0ac4e1d52..7131789761 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -159,7 +159,8 @@ arrayRank = shapeRank . arrayShape -- | Return the shape of a type - for non-arrays, this is 'mempty'. arrayShape :: TypeBase dim as -> Shape dim -arrayShape (Array _ _ ds _) = ds +arrayShape (Array _ _ ds scal) = ds <> arrayShape (Scalar scal) +arrayShape (Scalar (Refinement t _)) = arrayShape t arrayShape _ = mempty -- | Change the shape of a type to be just the rank. @@ -195,6 +196,8 @@ traverseDims f = go mempty PosImmediate DimPos -> TypeBase fdim als' -> f (TypeBase tdim als') + go bound b t@(Scalar Refinement {}) = + bitraverse (f bound b) pure t go bound b t@Array {} = bitraverse (f bound b) pure t go bound b (Scalar (Record fields)) = @@ -249,6 +252,7 @@ diet (Array _ Nonunique _ _) = Observe diet (Scalar (TypeVar _ Unique _ _)) = Consume diet (Scalar (TypeVar _ Nonunique _ _)) = Observe diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs +diet (Scalar (Refinement ty _)) = diet ty -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. @@ -427,6 +431,8 @@ setUniqueness (Scalar (Record ets)) u = Scalar $ Record $ fmap (`setUniqueness` u) ets setUniqueness (Scalar (Sum ets)) u = Scalar $ Sum $ fmap (map (`setUniqueness` u)) ets +setUniqueness (Scalar (Refinement ty e)) u = + Scalar $ Refinement (setUniqueness ty u) e setUniqueness t _ = t -- | @t \`setAliases\` als@ returns @t@, but with @als@ substituted for @@ -582,6 +588,7 @@ typeVars t = Scalar (Arrow _ _ _ t1 (RetType _ t2)) -> typeVars t1 <> typeVars t2 Scalar (Record fields) -> foldMap typeVars fields Scalar (Sum cs) -> mconcat $ (foldMap . fmap) typeVars cs + Scalar (Refinement ty _) -> typeVars ty Array _ _ _ rt -> typeVars $ Scalar rt where typeArgFree (TypeArgType ta) = typeVars ta @@ -597,6 +604,7 @@ orderZero (Scalar (Record fs)) = all orderZero $ M.elems fs orderZero (Scalar TypeVar {}) = True orderZero (Scalar Arrow {}) = False orderZero (Scalar (Sum cs)) = all (all orderZero) cs +orderZero (Scalar (Refinement ty _)) = orderZero ty -- | @patternOrderZero pat@ is 'True' if all of the types in the given pattern -- have order 0. diff --git a/src/Language/Futhark/Query.hs b/src/Language/Futhark/Query.hs index 3350baf0ad..7adb7b2686 100644 --- a/src/Language/Futhark/Query.hs +++ b/src/Language/Futhark/Query.hs @@ -230,6 +230,7 @@ atPosInTypeExp te pos = msum $ map (`atPosInTypeExp` pos) $ concatMap snd cs TEDim _ t _ -> atPosInTypeExp t pos + TERefine t e _ -> atPosInTypeExp t pos `mplus` atPosInExp e pos where inArg (TypeArgExpSize dim) = inDim dim inArg (TypeArgExpType e2) = atPosInTypeExp e2 pos diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index 26aadd53cb..50126acc7d 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -310,6 +310,7 @@ data ScalarTypeBase dim as | -- | The aliasing corresponds to the lexical -- closure of the function. Arrow as PName Diet (TypeBase dim ()) (RetTypeBase dim as) + | Refinement (TypeBase dim as) (ExpBase Info VName) deriving (Eq, Ord, Show) instance Bitraversable ScalarTypeBase where @@ -320,6 +321,7 @@ instance Bitraversable ScalarTypeBase where bitraverse f g (Arrow als v d t1 t2) = Arrow <$> g als <*> pure v <*> pure d <*> bitraverse f pure t1 <*> bitraverse f g t2 bitraverse f g (Sum cs) = Sum <$> (traverse . traverse) (bitraverse f g) cs + bitraverse f g (Refinement ty predicate) = Refinement <$> bitraverse f g ty <*> pure predicate instance Functor (ScalarTypeBase dim) where fmap = fmapDefault @@ -452,6 +454,7 @@ data TypeExp f vn | TEArrow (Maybe vn) (TypeExp f vn) (TypeExp f vn) SrcLoc | TESum [(Name, [TypeExp f vn])] SrcLoc | TEDim [vn] (TypeExp f vn) SrcLoc + | TERefine (TypeExp f vn) (ExpBase f vn) SrcLoc deriving instance Show (TypeExp Info VName) @@ -476,6 +479,7 @@ instance Located (TypeExp f vn) where locOf (TEArrow _ _ _ loc) = locOf loc locOf (TESum _ loc) = locOf loc locOf (TEDim _ _ loc) = locOf loc + locOf (TERefine _ _ loc) = locOf loc -- | A type argument expression passed to a type constructor. data TypeArgExp f vn diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index f9602aaa14..7b51b81af1 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -255,6 +255,8 @@ instance ASTMappable (TypeExp Info VName) where TESum <$> traverse (traverse $ astMap tv) cs <*> pure loc astMap tv (TEDim dims t loc) = TEDim dims <$> astMap tv t <*> pure loc + astMap tv (TERefine ty p loc) = + TERefine <$> astMap tv ty <*> astMap tv p <*> pure loc instance ASTMappable (TypeArgExp Info VName) where astMap tv (TypeArgExpSize dim) = TypeArgExpSize <$> astMap tv dim @@ -290,48 +292,52 @@ type TypeTraverser f t dim1 als1 dim2 als2 = (QualName VName -> f (QualName VName)) -> (dim1 -> f dim2) -> (als1 -> f als2) -> + (ExpBase Info VName -> f (ExpBase Info VName)) -> t dim1 als1 -> f (t dim2 als2) traverseScalarType :: Applicative f => TypeTraverser f ScalarTypeBase dim1 als1 dims als2 -traverseScalarType _ _ _ (Prim t) = pure $ Prim t -traverseScalarType f g h (Record fs) = Record <$> traverse (traverseType f g h) fs -traverseScalarType f g h (TypeVar als u t args) = - TypeVar <$> h als <*> pure u <*> f t <*> traverse (traverseTypeArg f g) args -traverseScalarType f g h (Arrow als v u t1 (RetType dims t2)) = +traverseScalarType _ _ _ _ (Prim t) = pure $ Prim t +traverseScalarType f g h e (Record fs) = Record <$> traverse (traverseType f g h e) fs +traverseScalarType f g h e (TypeVar als u t args) = + TypeVar <$> h als <*> pure u <*> f t <*> traverse (traverseTypeArg f g e) args +traverseScalarType f g h e (Arrow als v u t1 (RetType dims t2)) = Arrow <$> h als <*> pure v <*> pure u - <*> traverseType f g pure t1 - <*> (RetType dims <$> traverseType f g h t2) -traverseScalarType f g h (Sum cs) = - Sum <$> (traverse . traverse) (traverseType f g h) cs + <*> traverseType f g pure e t1 + <*> (RetType dims <$> traverseType f g h e t2) +traverseScalarType f g h e (Sum cs) = + Sum <$> (traverse . traverse) (traverseType f g h e) cs +traverseScalarType f g h e (Refinement ty p) = + Refinement <$> traverseType f g h e ty <*> e p traverseType :: Applicative f => TypeTraverser f TypeBase dim1 als1 dims als2 -traverseType f g h (Array als u shape et) = - Array <$> h als <*> pure u <*> traverse g shape <*> traverseScalarType f g pure et -traverseType f g h (Scalar t) = - Scalar <$> traverseScalarType f g h t +traverseType f g h e (Array als u shape et) = + Array <$> h als <*> pure u <*> traverse g shape <*> traverseScalarType f g pure e et +traverseType f g h e (Scalar t) = + Scalar <$> traverseScalarType f g h e t traverseTypeArg :: Applicative f => (QualName VName -> f (QualName VName)) -> (dim1 -> f dim2) -> + (ExpBase Info VName -> f (ExpBase Info VName)) -> TypeArg dim1 -> f (TypeArg dim2) -traverseTypeArg _ g (TypeArgDim d) = +traverseTypeArg _ g _ (TypeArgDim d) = TypeArgDim <$> g d -traverseTypeArg f g (TypeArgType t) = - TypeArgType <$> traverseType f g pure t +traverseTypeArg f g e (TypeArgType t) = + TypeArgType <$> traverseType f g pure e t instance ASTMappable StructType where - astMap tv = traverseType (astMap tv) (mapOnExp tv) pure + astMap tv = traverseType (astMap tv) (mapOnExp tv) pure (mapOnExp tv) instance ASTMappable PatType where - astMap tv = traverseType (astMap tv) (mapOnExp tv) (astMap tv) + astMap tv = traverseType (astMap tv) (mapOnExp tv) (astMap tv) (mapOnExp tv) instance ASTMappable StructRetType where astMap tv (RetType ext t) = RetType ext <$> astMap tv t @@ -453,6 +459,7 @@ bareTypeExp (TEApply ty ta loc) = TEApply (bareTypeExp ty) (bareTypeArgExp ta) l bareTypeExp (TEArrow arg tya tyr loc) = TEArrow arg (bareTypeExp tya) (bareTypeExp tyr) loc bareTypeExp (TESum cs loc) = TESum (map (second $ map bareTypeExp) cs) loc bareTypeExp (TEDim names ty loc) = TEDim names (bareTypeExp ty) loc +bareTypeExp (TERefine ty p loc) = TERefine (bareTypeExp ty) (bareExp p) loc -- | Remove all annotations from an expression, but retain the -- name/scope information. diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index a09b23832c..0a5abdf862 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -51,7 +51,7 @@ checkProg :: UncheckedProg -> (Warnings, Either TypeError (FileModule, VNameSource)) checkProg files src name prog = - runTypeM initialEnv files' name src checkSizeExp $ checkProgM prog + runTypeM initialEnv files' name src checkSizeExp checkPredExp $ checkProgM prog where files' = M.map fileEnv $ M.fromList files @@ -67,7 +67,7 @@ checkExp :: UncheckedExp -> (Warnings, Either TypeError ([TypeParam], Exp)) checkExp files src env e = - second (fmap fst) $ runTypeM env files' (mkInitialImport "") src checkSizeExp $ checkOneExp e + second (fmap fst) $ runTypeM env files' (mkInitialImport "") src checkSizeExp checkPredExp $ checkOneExp e where files' = M.map fileEnv $ M.fromList files @@ -84,7 +84,7 @@ checkDec :: (Warnings, Either TypeError (Env, Dec, VNameSource)) checkDec files src env name d = second (fmap massage) $ - runTypeM env files' name src checkSizeExp $ do + runTypeM env files' name src checkSizeExp checkPredExp $ do (_, env', d') <- checkOneDec d pure (env' <> env, d') where @@ -103,7 +103,7 @@ checkModExp :: ModExpBase NoInfo Name -> (Warnings, Either TypeError (MTy, ModExpBase Info VName)) checkModExp files src env me = - second (fmap fst) . runTypeM env files' (mkInitialImport "") src checkSizeExp $ do + second (fmap fst) . runTypeM env files' (mkInitialImport "") src checkSizeExp checkPredExp $ do (_abs, mty, me') <- checkOneModExp me pure (mty, me') where diff --git a/src/Language/Futhark/TypeChecker/Modules.hs b/src/Language/Futhark/TypeChecker/Modules.hs index 3a75b7f558..da589853a8 100644 --- a/src/Language/Futhark/TypeChecker/Modules.hs +++ b/src/Language/Futhark/TypeChecker/Modules.hs @@ -145,6 +145,8 @@ newNamesForMTy orig_mty = do map substituteInTypeArg targs substituteInType (Scalar (Prim t)) = Scalar $ Prim t + substituteInType (Scalar (Refinement ty e)) = + Scalar $ Refinement (substituteInType ty) e substituteInType (Scalar (Record ts)) = Scalar $ Record $ fmap substituteInType ts substituteInType (Scalar (Sum ts)) = diff --git a/src/Language/Futhark/TypeChecker/Monad.hs b/src/Language/Futhark/TypeChecker/Monad.hs index 3d3c21812a..45468fe4b3 100644 --- a/src/Language/Futhark/TypeChecker/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Monad.hs @@ -163,7 +163,8 @@ data Context = Context -- | Currently type-checking at the top level? If false, we are -- inside a module. contextAtTopLevel :: Bool, - contextCheckExp :: UncheckedExp -> TypeM Exp + contextCheckExpForSize :: UncheckedExp -> TypeM Exp, + contextCheckExpForPred :: StructType -> UncheckedExp -> TypeM Exp } data TypeState = TypeState @@ -207,10 +208,11 @@ runTypeM :: ImportName -> VNameSource -> (UncheckedExp -> TypeM Exp) -> + (StructType -> UncheckedExp -> TypeM Exp) -> TypeM a -> (Warnings, Either TypeError (a, VNameSource)) -runTypeM env imports fpath src checker (TypeM m) = do - let ctx = Context env imports fpath True checker +runTypeM env imports fpath src sizeChecker predChecker (TypeM m) = do + let ctx = Context env imports fpath True sizeChecker predChecker s = TypeState src mempty 0 case runExcept $ runStateT (runReaderT m ctx) s of Left (ws, e) -> (ws, Left e) @@ -289,6 +291,7 @@ class Monad m => MonadTypeChecker m where lookupVar :: SrcLoc -> QualName Name -> m (QualName VName, PatType) checkExpForSize :: UncheckedExp -> m Exp + checkExpForPred :: StructType -> UncheckedExp -> m Exp typeError :: Located loc => loc -> Notes -> Doc () -> m a @@ -372,9 +375,13 @@ instance MonadTypeChecker TypeM where ) checkExpForSize e = do - checker <- asks contextCheckExp + checker <- asks contextCheckExpForSize checker e + checkExpForPred t e = do + checker <- asks contextCheckExpForPred + checker t e + typeError loc notes s = throwError $ TypeError (locOf loc) notes s -- | Extract from a type a first-order type. @@ -435,6 +442,8 @@ qualifyTypeVars outer_env orig_except ref_qs = onType (S.fromList orig_except) except' = case p of Named p' -> S.insert p' except Unnamed -> except + onScalar except (Refinement ty p) = + Refinement (onType except ty) p onTypeArg except (TypeArgDim d) = TypeArgDim $ onDim except d diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 3f5241146c..968d06e7cc 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -9,6 +9,7 @@ module Language.Futhark.TypeChecker.Terms ( checkOneExp, checkSizeExp, + checkPredExp, checkFunDef, ) where @@ -41,18 +42,6 @@ import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) -hasBinding :: Exp -> Bool -hasBinding Lambda {} = True -hasBinding (AppExp LetPat {} _) = True -hasBinding (AppExp LetFun {} _) = True -hasBinding (AppExp DoLoop {} _) = True -hasBinding (AppExp LetWith {} _) = True -hasBinding (AppExp Match {} _) = True -hasBinding e = isNothing $ astMap m e - where - m = - identityMapper {mapOnExp = \e' -> if hasBinding e' then Nothing else Just e'} - overloadedTypeVars :: Constraints -> Names overloadedTypeVars = mconcat . map f . M.elems where @@ -81,11 +70,17 @@ sliceShape :: [(DimIndex, Maybe Occurrence)] -> TypeBase Size as -> TermTypeM (TypeBase Size as, [VName]) -sliceShape r slice t@(Array als u (Shape orig_dims) et) = - runStateT (setDims <$> adjustDims slice orig_dims) [] +sliceShape r slice (Array als u (Shape orig_dims) et) = do + (ty, (exts, _)) <- runStateT (setDims =<< adjustDims slice orig_dims) ([], (et, als, u)) + pure (ty, exts) where - setDims [] = stripArray (length orig_dims) t - setDims dims' = Array als u (Shape dims') et + setDims :: [Size] -> StateT ([VName], (ScalarTypeBase Size (), as, Uniqueness)) TermTypeM (TypeBase Size as) + setDims [] = do + (et', als', u') <- gets snd + pure $ stripArray 0 $ Array als' u' (Shape []) et' + setDims dims' = do + (et', als', u') <- gets snd + pure $ Array als' u' (Shape dims') et' -- If the result is supposed to be a nonrigid size variable, then -- don't bother trying to create non-existential sizes. This is @@ -101,7 +96,7 @@ sliceShape r slice t@(Array als u (Shape orig_dims) et) = (d, ext) <- lift . extSize loc $ SourceSlice orig_d' (bareExp <$> i) (bareExp <$> j) (bareExp <$> stride) - modify (maybeToList ext ++) + modify $ first (maybeToList ext ++) pure d Just (loc, Nonrigid) -> lift $ @@ -109,7 +104,7 @@ sliceShape r slice t@(Array als u (Shape orig_dims) et) = <$> newFlexibleDim (mkUsage loc "size of slice") "slice_dim" Nothing -> do v <- lift $ newID "slice_anydim" - modify (v :) + modify $ first (v :) pure $ sizeFromName (qualName v) mempty where -- The original size does not matter if the slice is fully specified. @@ -134,6 +129,9 @@ sliceShape r slice t@(Array als u (Shape orig_dims) et) = (_, False) -> pure (size :) + adjustDims :: [(DimIndex, Maybe Occurrence)] -> [Size] -> StateT ([VName], (ScalarTypeBase Size (), as, Uniqueness)) TermTypeM [Size] + adjustDims [] dims = + pure dims adjustDims ((DimFix {}, _) : idxes') (_ : dims) = adjustDims idxes' dims -- Pat match some known slices to be non-existential. @@ -164,8 +162,19 @@ sliceShape r slice t@(Array als u (Shape orig_dims) et) = -- existential adjustDims ((DimSlice i j stride, _) : idxes') (d : dims) = (:) <$> sliceSize d i j stride <*> adjustDims idxes' dims - adjustDims _ dims = - pure dims + -- go through Refinement + adjustDims idxes [] = do + (et', als', u') <- gets snd + let throughRefine ty = + case ty of + Scalar (Refinement ty' _) -> throughRefine ty' + _ -> ty + underlying = throughRefine $ stripArray 0 $ Array als' u' (Shape []) et' + case underlying of + Array als'' u'' (Shape dims'') et'' -> do + modify $ second $ const (et'', als'', u'') + adjustDims idxes dims'' + _ -> error $ "no more dimension to take from " ++ prettyString et' sizeMinus j i = AppExp @@ -180,6 +189,7 @@ sliceShape r slice t@(Array als u (Shape orig_dims) et) = $ AppRes i64 [] i64 = Scalar $ Prim $ Signed Int64 sizeBinOpInfo = Info $ foldFunType [(Observe, i64), (Observe, i64)] $ RetType [] i64 +sliceShape r slice (Scalar (Refinement ty _)) = sliceShape r slice ty sliceShape _ _ t = pure (t, []) --- Main checkers @@ -314,6 +324,8 @@ sizeFree tloc expKiller orig_t = do e' <- replacing e local ((e, e') :) m + onScalar (Refinement ty e) = + Refinement <$> onType ty <*> pure e onScalar (Record fs) = Record <$> traverse onType fs onScalar (Sum cs) = @@ -1108,6 +1120,7 @@ boundInsideType (Scalar (TypeVar _ _ _ targs)) = foldMap f targs f TypeArgDim {} = mempty boundInsideType (Scalar (Record fs)) = foldMap boundInsideType fs boundInsideType (Scalar (Sum cs)) = foldMap (foldMap boundInsideType) cs +boundInsideType (Scalar (Refinement ty _)) = boundInsideType ty boundInsideType (Scalar (Arrow _ pn _ t1 (RetType dims t2))) = pn' <> boundInsideType t1 <> S.fromList dims <> boundInsideType t2 where @@ -1293,6 +1306,15 @@ checkSizeExp e = fmap fst . runTermTypeM checkExp $ do unify (mkUsage e' "Size expression") t (Scalar (Prim (Signed Int64))) updateTypes e' +-- | Type-check a single predicate expression in isolation. This expression may +-- turn out to be polymorphic, in which case it is unified with t -> bool. +checkPredExp :: StructType -> UncheckedExp -> TypeM Exp +checkPredExp t e = fmap fst . runTermTypeM checkExp $ do + e' <- noUnique $ checkExp e + let t' = toStruct $ typeOf e' + unify (mkUsage e' "Refinement Predicate") t' (Scalar (Arrow () Unnamed Observe t (RetType [] $ Scalar $ Prim Bool))) + updateTypes e' + -- Verify that all sum type constructors and empty array literals have -- a size that is known (rigid or a type parameter). This is to -- ensure that we can actually determine their shape at run-time. @@ -1660,6 +1682,7 @@ checkReturnAlias loc rettp params = consumableParamType (Scalar (TypeVar _ u _ _)) = u == Unique consumableParamType (Scalar (Record fs)) = all consumableParamType fs consumableParamType (Scalar (Sum fs)) = all (all consumableParamType) fs + consumableParamType (Scalar (Refinement ty _)) = consumableParamType ty consumableParamType (Scalar Arrow {}) = False checkBinding :: @@ -1793,6 +1816,7 @@ boundArrayAliases (Scalar (TypeVar als _ _ _)) = boundAliases als boundArrayAliases (Scalar Arrow {}) = mempty boundArrayAliases (Scalar (Sum fs)) = mconcat $ concatMap (map boundArrayAliases) $ M.elems fs +boundArrayAliases (Scalar (Refinement ty _)) = boundArrayAliases ty nothingMustBeUnique :: SrcLoc -> TypeBase () () -> TermTypeM () nothingMustBeUnique loc = check @@ -1857,6 +1881,7 @@ injectExt ext ret = RetType ext_here $ deeper ret (ext_here, ext_there) = partition (`S.member` immediate) ext deeper (Scalar (Prim t)) = Scalar $ Prim t deeper (Scalar (Record fs)) = Scalar $ Record $ M.map deeper fs + deeper (Scalar (Refinement ty e)) = Scalar $ Refinement (deeper ty) e deeper (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map deeper) cs deeper (Scalar (Arrow als p d1 t1 (RetType t2_ext t2))) = Scalar $ Arrow als p d1 t1 $ injectExt (ext_there <> t2_ext) t2 diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 4935d091e5..0cfb729c7c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -27,6 +27,7 @@ module Language.Futhark.TypeChecker.Terms.Monad newArrayType, allDimsFreshInType, updateTypes, + hasBinding, -- * Primitive checking unifies, @@ -451,6 +452,18 @@ nameReason loc (NameAppRes fname apploc) = <+> dquotes (pretty fname) <+> parens ("at" <+> pretty (locStrRel loc apploc)) +hasBinding :: Exp -> Bool +hasBinding Lambda {} = True +hasBinding (AppExp LetPat {} _) = True +hasBinding (AppExp LetFun {} _) = True +hasBinding (AppExp DoLoop {} _) = True +hasBinding (AppExp LetWith {} _) = True +hasBinding (AppExp Match {} _) = True +hasBinding e = isNothing $ astMap m e + where + m = + identityMapper {mapOnExp = \e' -> if hasBinding e' then Nothing else Just e'} + -- | The state is a set of constraints and a counter for generating -- type names. This is distinct from the usual counter we use for -- generating unique names, as these will be user-visible. @@ -621,6 +634,16 @@ instance MonadTypeChecker TermTypeM where e' <- noUnique $ checker e let t = toStruct $ typeOf e' unify (mkUsage (srclocOf e') "Size expression") t (Scalar (Prim (Signed Int64))) + when (hasBinding e') $ + typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ + "Size expression with binding is forbidden." + updateTypes e' + + checkExpForPred t e = do + checker <- asks termChecker + e' <- noUnique $ checker e + let t' = toStruct $ typeOf e' + unify (mkUsage e' "Refinement Predicate") t' (Scalar (Arrow () Unnamed Observe t (RetType [] $ Scalar $ Prim Bool))) updateTypes e' warn loc problem = liftTypeM $ warn loc problem diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index cf17b36b2c..18fafde6d4 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -74,6 +74,7 @@ checkIfUsed allow_consume occs v consumes = maybe False (identName v `S.member`) . consumed consumable (Scalar (Record fs)) = all consumable fs + consumable (Scalar (Refinement ty _)) = consumable ty consumable (Scalar (Sum cs)) = all (all consumable) cs consumable (Scalar (TypeVar _ u _ _)) = u == Unique consumable (Scalar Arrow {}) = True diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 7fd4327bbc..c87e47efca 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -12,6 +12,7 @@ module Language.Futhark.TypeChecker.Types TypeSubs, Substitutable (..), substTypesAny, + removeRefinement, -- * Witnesses mustBeExplicitInType, @@ -34,6 +35,15 @@ import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad +removeRefinement :: TypeBase dim () -> TypeBase dim () +removeRefinement (Array () u shape scal) = + let ty' = removeRefinement (Scalar scal) + in case ty' of + Array () u' shape' scal' -> Array () (u <> u') (shape <> shape') scal' + Scalar scal' -> Array () u shape scal' +removeRefinement (Scalar (Refinement ty _)) = removeRefinement ty +removeRefinement t = t + mustBeExplicitAux :: StructType -> M.Map VName Bool mustBeExplicitAux t = execState (traverseDims onDim t) mempty @@ -101,6 +111,8 @@ returnType _ (Scalar (Arrow old_als v pd t1 (RetType dims t2))) d arg = als = old_als <> aliases (maskAliases arg d) returnType appres (Scalar (Sum cs)) d arg = Scalar $ Sum $ (fmap . fmap) (\et -> returnType appres et d arg) cs +returnType appres (Scalar (Refinement ty e)) d arg = + Scalar $ Refinement (returnType appres ty d arg) e -- @t `maskAliases` d@ removes aliases (sets them to 'mempty') from -- the parts of @t@ that are denoted as consumed by the 'Diet' @d@. @@ -239,6 +251,7 @@ evalTypeExp (TEUnique t loc) = do mayContainArray (Scalar TypeVar {}) = True mayContainArray (Scalar Arrow {}) = False mayContainArray (Scalar (Sum cs)) = (any . any) mayContainArray cs + mayContainArray (Scalar (Refinement ty _)) = mayContainArray ty -- evalTypeExp (TEArrow (Just v) t1 t2 loc) = do (t1', svars1, RetType dims1 st1, _) <- evalTypeExp t1 @@ -379,6 +392,10 @@ evalTypeExp ote@TEApply {} = do <+> pretty a <+> "not valid for a type parameter" <+> pretty p <> "." +evalTypeExp (TERefine te e loc) = do + (te', svars, RetType dims ty, ls) <- evalTypeExp te + e' <- checkExpForPred ty e + pure (TERefine te' e' loc, svars, RetType dims (Scalar $ Refinement ty e'), ls) -- | Check a type expression, producing: -- @@ -472,6 +489,7 @@ checkForDuplicateNamesInType = check mempty check _ TEArray {} = pure () check _ TEVar {} = pure () check seen (TEParens te _) = check seen te + check seen (TERefine te _ _) = check seen te -- | @checkTypeParams ps m@ checks the type parameters @ps@, then -- invokes the continuation @m@ with the checked parameters, while @@ -672,6 +690,8 @@ substTypesRet lookupSubst ot = Scalar <$> (Arrow als v d <$> onType t1 <*> onRetType t2) onType (Scalar (Sum ts)) = Scalar . Sum <$> traverse (traverse onType) ts + onType (Scalar (Refinement ty e)) = + Scalar . (`Refinement` applySubst lookupSubst' e) <$> onType ty onRetType (RetType dims t) = do ext <- get diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 71bb059be0..2117eae4a7 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -546,6 +546,8 @@ unifyWith onDims usage = subunify False unifySharedConstructors onDims usage bound bcs cs arg_cs | otherwise -> unifyError usage mempty bcs $ unsharedConstructorsMsg arg_cs cs + (Scalar (Refinement ty _), _) -> subunify ord bound bcs ty t2' + (_, Scalar (Refinement ty _)) -> subunify ord bound bcs t1' ty _ | t1' == t2' -> pure () | otherwise -> failure @@ -649,16 +651,17 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do tp <- normTypeFully tp_unnorm occursCheck usage bcs vn tp scopeCheck usage bcs vn lvl tp + let tp_unrefined = removeRefinement tp constraints <- getConstraints - let link = do - let (witnessed, not_witnessed) = determineSizeWitnesses tp + let link ty = do + let (witnessed, not_witnessed) = determineSizeWitnesses ty used v = v `S.member` witnessed || v `S.member` not_witnessed ext = filter used bound case filter (`notElem` witnessed) ext of [] -> modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp) usage) + M.insert vn (lvl, Constraint (RetType ext ty) usage) problems -> unifyError usage mempty bcs . withIndexLink "unify-param-existential" $ "Parameter(s) " @@ -677,7 +680,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do case snd <$> M.lookup vn constraints of Just (NoConstraint Unlifted unlift_usage) -> do - link + link tp arrayElemTypeWith usage (unliftedBcs unlift_usage) tp when (any (`elem` bound) (fvVars (freeInType tp))) $ @@ -688,12 +691,12 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do indent 2 (pretty tp) textwrap "This is usually because the size of an array returned by a higher-order function argument cannot be determined statically. This can also be due to the return size being a value parameter. Add type annotation to clarify." Just (Equality _) -> do - link + link tp equalityType usage tp Just (Overloaded ts old_usage) - | tp `notElem` map (Scalar . Prim) ts -> do - link - case tp of + | tp_unrefined `notElem` map (Scalar . Prim) ts -> do + link tp_unrefined + case tp_unrefined of Scalar (TypeVar _ _ (QualName [] v) []) | not $ isRigid v constraints -> linkVarToTypes usage v ts @@ -711,7 +714,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do <+> pretty old_usage <> "." Just (HasFields l required_fields old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of + case tp_unrefined of Scalar (Record tp_fields) | all (`M.member` tp_fields) $ M.keys required_fields -> do required_fields' <- mapM normTypeFully required_fields @@ -729,7 +732,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do _ -> do notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v noRecordType notes - link + link tp_unrefined modifyConstraints $ M.insertWith combineFields @@ -754,7 +757,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do -- See Note [Linking variables to sum types] Just (HasConstrs l required_cs old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of + case tp_unrefined of Scalar (Sum ts) | all (`M.member` ts) $ M.keys required_cs -> do let tp' = Scalar $ Sum $ required_cs <> ts -- Crucially left-biased. @@ -773,7 +776,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do _ -> do notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v noSumType notes - link + link tp_unrefined modifyConstraints $ M.insertWith combineConstrs @@ -784,7 +787,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do (lvl, HasConstrs (l1 `min` l2) (M.union cs1 cs2) usage1) combineConstrs hasCs _ = hasCs _ -> noSumType =<< typeVarNotes vn - _ -> link + _ -> link tp_unrefined where unsharedConstructors cs1 cs2 notes = unifyError @@ -855,7 +858,7 @@ mustBeOneOf ts usage t = do constraints <- getConstraints let isRigid' v = isRigid v constraints - case t' of + case removeRefinement t' of Scalar (TypeVar _ _ (QualName [] v) []) | not $ isRigid' v -> linkVarToTypes usage v ts Scalar (Prim pt) | pt `elem` ts -> pure () diff --git a/tests/refinement/indirect0.fut b/tests/refinement/indirect0.fut new file mode 100644 index 0000000000..8caf369165 --- /dev/null +++ b/tests/refinement/indirect0.fut @@ -0,0 +1,5 @@ +def g [n] (xs: [n]i64) : i64 = + xs[0] + +entry f (xs: [2]i64) : i64 = + g xs diff --git a/tests/refinement/indirect1.fut b/tests/refinement/indirect1.fut new file mode 100644 index 0000000000..adddb9078a --- /dev/null +++ b/tests/refinement/indirect1.fut @@ -0,0 +1,6 @@ +def g [n] (xs: [n]i64) : i64 = + let k = n + 1 + in xs[0] + +entry f (xs: [2]i64) : i64 = + g xs diff --git a/tests/refinement/iota0.fut b/tests/refinement/iota0.fut new file mode 100644 index 0000000000..bb1bda2b4c --- /dev/null +++ b/tests/refinement/iota0.fut @@ -0,0 +1,3 @@ +entry f : i64 = + let xs = iota 5 + in xs[0] diff --git a/tests/refinement/part2indices.fut b/tests/refinement/part2indices.fut new file mode 100644 index 0000000000..a5634b1b13 --- /dev/null +++ b/tests/refinement/part2indices.fut @@ -0,0 +1,9 @@ +def part2Indices [n] 't (conds: [n]bool) : {[n]i64 | \res-> permutationOf res (0...n-1)} = + let tflgs = map (\c -> i64.bool c) conds + let fflgs = map (\ b -> 1 - b) tflgs + let indsT = scan (+) 0 tflgs + let tmp = scan (+) 0 fflgs + let lst = if n > 0 then indsT[n-1] else 0 + let indsF = map (\t -> t +lst) tmp + let inds = map3 (\ c indT indF -> if c then indT-1 else indF-1) conds indsT indsF + in inds diff --git a/tests/refinement/safe-index0.fut b/tests/refinement/safe-index0.fut new file mode 100644 index 0000000000..2cfee7a977 --- /dev/null +++ b/tests/refinement/safe-index0.fut @@ -0,0 +1,3 @@ +entry f (xs: [2]i64) : i64 = + xs[0] + diff --git a/tests/refinement/safe-index1.fut b/tests/refinement/safe-index1.fut new file mode 100644 index 0000000000..aca421f17e --- /dev/null +++ b/tests/refinement/safe-index1.fut @@ -0,0 +1,4 @@ +entry f (xs: [2]i64) : i64 = + let i = 0 + in xs[i + 1] + diff --git a/unittests/Futhark/SoP/FourierMotzkinTests.hs b/unittests/Futhark/SoP/FourierMotzkinTests.hs new file mode 100644 index 0000000000..5c71839343 --- /dev/null +++ b/unittests/Futhark/SoP/FourierMotzkinTests.hs @@ -0,0 +1,45 @@ +module Futhark.SoP.FourierMotzkinTests (tests) where + +import Futhark.Analysis.PrimExp +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.Monad +import Futhark.SoP.Parse +import Futhark.SoP.SoP +import Test.Tasty +import Test.Tasty.HUnit + +fmSolveLTh0_ :: RangeEnv String -> SoP String -> Bool +fmSolveLTh0_ rs = evalSoPM mempty {ranges = rs} . (fmSolveLTh0 :: SoP String -> SoPM String (PrimExp String) Bool) + +fmSolveGTh0_ :: RangeEnv String -> SoP String -> Bool +fmSolveGTh0_ rs = evalSoPM mempty {ranges = rs} . (fmSolveGTh0 :: SoP String -> SoPM String (PrimExp String) Bool) + +fmSolveGEq0_ :: RangeEnv String -> SoP String -> Bool +fmSolveGEq0_ rs = evalSoPM mempty {ranges = rs} . (fmSolveGEq0 :: SoP String -> SoPM String (PrimExp String) Bool) + +fmSolveLEq0_ :: RangeEnv String -> SoP String -> Bool +fmSolveLEq0_ rs = evalSoPM mempty {ranges = rs} . (fmSolveLEq0 :: SoP String -> SoPM String (PrimExp String) Bool) + +tests :: TestTree +tests = + testGroup + "Solving inequalities with basic ranges" + [ testCase "Ranges 1" $ + let sop = parseSoP "i*N + j - N*N" + rs = + parseRangeEnv + [ "0 <= 2*i <= 2*N - 2", + "0 <= 2*j <= 2*N - 2", + "0 <= N" + ] + in fmSolveLTh0_ rs sop @?= True, + testCase "Ranges 2" $ + let sop = parseSoP "i*N + j" + rs = + parseRangeEnv + [ "0 <= i <= N - 1", + "0 <= j <= N - 1", + "0 <= N" + ] + in fmSolveGEq0_ rs sop @?= True + ] diff --git a/unittests/Futhark/SoP/Parse.hs b/unittests/Futhark/SoP/Parse.hs new file mode 100644 index 0000000000..5cb20bb8f5 --- /dev/null +++ b/unittests/Futhark/SoP/Parse.hs @@ -0,0 +1,156 @@ +{-# LANGUAGE GADTs #-} + +module Futhark.SoP.Parse + ( parseSoP, + parseRange, + parseRangeEnv, + parsePrimExp, + parsePrimExpToSoP, + ) +where + +import Control.Applicative ((<|>)) +import Control.Monad +import Data.Char +import Data.Functor +import Data.Map qualified as M +import Data.Set qualified as S +import Data.String +import Data.Text qualified as T +import Data.Void +import Futhark.Analysis.PrimExp +import Futhark.Analysis.PrimExp.Parse +import Futhark.SoP.Convert +import Futhark.SoP.Monad +import Futhark.SoP.SoP +import Language.Futhark.Primitive.Parse +import Text.Megaparsec (Parsec, manyTill_, notFollowedBy, parse, try, ()) +import Text.Megaparsec qualified as MP +import Text.Megaparsec.Char.Lexer qualified as L +import Text.ParserCombinators.ReadP + +tokenize :: ReadP a -> ReadP a +tokenize p = p <* skipSpaces + +pChar :: Char -> ReadP Char +pChar = tokenize . char + +pString :: String -> ReadP String +pString = tokenize . string + +pInt :: (Read a, Integral a) => ReadP a +pInt = + tokenize $ do + sign <- option 1 (char '-' >> pure (-1)) + n <- read <$> munch1 isDigit + pure $ sign * n + +pSym :: ReadP String +pSym = + tokenize $ + (:) + <$> satisfy isLetter + <*> munch (\c -> isAlphaNum c || c == '_') + +pAtom :: ReadP (SoP String) +pAtom = + choice + [ between (pChar '(') (pChar ')') pSoP, + sym2SoP <$> pSym, + int2SoP <$> pInt + ] + +pMult :: ReadP (SoP String) +pMult = chainl1 pAtom (pChar '*' $> (.*.)) + +pPlus :: ReadP (SoP String) +pPlus = chainl1 pMult $ do + op <- option '+' $ pChar '+' <|> pChar '-' + let sign + | op == '-' = (-1) + | otherwise = 1 + pure $ \l r -> l .+. scaleSoP sign r + +pSoP :: ReadP (SoP String) +pSoP = pPlus + +pRange :: ReadP (String, Range String) +pRange = do + lb <- pLowerBound + k <- pK + sym <- pSym + ub <- pUpperBound + pure (sym, Range lb k ub) + where + pK = option 1 $ pInt <* pChar '*' + pLowerBound = option mempty $ do + bset <- pBoundSet + sign <- pString "<=" <|> pString "<" + case sign of + "<" -> pure $ S.map (.+. int2SoP 1) bset + _ -> pure bset + pUpperBound = option mempty $ do + sign <- pString "<=" <|> pString "<" + bset <- pBoundSet + case sign of + "<" -> pure $ S.map (.+. int2SoP (-1)) bset + _ -> pure bset + pBoundSet = + S.fromList + <$> choice + [ pure <$> pSoP, + between (pChar '{') (pChar '}') $ + sepBy pSoP (pChar ',') + ] + +parse' :: Show a => ReadP a -> String -> a +parse' p s = + case readP_to_S (tokenize p <* eof) s of + [(e, "")] -> e + res -> error $ show res ++ "\n" ++ show s + +parseSoP :: String -> SoP String +parseSoP = parse' pSoP + +parseRange :: String -> (String, Range String) +parseRange = parse' pRange + +parseRangeEnv :: [String] -> RangeEnv String +parseRangeEnv = M.fromList . map parseRange + +pLeaf :: Parsec Void T.Text String +pLeaf = try $ lexeme $ MP.choice [try pVNameString, pFun] + where + pVNameString = do + (s, tag) <- + MP.satisfy constituent + `manyTill_` try pTag + "variable name" + pure $ s <> "_" <> show tag + pTag :: Parsec Void T.Text Integer + pTag = + "_" *> L.decimal <* notFollowedBy (MP.satisfy constituent) + pFun = do + fun <- lexeme $ T.unpack <$> MP.takeWhileP Nothing constituent + guard (fun `elem` M.keys primFuns) + args <- pBalanced (0 :: Integer) (0 :: Integer) "" + pure (fun <> args) + pBalanced open close s + | open > 0 && open == close = pure s + | open < close = fail "" + | otherwise = do + c <- MP.anySingle + let s' = s ++ [c] + case c of + '(' -> pBalanced (open + 1) close s' + ')' -> pBalanced open (close + 1) s' + _ -> pBalanced open close s' + +parsePrimExp :: String -> PrimExp String +parsePrimExp s = + case parse (pPrimExp (IntType Int64) pLeaf) "" (fromString s) of + Left bundle -> error $ show bundle + Right pe -> pe + +parsePrimExpToSoP :: String -> SoPM String (PrimExp String) (Integer, SoP String) +parsePrimExpToSoP = toSoPNum . parsePrimExp diff --git a/unittests/Futhark/SoP/RefineTests.hs b/unittests/Futhark/SoP/RefineTests.hs new file mode 100644 index 0000000000..ce7ef7a34d --- /dev/null +++ b/unittests/Futhark/SoP/RefineTests.hs @@ -0,0 +1,293 @@ +module Futhark.SoP.RefineTests (tests) where + +import Data.Set qualified as S +import Futhark.Analysis.PrimExp +import Futhark.SoP.Convert +import Futhark.SoP.FourierMotzkin +import Futhark.SoP.Monad +import Futhark.SoP.Parse +import Futhark.SoP.Refine +import Test.Tasty +import Test.Tasty.HUnit + +tests :: TestTree +tests = + testGroup + "Environment refinement tests" + [test_nw, test_lud] + +test_nw :: TestTree +test_nw = + testGroup + "Tests based on NW logs" + $ let lessThans = S.fromList . map parsePrimExp + nonNegatives = + S.fromList + . map (parsePrimExp . (\s -> "(sle64 " <> "(sub64 (0i64) (" <> s <> ")) " <> "(0i64))")) + in [ testCase + "Example 1" + $ let less_thans = + [ "(slt64 (mul64 (64i64) (i_13617)) (sub64 (sub64 (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))) (1i64)) (128i64)))", + "(slt64 (mul64 (64i64) (gtid_14374)) (sub64 (sub64 (sub64 (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))) (1i64)) (mul64 (64i64) (i_13617))) (128i64)))" + ] + non_negatives = ["n_13434", "i_13617", "gtid_14374"] + pes = lessThans less_thans <> nonNegatives non_negatives + goal = parsePrimExp "(slt64 (64i64) (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))))" + resM :: SoPM String (PrimExp String) Bool + resM = do + refineAlgEnv pes + fmSolveGEq0 . snd =<< toSoPCmp goal + in fst (runSoPM_ resM) @?= True, + testCase + "Example 2" + $ let less_thans = + [ "(slt64 (mul64 (64i64) (i_13969)) (sub64 (sub64 (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (mul64 (n_13783) (n_13783))))) (1i64)) (64i64)))", + "(slt64 (mul64 (64i64) (gtid_14718)) (sub64 (sub64 (sub64 (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (mul64 (n_13783) (n_13783))))) (1i64)) (mul64 (64i64) (i_13969))) (64i64)))" + ] + non_negatives = ["n_13783", "mul64 (n_13783) (n_13783)", "i_13969", "gtid_14718"] + pes = lessThans less_thans <> nonNegatives non_negatives + goal = parsePrimExp "(slt64 (add_nw64 (add_nw64 (64i64) (mul_nw64 (64i64) (gtid_14718))) (mul_nw64 (62i64) (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (mul64 (n_13783) (n_13783))))))) (add_nw64 (-64i64) (mul_nw64 (64i64) (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (mul64 (n_13783) (n_13783))))))))" + + resM :: SoPM String (PrimExp String) Bool + resM = do + refineAlgEnv pes + fmSolveGEq0 . snd =<< toSoPCmp goal + in fst (runSoPM_ resM) @?= True, + testCase + "Example 3 (test the limits of the range of Example 1)" + $ let less_thans = + [ "(slt64 (mul64 (64i64) (i_13617)) (sub64 (sub64 (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))) (1i64)) (128i64)))", + "(slt64 (mul64 (64i64) (gtid_14374)) (sub64 (sub64 (sub64 (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))) (1i64)) (mul64 (64i64) (i_13617))) (128i64)))" + ] + non_negatives = ["n_13434", "i_13617", "gtid_14374"] + pes = lessThans less_thans <> nonNegatives non_negatives + maximalTrue = + fst $ (runSoPM_ :: SoPM String (PrimExp String) a -> (a, AlgEnv String (PrimExp String))) $ do + refineAlgEnv pes + fmSolveGEq0 . snd + =<< toSoPCmp + ( parsePrimExp "(slt64 (129i64) (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))))" + ) + + minimalFalse = + fst $ (runSoPM_ :: SoPM String (PrimExp String) a -> (a, AlgEnv String (PrimExp String))) $ do + refineAlgEnv pes + fmSolveGEq0 . snd + =<< toSoPCmp + ( parsePrimExp "(slt64 (130i64) (fptosi_f64_i64 (sqrt64 (sitofp_i64_f64 (n_13434)))))" + ) + in maximalTrue && not minimalFalse @?= True + ] + +test_lud :: TestTree +test_lud = + testGroup + "Tests based on LUD logs" + $ let lessThans = + S.fromList + . map (parsePrimExp . (\(i, b) -> "(slt64 (" <> i <> ") (" <> b <> "))")) + nonNegatives = + S.fromList + . map (parsePrimExp . (\s -> "(sle64 " <> "(sub64 (0i64) (" <> s <> ")) " <> "(0i64))")) + in [ testCase + "Example 1" + $ let less_thans = + [ ("step_14910", "sub64 (num_blocks_14902) (1i64)"), + ("gtid_16211", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16212", "32i64"), + ("gtid_16304", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16305", "32i64"), + ("gtid_16448", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16449", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gid_y_17184", "1i64"), + ("gid_x_17183", "1i64") + ] + non_negatives = + [ "num_blocks_14902", + "opaque_res_14906", + "step_14910", + "sub64 (num_blocks_14902) (i_14963)", + "opaque_res_15034", + "max_group_size_15534", + "segmap_group_size_16060", + "segmap_group_size_16089", + "segmap_group_size_16103", + "segmap_group_size_16130", + "segmap_group_size_16207", + "gtid_16211", + "gtid_16212", + "segmap_group_size_16300", + "gtid_16304", + "gtid_16305", + "gtid_16448", + "gtid_16449", + "gid_x_17183", + "gid_y_17184", + "smax64 (0i64) (binop_y_17516)", + "mul_nw64 (4096i64) (j_m_i_14964)", + "smax64 (0i64) (binop_y_17765)", + "smax64 (0i64) (binop_y_17784)", + "smax64 (0i64) (binop_y_17913)" + ] + pes = lessThans less_thans <> nonNegatives non_negatives + goal = parsePrimExp "(slt64 (1023i64) (mul_nw64 (1024i64) (num_blocks_14902)) )" + resM :: SoPM String (PrimExp String) Bool + resM = do + refineAlgEnv pes + fmSolveGEq0 . snd =<< toSoPCmp goal + in fst (runSoPM_ resM) @?= True, + testCase + "Example 2" + $ let less_thans = + [ ("step_14910", "sub64 (num_blocks_14902) (1i64)"), + ("gtid_16211", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16212", "32i64"), + ("gtid_16304", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16305", "32i64"), + ("gtid_16448", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16449", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gid_y_17184", "1i64"), + ("gid_x_17183", "1i64") + ] + non_negatives = + [ "num_blocks_14902", + "opaque_res_14906", + "step_14910", + "sub64 (num_blocks_14902) (i_14963)", + "opaque_res_15034", + "max_group_size_15534", + "segmap_group_size_16060", + "segmap_group_size_16089", + "segmap_group_size_16103", + "segmap_group_size_16130", + "segmap_group_size_16207", + "gtid_16211", + "gtid_16212", + "segmap_group_size_16300", + "gtid_16304", + "gtid_16305", + "gtid_16448", + "gtid_16449", + "gid_x_17183", + "gid_y_17184", + "smax64 (0i64) (binop_y_17516)", + "mul_nw64 (4096i64) (j_m_i_14964)", + "smax64 (0i64) (binop_y_17765)", + "smax64 (0i64) (binop_y_17784)", + "smax64 (0i64) (binop_y_17913)" + ] + pes = lessThans less_thans <> nonNegatives non_negatives + goal = parsePrimExp "(slt64 (add_nw64 (add_nw64 (add_nw64 (1023i64) (mul_nw64 (1024i64) (gtid_16449))) (mul_nw64 (1024i64) (gid_y_17184))) (mul_nw64 (32i64) (gid_x_17183))) (mul_nw64 (1024i64) (num_blocks_14902)))" + + resM :: SoPM String (PrimExp String) Bool + resM = do + refineAlgEnv pes + fmSolveGEq0 . snd =<< toSoPCmp goal + in fst (runSoPM_ resM) @?= True, + testCase + "Example 3" + $ let less_thans = + [ ("step_14910", "sub64 (num_blocks_14902) (1i64)"), + ("gtid_16211", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16212", "32i64"), + ("gtid_16304", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16305", "32i64"), + ("gtid_16448", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gtid_16449", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gid_y_17184", "1i64"), + ("gid_x_17183", "1i64") + ] + + non_negatives = + [ "num_blocks_14902", + "opaque_res_14906", + "step_14910", + "sub64 (num_blocks_14902) (i_14963)", + "opaque_res_15034", + "max_group_size_15534", + "segmap_group_size_16060", + "segmap_group_size_16089", + "segmap_group_size_16103", + "segmap_group_size_16130", + "segmap_group_size_16207", + "gtid_16211", + "gtid_16212", + "segmap_group_size_16300", + "gtid_16304", + "gtid_16305", + "gtid_16448", + "gtid_16449", + "gid_x_17183", + "gid_y_17184", + "smax64 (0i64) (binop_y_17516)", + "mul_nw64 (4096i64) (j_m_i_14964)", + "smax64 (0i64) (binop_y_17765)", + "smax64 (0i64) (binop_y_17784)", + "smax64 (0i64) (binop_y_17913)" + ] + + pes = lessThans less_thans <> nonNegatives non_negatives + goal = parsePrimExp "(slt64 (add_nw64 (add_nw64 (mul_nw64 (1024i64) (gid_y_17184)) (mul_nw64 (32i64) (gid_x_17183))) (1024i64) ) (1025i64 ))" + + resM :: SoPM String (PrimExp String) Bool + resM = do + refineAlgEnv pes + fmSolveGEq0 . snd =<< toSoPCmp goal + in fst (runSoPM_ resM) @?= True, + testCase + "Example 4" + $ let less_thans = + [ ("step_14910", "sub64 (num_blocks_14902) (1i64)"), + ( "gtid_16211", + "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))" + ), + ("gtid_16212", "32i64"), + ( "gtid_16304", + "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))" + ), + ("gtid_16305", "32i64"), + ( "gtid_16448", + "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))" + ), + ("gtid_16449", "sub64 (num_blocks_14902) (add64 (1i64) (step_14910))"), + ("gid_y_17184", "1i64"), + ("gid_x_17183", "1i64") + ] + + non_negatives = + [ "num_blocks_14902", + "opaque_res_14906", + "step_14910", + "sub64 (num_blocks_14902) (i_14963)", + "opaque_res_15034", + "max_group_size_15534", + "segmap_group_size_16060", + "segmap_group_size_16089", + "segmap_group_size_16103", + "segmap_group_size_16130", + "segmap_group_size_16207", + "gtid_16211", + "gtid_16212", + "segmap_group_size_16300", + "gtid_16304", + "gtid_16305", + "gtid_16448", + "gtid_16449", + "gid_x_17183", + "gid_y_17184", + "smax64 (0i64) (binop_y_17516)", + "mul_nw64 (4096i64) (j_m_i_14964)", + "smax64 (0i64) (binop_y_17765)", + "smax64 (0i64) (binop_y_17784)", + "smax64 (0i64) (binop_y_17913)" + ] + + pes = lessThans less_thans <> nonNegatives non_negatives + goal = parsePrimExp "(slt64 (mul_nw64 (1024i64) (gtid_16449) ) (add_nw64 (add_nw64 (add_nw64 (mul_nw64 (1024i64) (gtid_16449)) (mul_nw64 (1024i64) (gid_y_17184))) (mul_nw64 (32i64) (gid_x_17183))) (1i64) ))" + + resM :: SoPM String (PrimExp String) Bool + resM = do + refineAlgEnv pes + fmSolveGEq0 . snd =<< toSoPCmp goal + in fst (runSoPM_ resM) @?= True + ] diff --git a/unittests/Futhark/SoP/SoPTests.hs b/unittests/Futhark/SoP/SoPTests.hs new file mode 100644 index 0000000000..3e4023582c --- /dev/null +++ b/unittests/Futhark/SoP/SoPTests.hs @@ -0,0 +1,31 @@ +module Futhark.SoP.SoPTests (tests) where + +import Futhark.SoP.Parse +import Futhark.SoP.SoP +import Test.Tasty +import Test.Tasty.HUnit + +tests :: TestTree +tests = + testGroup + "Arithmetic tests" + $ let sop1 = parseSoP "x + y" + sop2 = parseSoP "5 + x*x + 2*x + 3*y + y*y" + sop3 = parseSoP "-3 + x + y" + in [ testCase "Addition" $ + addSoPs sop1 sop2 @?= parseSoP "5 + x*x + 3*x + 4*y + y*y", + testCase "Multiplication 1" $ + mulSoPs sop1 sop2 + @?= parseSoP + "5*x + x*x*x + 2*x*x + 3*y*x + y*y*x + 5*y + x*x*y + 2*x*y + 3*y*y + y*y*y", + testCase "Multiplication 2" $ + mulSoPs sop1 sop2 + @?= parseSoP + "(x+y) * (5 + x*x + 2*x + 3*y + y*y)", + testCase "Negation 1" $ + negSoP sop3 + @?= parseSoP + "3 - x - y", + testCase "Negation 2" $ + negSoP (negSoP sop3) @?= sop3 + ] diff --git a/unittests/Language/Futhark/TypeChecker/TypesTests.hs b/unittests/Language/Futhark/TypeChecker/TypesTests.hs index 283eb921a9..712a6198b1 100644 --- a/unittests/Language/Futhark/TypeChecker/TypesTests.hs +++ b/unittests/Language/Futhark/TypeChecker/TypesTests.hs @@ -32,7 +32,7 @@ evalTest te expected = assertFailure $ "Expected error, got: " <> show actual_t where extract (_, svars, t, _) = (svars, t) - run = snd . runTypeM env mempty (mkInitialImport "") blankNameSource checkSizeExp + run = snd . runTypeM env mempty (mkInitialImport "") blankNameSource checkSizeExp checkPredExp -- We hack up an environment with some predefined type -- abbreviations for testing. This is all prettyString sensitive to the -- specific unique names, so we have to be careful! diff --git a/unittests/futhark_tests.hs b/unittests/futhark_tests.hs index 32e22272cf..6b79c467cf 100644 --- a/unittests/futhark_tests.hs +++ b/unittests/futhark_tests.hs @@ -10,6 +10,9 @@ import Futhark.IR.Syntax.CoreTests qualified import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified +import Futhark.SoP.RefineTests qualified +import Futhark.SoP.FourierMotzkinTests qualified +import Futhark.SoP.SoPTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SyntaxTests qualified import Language.Futhark.TypeCheckerTests qualified @@ -31,7 +34,10 @@ allTests = Language.Futhark.PrimitiveTests.tests, Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests.tests, Futhark.Analysis.AlgSimplifyTests.tests, - Language.Futhark.TypeCheckerTests.tests + Language.Futhark.TypeCheckerTests.tests, + Futhark.SoP.RefineTests.tests, + Futhark.SoP.FourierMotzkinTests.tests, + Futhark.SoP.SoPTests.tests ] main :: IO ()