diff --git a/futhark.cabal b/futhark.cabal index 036ceec612..1dd3e0aaeb 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -379,6 +379,14 @@ library Futhark.Pass.ExtractKernels.ToGPU Futhark.Pass.ExtractMulticore Futhark.Pass.FirstOrderTransform + Futhark.Pass.Flatten + Futhark.Pass.Flatten.Builtins + Futhark.Pass.Flatten.PreProcess + Futhark.Pass.Flatten.ISRWIM + Futhark.Pass.Flatten.Distribute + Futhark.Pass.Flatten.Match + Futhark.Pass.Flatten.Monad + Futhark.Pass.Flatten.WithAcc Futhark.Pass.LiftAllocations Futhark.Pass.LowerAllocations Futhark.Pass.Simplify @@ -490,6 +498,7 @@ library , lsp-types >= 2.4.0.0 , mainland-pretty >=0.7.1 , cmark-gfm >=0.2.1 + , OneTuple , megaparsec >=9.0.0 , mtl >=2.2.1 , neat-interpolation >=0.3 diff --git a/rewritefut/segupdate.fut b/rewritefut/segupdate.fut new file mode 100644 index 0000000000..980d3702db --- /dev/null +++ b/rewritefut/segupdate.fut @@ -0,0 +1,36 @@ +-- Flat-Parallel Segmented Update +-- == +-- compiled input { [1i64,2i64,3i64,1i64,2i64,1i64,2i64,3i64,4i64] [3i64,2i64,4i64] [0i64,0i64,0i64,0i64,0i64] [2i64,1i64,2i64] [0i64, 1i64, 0i64] [1i64, 1i64, 2i64] } output { [0i64,0i64,3i64,1i64,0i64,0i64,2i64,0i64,4i64] } + +let sgmSumI64 [n] (flg : [n]i64) (arr : [n]i64) : [n]i64 = + let flgs_vals = + scan ( \ (f1, x1) (f2,x2) -> + let f = f1 | f2 in + if f2 != 0 then (f, x2) + else (f, x1 + x2) ) + (0, 0i64) (zip flg arr) + let (_, vals) = unzip flgs_vals + in vals + +let mkFlagArray [m] (aoa_shp: [m]i64) (zero: i64) + (aoa_val: [m]i64) : []i64 = + let shp_rot = map(\i -> if i==0i64 then 0i64 else aoa_shp[i-1]) (iota m) + let shp_scn = scan (+) 0i64 shp_rot + let aoa_len = shp_scn[m-1]+aoa_shp[m-1] + let shp_ind = map2 (\shp ind -> if shp==0 then -1i64 else ind) aoa_shp shp_scn + in scatter (replicate aoa_len zero) shp_ind aoa_val + +let segUpdate [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) + (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = + let fvss = (mkFlagArray shp_vss 0 (1...t :> [t]i64)) :> [m]i64 + let II1 = sgmSumI64 fvss fvss |> map (\x -> x - 1) + let shp_xss_rot = map(\i -> if i==0i64 then 0i64 else shp_xss[i-1]) (iota t) + let bxss = scan (+) 0 shp_xss_rot + let II2 = sgmSumI64 fvss (replicate m 1) |> map (\x -> x - 1) + let iss = map (\i -> bxss[II1[i]] + bs[II1[i]] + (II2[i] * ss[II1[i]])) (iota m) + in scatter xss_val iss vss_val + + +let main [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) + (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = + segUpdate xss_val shp_xss vss_val shp_vss bs ss \ No newline at end of file diff --git a/src/Futhark/CLI/Dev.hs b/src/Futhark/CLI/Dev.hs index 0c6332921a..54864eb572 100644 --- a/src/Futhark/CLI/Dev.hs +++ b/src/Futhark/CLI/Dev.hs @@ -56,6 +56,7 @@ import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform +import Futhark.Pass.Flatten (flattenSOACs) import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations import Futhark.Pass.Simplify @@ -723,6 +724,7 @@ commandLineOptions = sinkOption [], kernelsPassOption reduceDeviceSyncs [], typedPassOption soacsProg GPU extractKernels [], + typedPassOption soacsProg GPU flattenSOACs [], typedPassOption soacsProg MC extractMulticore [], allocateOption "a", kernelsMemPassOption doubleBufferGPU [], diff --git a/src/Futhark/IR/Pretty.hs b/src/Futhark/IR/Pretty.hs index 270bfec54c..35c2efabaf 100644 --- a/src/Futhark/IR/Pretty.hs +++ b/src/Futhark/IR/Pretty.hs @@ -7,6 +7,7 @@ module Futhark.IR.Pretty ( prettyTuple, prettyTupleLines, prettyString, + prettyRet, PrettyRep (..), ) where diff --git a/src/Futhark/IR/TypeCheck.hs b/src/Futhark/IR/TypeCheck.hs index 701bfe4314..43f2ef5caa 100644 --- a/src/Futhark/IR/TypeCheck.hs +++ b/src/Futhark/IR/TypeCheck.hs @@ -58,7 +58,7 @@ import Futhark.Analysis.PrimExp import Futhark.Construct (instantiateShapes) import Futhark.IR.Aliases hiding (lookupAliases) import Futhark.Util -import Futhark.Util.Pretty (align, docText, indent, ppTuple', pretty, (<+>), ()) +import Futhark.Util.Pretty hiding (width) -- | Information about an error during type checking. The 'Show' -- instance for this type produces a human-readable description. @@ -738,7 +738,7 @@ checkSubExp (Var ident) = context ("In subexp " <> prettyText ident) $ do lookupType ident checkCerts :: (Checkable rep) => Certs -> TypeM rep () -checkCerts (Certs cs) = mapM_ (requireI (Prim Unit)) cs +checkCerts = mapM_ lookupType . unCerts checkSubExpRes :: (Checkable rep) => SubExpRes -> TypeM rep Type checkSubExpRes (SubExpRes cs se) = do @@ -1028,9 +1028,9 @@ checkExp (Apply fname args rettype_annot _) = do when (rettype_derived /= rettype_annot) $ bad . TypeError . docText $ "Expected apply result type:" - indent 2 (pretty $ map fst rettype_derived) + indent 2 (braces $ commasep $ map prettyRet rettype_derived) "But annotation is:" - indent 2 (pretty $ map fst rettype_annot) + indent 2 (braces $ commasep $ map prettyRet rettype_annot) consumeArgs paramtypes argflows checkExp (Loop merge form loopbody) = do let (mergepat, mergeexps) = unzip merge @@ -1258,9 +1258,8 @@ checkStm :: TypeM rep a -> TypeM rep a checkStm stm@(Let pat aux e) m = do - let Certs cs = stmAuxCerts aux - (_, dec) = stmAuxDec aux - context "When checking certificates" $ mapM_ (requireI $ Prim Unit) cs + let (_, dec) = stmAuxDec aux + context "When checking certificates" $ checkCerts $ stmAuxCerts aux context "When checking expression annotation" $ checkExpDec dec context ("When matching\n" <> message " " pat <> "\nwith\n" <> message " " e) $ matchPat pat e diff --git a/src/Futhark/Pass/ExtractKernels/ToGPU.hs b/src/Futhark/Pass/ExtractKernels/ToGPU.hs index 667b68a420..ea4ab3586a 100644 --- a/src/Futhark/Pass/ExtractKernels/ToGPU.hs +++ b/src/Futhark/Pass/ExtractKernels/ToGPU.hs @@ -5,6 +5,7 @@ module Futhark.Pass.ExtractKernels.ToGPU segThread, soacsLambdaToGPU, soacsStmToGPU, + soacsExpToGPU, scopeForGPU, scopeForSOACs, injectSOACS, @@ -74,6 +75,9 @@ injectSOACS f = soacsStmToGPU :: Stm SOACS -> Stm GPU soacsStmToGPU = runIdentity . rephraseStm (injectSOACS OtherOp) +soacsExpToGPU :: Exp SOACS -> Exp GPU +soacsExpToGPU = runIdentity . rephraseExp (injectSOACS OtherOp) + soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU soacsLambdaToGPU = runIdentity . rephraseLambda (injectSOACS OtherOp) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs new file mode 100644 index 0000000000..5e0f05a722 --- /dev/null +++ b/src/Futhark/Pass/Flatten.hs @@ -0,0 +1,3054 @@ +{-# LANGUAGE TypeFamilies #-} + +-- The idea is to perform distribution on one level at a time, and +-- produce "irregular Maps" that can accept and produce irregular +-- arrays. These irregular maps will then be transformed into flat +-- parallelism based on their contents. This is a sensitive detail, +-- but if irregular maps contain only a single Stm, then it is fairly +-- straightforward, as we simply implement flattening rules for every +-- single kind of expression. Of course that is also somewhat +-- inefficient, so we want to support multiple Stms for things like +-- scalar code. +module Futhark.Pass.Flatten (flattenSOACs) where + +import Control.Monad +import Control.Monad.Reader +import Data.Bifunctor (first) +import Data.Foldable +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe (isNothing, mapMaybe) +import Data.Set qualified as S +import Data.Tuple.Solo +import Debug.Trace +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass +import Futhark.Pass.ExtractKernels.ToGPU (soacsLambdaToGPU, soacsStmToGPU) +import Futhark.Pass.Flatten.Builtins +import Futhark.Pass.Flatten.Distribute +import Futhark.Pass.Flatten.Match +import Futhark.Pass.Flatten.Monad +import Futhark.Pass.Flatten.PreProcess +import Futhark.Pass.Flatten.WithAcc +import Futhark.Tools +import Futhark.Transform.Rename +import Futhark.Transform.Substitute +import Futhark.Util.IntegralExp +import Prelude hiding (div, quot, rem) + +data InnerMapMode + = MultiDim + | SingleDim + +flattenOps :: FlattenOps +flattenOps = FlattenOps {flattenDistStm = transformDistStm} + +transformScalarStms :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Stms SOACS -> + Builder GPU DistEnv +transformScalarStms segments env inps distres stms = do + let bound_in_batch = namesFromList $ concatMap (patNames . stmPat) $ stmsToList stms + allCerts = foldMap (\stm -> distCerts inps (stmAux stm) env) (stmsToList stms) + certs = Certs $ filter (`notNameIn` bound_in_batch) $ unCerts allCerts + vs <- certifying certs $ letTupExp "scalar_dist" <=< renameExp <=< segMap segments $ \is -> do + readInputs segments env (toList is) inps + addStms $ fmap soacsStmToGPU stms + pure $ subExpsRes $ map (Var . distResName) distres + pure $ insertReps (zip (map distResTag distres) $ map Regular vs) env + +transformScalarStm :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Stm SOACS -> + Builder GPU DistEnv +transformScalarStm segments env inps res stm = + transformScalarStms segments env inps res (oneStm stm) + +-- Do 'map2 (++) A B' where 'A' and 'B' are irregular arrays and have the same +-- number of subarrays +concatIrreg :: + Segments -> + DistEnv -> + VName -> + [IrregularRep] -> + Builder GPU IrregularRep +concatIrreg _segments _env ns reparr = do + -- Concatenation does not change the number of segments - it simply + -- makes each of them larger. + + num_segments <- arraySize 0 <$> lookupType ns + + -- Constructs the full list size / shape that should hold the final results. + let zero = Constant $ IntValue $ intValue Int64 (0 :: Int) + ns_full <- letExp (baseName ns <> "_full") <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + old_segments <- + forM reparr $ \rep -> + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i] + new_segment <- + letSubExp "new_segment" + =<< toExp (foldl (+) (pe64 zero) $ map pe64 old_segments) + pure $ subExpsRes [new_segment] + + (ns_full_F, ns_full_O, _ns_II1) <- doRepIota ns_full + + repIota <- mapM (doRepIota . irregularS) reparr + segIota <- mapM (doSegIota . irregularS) reparr + + let (_, _, rep_II1) = unzip3 repIota + let (_, _, rep_II2) = unzip3 segIota + + n_arr <- mapM (fmap (arraySize 0) . lookupType) rep_II1 + + -- Calculate offsets for the scatter operations + let shapes = map irregularS reparr + scatter_offsets <- + letTupExp "irregular_scatter_offsets" <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + segment_sizes <- + forM shapes $ \shape -> + letSubExp "segment_size" =<< eIndex shape [eSubExp i] + let prefixes = L.init $ L.inits segment_sizes + sumprefix <- + mapM + ( letSubExp "segment_prefix" + <=< foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) + ) + prefixes + pure $ subExpsRes sumprefix + + scatter_offsets_T <- + letTupExp "irregular_scatter_offsets_T" <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + columns <- + forM scatter_offsets $ \offsets -> + letSubExp "segment_offset" =<< eIndex offsets [eSubExp i] + pure $ subExpsRes columns + + m <- arraySize 0 <$> lookupType ns_full_F + data_t <- lookupType (irregularD (head reparr)) + let pt = elemType data_t + let resultType = Array pt (Shape [m]) NoUniqueness + elems_blank <- letExp "blank_res" =<< eBlank resultType + + -- Scatter data into result array + elems <- + foldlM + ( \elems (reparr1, scatter_offset, n, ii1, ii2) -> do + letExp "irregular_scatter_elems" <=< genScatter elems n $ \gid -> do + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex ii1 [eSubExp gid] + + -- Get segment offset in final array + segment_o <- + letSubExp "segment_o" =<< eIndex ns_full_O [eSubExp segment_i] + + -- Get local segment offset + segment_local_o <- + letSubExp "segment_local_o" + =<< eIndex scatter_offset [eSubExp segment_i] + + o' <- letSubExp "o" =<< eIndex ii2 [eSubExp gid] + src_segment_o <- + letSubExp "src_segment_o" =<< eIndex (irregularO reparr1) [eSubExp segment_i] + src_i <- + letSubExp "src_i" <=< toExp $ pe64 src_segment_o + pe64 o' + v' <- + letSubExp "v" =<< eIndex (irregularD reparr1) [eSubExp src_i] + + -- Index to write `v'` at + i <- + letExp "i" =<< toExp (pe64 o' + pe64 segment_local_o + pe64 segment_o) + + pure (i, v') + ) + elems_blank + $ L.zip5 reparr scatter_offsets_T n_arr rep_II1 rep_II2 + + pure $ + IrregularRep + { irregularS = ns_full, + irregularF = ns_full_F, + irregularO = ns_full_O, + irregularD = elems, + irregularK = Dense + } + +-- We also can do reearange -> concat -> rearrange but this should be more efficient +concatIrregAlongDim :: + Segments -> + DistEnv -> + VName -> + [IrregularRep] -> + [Type] -> + DistInputs -> + Int -> + Builder GPU IrregularRep +concatIrregAlongDim segments env ns reparr typearr inps d = do + num_segments <- arraySize 0 <$> lookupType ns + + let zero = Constant $ IntValue $ intValue Int64 (0 :: Int) + ns_full <- letExp (baseName ns <> "_full") <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + old_segments <- + forM reparr $ \rep -> + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i] + new_segment <- + letSubExp "new_segment" + =<< toExp (foldl (+) (pe64 zero) $ map pe64 old_segments) + pure $ subExpsRes [new_segment] + + (ns_full_F, ns_full_O, _ns_II1) <- doRepIota ns_full + + repIota <- mapM (doRepIota . irregularS) reparr + segIota <- mapM (doSegIota . irregularS) reparr + + let (_, _, rep_II1) = unzip3 repIota + let (_, _, rep_II2) = unzip3 segIota + + n_arr <- mapM (fmap (arraySize 0) . lookupType) rep_II1 + + scatter_info <- + letTupExp "irregular_scatter_offsets" <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + seg_is <- segmentCoordsFromFlat segments i + + block_sizes <- + forM typearr $ \t -> do + v_dims <- readTypeDims segments env seg_is inps t + letSubExp "block_size" =<< toExp (product $ map pe64 $ drop d v_dims) + + total_block <- + letSubExp "total_block" + <=< foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) + $ block_sizes + + let prefixes = L.init $ L.inits block_sizes + + sumprefix <- + mapM + ( letSubExp "segment_prefix" + <=< foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) + ) + prefixes + + pure $ subExpsRes (block_sizes <> sumprefix <> [total_block]) + + let k = length typearr + (scatter_blocks, rest) = splitAt k scatter_info + (scatter_offsets, [total_block_size]) = splitAt k rest + + m <- arraySize 0 <$> lookupType ns_full_F + data_t <- lookupType (irregularD (head reparr)) + let pt = elemType data_t + let resultType = Array pt (Shape [m]) NoUniqueness + elems_blank <- letExp "blank_res" =<< eBlank resultType + + -- Scatter data into result array + elems <- + foldlM + ( \elems (reparr1, scatter_block, scatter_offset, n, ii1, ii2) -> do + letExp "irregular_scatter_elems" <=< genScatter elems n $ \gid -> do + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex ii1 [eSubExp gid] + + -- Get segment offset in final array + segment_o <- + letSubExp "segment_o" =<< eIndex ns_full_O [eSubExp segment_i] + + -- Get local segment offset + segment_local_o <- + letSubExp "segment_local_o" + =<< eIndex scatter_offset [eSubExp segment_i] + + o' <- letSubExp "o" =<< eIndex ii2 [eSubExp gid] + src_segment_o <- + letSubExp "src_segment_o" =<< eIndex (irregularO reparr1) [eSubExp segment_i] + src_i <- + letSubExp "src_i" <=< toExp $ pe64 src_segment_o + pe64 o' + v' <- + letSubExp "v" =<< eIndex (irregularD reparr1) [eSubExp src_i] + + scatter_block_size <- + letSubExp "scatter_block_size" =<< eIndex scatter_block [eSubExp segment_i] + + scatter_total_block_size <- + letSubExp "scatter_total_block_size" =<< eIndex total_block_size [eSubExp segment_i] + + outer_i <- + letSubExp "outer_i" =<< toExp (pe64 o' `div` pe64 scatter_block_size) + + i <- + letExp "i" + =<< toExp + ( pe64 o' + + pe64 outer_i * (pe64 scatter_total_block_size - pe64 scatter_block_size) + + pe64 segment_local_o + + pe64 segment_o + ) + pure (i, v') + ) + elems_blank + $ L.zip6 reparr scatter_blocks scatter_offsets n_arr rep_II1 rep_II2 + + pure $ + IrregularRep + { irregularS = ns_full, + irregularF = ns_full_F, + irregularO = ns_full_O, + irregularD = elems, + irregularK = Dense + } + +-- Do 'map2 replicate ns A', where 'A' is an irregular array (and so +-- is the result, obviously). +replicateIrreg :: + Segments -> + DistEnv -> + VName -> + Name -> + IrregularRep -> + Builder GPU IrregularRep +replicateIrreg _segments _env ns desc rep = do + -- Replication does not change the number of segments - it simply + -- makes each of them larger. + + num_segments <- arraySize 0 <$> lookupType ns + + -- ns multipled with existing segment sizes. + ns_full <- letExp (baseName ns <> "_full") <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + n <- + letSubExp "n" =<< eIndex ns [eSubExp i] + old_segment <- + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i] + full_segment <- + letSubExp "new_segment" =<< toExp (pe64 n * pe64 old_segment) + pure $ subExpsRes [full_segment] + + (ns_full_F, ns_full_O, ns_full_D) <- doRepIota ns_full + (_, _, flat_to_segs) <- doSegIota ns_full + + w <- arraySize 0 <$> lookupType ns_full_D + + elems <- letExp (desc <> "_rep_D") <=< segMap (MkSolo w) $ \(MkSolo i) -> do + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex ns_full_D [eSubExp i] + -- Size of original segment. + old_segment <- + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp segment_i] + -- Index of value inside *new* segment. + j_new <- + letSubExp "j_new" =<< eIndex flat_to_segs [eSubExp i] + -- Index of value inside *old* segment. + j_old <- + letSubExp "j_old" =<< toExp (pe64 j_new `rem` pe64 old_segment) + -- Offset of values in original segment. + offset <- + letSubExp "offset" =<< eIndex (irregularO rep) [eSubExp segment_i] + v <- + letSubExp "v" + =<< eIndex (irregularD rep) [toExp $ pe64 offset + pe64 j_old] + pure $ subExpsRes [v] + + pure $ + IrregularRep + { irregularS = ns_full, + irregularF = ns_full_F, + irregularO = ns_full_O, + irregularD = elems, + irregularK = Dense + } + +-- | Flatten the arrays of an IrregularRep to be entirely one-dimensional. +flattenIrregularRep :: IrregularRep -> Builder GPU IrregularRep +flattenIrregularRep ir@(IrregularRep shape flags offsets elems kind) = do + elems_t <- lookupType elems + if arrayRank elems_t == 1 + then pure ir + else do + n <- arraySize 0 <$> lookupType shape + m' <- letSubExp "flat_m" <=< toExp $ product $ map pe64 $ arrayDims elems_t + elems' <- + letExp (baseName elems <> "_flat") . BasicOp $ + Reshape elems (reshapeAll (arrayShape elems_t) (Shape [m'])) + + shape' <- letExp (baseName shape <> "_flat") <=< renameExp <=< segMap (MkSolo n) $ + \(MkSolo i) -> do + old_shape <- letSubExp "old_shape" =<< eIndex shape [toExp i] + segment_shape <- + letSubExp "segment_shape" <=< toExp $ + pe64 old_shape * product (map pe64 $ tail $ arrayDims elems_t) + pure [subExpRes segment_shape] + + offsets' <- letExp (baseName offsets <> "_flat") <=< renameExp <=< segMap (MkSolo n) $ + \(MkSolo i) -> do + old_offsets <- letSubExp "old_offsets" =<< eIndex offsets [toExp i] + segment_offsets <- + letSubExp "segment_offsets" <=< toExp $ + pe64 old_offsets * product (map pe64 $ tail $ arrayDims elems_t) + pure [subExpRes segment_offsets] + + flags' <- letExp (baseName flags <> "_flat") <=< renameExp <=< segMap (MkSolo m') $ + \(MkSolo i) -> do + let head_i = head $ unflattenIndex (map pe64 $ arrayDims elems_t) (pe64 i) + flag <- letSubExp "flag" =<< eIndex flags [toExp head_i] + pure [subExpRes flag] + pure $ IrregularRep shape' flags' offsets' elems' kind + +rearrangeFlat :: (IntegralExp num) => [Int] -> [num] -> num -> num +rearrangeFlat perm dims i = + -- TODO? Maybe we need to invert one of these permutations. + flattenIndex dims $ + rearrangeShape perm $ + unflattenIndex (rearrangeShape perm dims) i + +segmentCoordsFromFlat :: Segments -> SubExp -> Builder GPU [SubExp] +segmentCoordsFromFlat segments seg_i = + mapM (letSubExp "seg_coord" <=< toExp) $ + unflattenIndex (map pe64 $ shapeDims $ segmentsShape segments) (pe64 seg_i) + +segmentCount :: Segments -> TPrimExp Int64 VName +segmentCount = product . map pe64 . shapeDims . segmentsShape + +-- TODO: We do not need to actully make this Dense +rearrangeIrreg :: + Segments -> + DistEnv -> + DistInputs -> + TypeBase Shape u -> + [Int] -> + IrregularRep -> + Builder GPU IrregularRep +rearrangeIrreg segments env inps v_t perm ir = do + (IrregularRep shape _ offsets elems _) <- flattenIrregularRep ir + (new_F, new_O, ii1_vss) <- doRepIota shape + (_, _, ii2_vss) <- doSegIota shape + m <- arraySize 0 <$> lookupType ii1_vss + elems' <- letExp "elems_rearrange" <=< renameExp <=< segMap (MkSolo m) $ + \(MkSolo i) -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [eSubExp seg_i] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp i] + seg_is <- segmentCoordsFromFlat segments seg_i + v_dims <- readTypeDims segments env seg_is inps v_t + let v_dims' = map pe64 v_dims + in_seg_is_tr = rearrangeFlat perm v_dims' $ pe64 in_seg_i + v' <- + letSubExp "v" + =<< eIndex elems [toExp $ pe64 offset + in_seg_is_tr] + pure [subExpRes v'] + pure $ + IrregularRep + { irregularS = shape, + irregularF = new_F, + irregularO = new_O, + irregularD = elems', + irregularK = Dense + } + +sufficientParallelism :: + Name -> + [SubExp] -> + KernelPath -> + Maybe Int64 -> + Builder GPU (SubExp, Name) +sufficientParallelism desc ws path def = do + size_key <- nameFromText . prettyText <$> newVName desc + + amount <- + letSubExp "comparatee" + =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) ws + + cmp_res <- + letSubExp desc $ + Op $ + SizeOp $ + CmpSizeLe size_key (SizeThreshold path def) amount + + pure (cmp_res, size_key) + +kernelAlternatives :: + Name -> + [Type] -> + Body GPU -> + [(SubExp, Body GPU)] -> + Builder GPU [VName] +kernelAlternatives desc _ default_body [] = do + ses <- bodyBind default_body + forM ses $ \(SubExpRes cs se) -> + certifying cs $ + letExp desc $ + BasicOp $ + SubExp se +kernelAlternatives desc result_ts default_body ((cond, alt) : alts) = do + fallback_body <- do + (fallback_vs, fallback_stms) <- + collectStms $ + kernelAlternatives desc result_ts default_body alts + pure $ mkBody fallback_stms $ varsRes fallback_vs + + letTupExp desc $ + Match [cond] [Case [Just $ BoolValue True] alt] fallback_body $ + MatchDec (staticShapes result_ts) MatchEquiv + +regularResultVars :: [DistResult] -> DistEnv -> [VName] +regularResultVars ress env = + map onRes ress + where + onRes res = + case resVar (distResTag res) env of + Regular v -> v + Irregular {} -> + error "regularResultVars: expected regular result" + +regularRepVars :: [ResRep] -> [VName] +regularRepVars = + map onRep + where + onRep (Regular v) = v + onRep Irregular {} = + error "regularRepVars: expected regular result" + +isVersionableRegularResult :: DistResult -> Bool +isVersionableRegularResult = isRegularDistResult + +regularBranchBody :: + Builder GPU [VName] -> + Builder GPU (Body GPU) +regularBranchBody m = do + (vs, stms) <- collectStms m + renameBody $ mkBody stms $ varsRes vs + +versionedRegularMap :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Pat Type -> + StmAux () -> + SubExp -> + [VName] -> + ScremaForm SOACS -> + Lambda SOACS -> + Builder GPU DistEnv +versionedRegularMap segments env inps ress pat aux w arrs form map_lam = do + (outer_suff, _) <- + sufficientParallelism "suff_outer_map" (NE.toList segments) mempty Nothing + + let fullFlatten = + regularRepVars <$> transformInnerMap segments env inps pat w arrs map_lam + + outerOnly = do + env' <- + transformScalarStm segments env inps ress $ + Let pat aux $ + Op $ + Screma w arrs form + pure $ regularResultVars ress env' + + full_body <- regularBranchBody fullFlatten + outer_body <- regularBranchBody outerOnly + + let result_ts = + [ t `arrayOfShape` segmentsShape segments + | DistResult _ (DistType _ _ t) _ <- ress + ] + + match_res <- + certifying (distCerts inps aux env) $ + kernelAlternatives "match_res" result_ts full_body [(outer_suff, outer_body)] + + pure $ insertRegulars (map distResTag ress) match_res env + +transformDistBasicOp :: + Segments -> + DistEnv -> + ( DistInputs, + DistResult, + PatElem Type, + StmAux (), + BasicOp + ) -> + Builder GPU DistEnv +transformDistBasicOp segments env (inps, res, pe, aux, e) = + case e of + BinOp {} -> + scalarCase + CmpOp {} -> + scalarCase + ConvOp {} -> + scalarCase + UnOp {} -> + scalarCase + Assert {} -> + scalarCase + -- Potentially no need for this + ArrayLit [] row_type + | not $ any (isVariant inps env) (arrayDims row_type) -> do + let resultType = + Array + (elemType row_type) + (segmentsShape segments <> Shape [intConst Int64 0] <> arrayShape row_type) + NoUniqueness + v <- letExp "arraylit_empty_reg" =<< eBlank resultType + pure $ insertRegulars [distResTag res] [v] env + | otherwise -> do + ns <- dataArr segments env inps $ intConst Int64 0 + (flags, offsets, _elems) <- doRepIota ns + let resultType = Array (elemType row_type) (Shape [intConst Int64 0]) NoUniqueness + elems <- letExp "arraylit_empty_elems" =<< eBlank resultType + pure $ insertIrregular ns flags offsets (distResTag res) elems Dense env + -- TODO: not sure about this + ArrayVal vs row_type -> do + base_v <- letExp "arraylit_base" $ BasicOp $ ArrayVal vs row_type + res_v <- letExp "arraylit_reg" $ BasicOp $ Replicate (segmentsShape segments) (Var base_v) + pure $ insertRegulars [distResTag res] [res_v] env + ArrayLit vs row_type + | not $ any (isVariant inps env) (arrayDims row_type) -> do + res_v <- + if any (isVariant inps env) vs + then do + let seg_shape = segmentsShape segments + one = intConst Int64 1 + arr_outer_dim = intConst Int64 $ toInteger $ length vs + expected = seg_shape <> arrayShape row_type + stacked = seg_shape <> Shape [one] <> arrayShape row_type + d = segmentsRank segments + + vs_reg <- mapM (liftSubExpRegular segments inps env expected) vs + + vs_reg_1 <- + forM vs_reg $ \v -> do + v_t <- lookupType v + letExp (baseName v <> "_stack") $ + BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) stacked + + case vs_reg_1 of + [] -> undefined + [v] -> + pure v + v : vs' -> + letExp "arraylit_reg" $ BasicOp $ Concat d (v NE.:| vs') arr_outer_dim + else do + base_v <- letExp "arraylit_base" $ BasicOp $ ArrayLit vs row_type + letExp "arraylit_reg" $ + BasicOp $ + Replicate (segmentsShape segments) (Var base_v) + pure $ insertRegulars [distResTag res] [res_v] env + | otherwise -> do + let arr_outer_dim = intConst Int64 $ fromIntegral $ length vs + vs_reparr <- mapM (dataArr segments env inps) vs + dim_arrs <- mapM (dataArr segments env inps) (arrayDims row_type) + num_segments <- letSubExp "num_segments" =<< toExp (segmentCount segments) + ~[row_size, full_size] <- letTupExp "arraylit_row_size" <=< segMap (MkSolo num_segments) $ \(MkSolo i) -> do + vals <- mapM (\dim_arr -> letSubExp "dim_i" =<< eIndex dim_arr [eSubExp i]) dim_arrs + n <- letSubExp "n" <=< toExp $ product $ map pe64 vals + fs <- letSubExp "fs" <=< toExp $ pe64 n * pe64 arr_outer_dim + pure $ subExpsRes [n, fs] + + (_, _, row_II1) <- doRepIota row_size + (_, _, row_II2) <- doSegIota row_size + + row_flat_size <- arraySize 0 <$> lookupType row_II1 + + (full_flags, full_offset, full_II1) <- doRepIota full_size + + m <- arraySize 0 <$> lookupType full_II1 + data_t <- lookupType (head vs_reparr) + let pt = elemType data_t + let resultType = Array pt (Shape [m]) NoUniqueness + elems_blank <- letExp "blank_res" =<< eBlank resultType + + elems <- + foldlM + ( \elems (var_num, arr) -> do + letExp "irregular_scatter_elems" <=< genScatter elems row_flat_size $ \gid -> do + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex row_II1 [eSubExp gid] + + row_size_i <- + letSubExp "row_size_i" =<< eIndex row_size [eSubExp segment_i] + + segment_global_o <- + letSubExp "segment_global_o" + =<< eIndex full_offset [eSubExp segment_i] + + v' <- + letSubExp "v" =<< eIndex arr [eSubExp gid] + + o' <- letSubExp "o" =<< eIndex row_II2 [eSubExp gid] + + i <- + letExp "i" + =<< toExp + ( pe64 o' + + pe64 segment_global_o + + pe64 row_size_i * pe64 (intConst Int64 var_num) + ) + + pure (i, v') + ) + elems_blank + $ zip [0 ..] vs_reparr + + pure $ insertIrregular full_size full_flags full_offset (distResTag res) elems Dense env + Opaque _op se + | Var v <- se, + Just (DistInput rt_in _) <- lookup v inps -> + -- TODO: actually insert opaques + pure $ insertRep (distResTag res) (resVar rt_in env) env + | otherwise -> + scalarCase + -- TODO: Probably have to change this. + Reshape arr _ -> do + irreg_v <- getIrregRep segments env inps arr + pure $ insertRep (distResTag res) (Irregular irreg_v) env + Index arr slice + | null $ sliceDims slice -> + scalarCase + | otherwise -> do + -- Maximally irregular case. + num_segments <- letSubExp "num_segments" =<< toExp (segmentCount segments) + ns <- letExp "slice_sizes" <=< renameExp <=< segMap (MkSolo num_segments) $ \(MkSolo segment) -> do + segment_is <- segmentCoordsFromFlat segments segment + slice_ns <- mapM (readInput segments env segment_is inps) $ sliceDims slice + fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns + (_n, offsets, m) <- exScanAndSum ns + (_, _, repiota_D) <- doRepIota ns + flags <- genFlags m offsets + elems <- letExp "elems" <=< renameExp <=< segMap (NE.singleton m) $ \is -> do + segment <- letSubExp "segment" =<< eIndex repiota_D (toList $ fmap eSubExp is) + segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] + segment_is <- segmentCoordsFromFlat segments segment + readInputs segments env segment_is inps + let slice' = + fixSlice (fmap pe64 slice) $ + unflattenIndex (map pe64 (sliceDims slice)) $ + subtract (pe64 segment_start) . pe64 $ + NE.head is + auxing aux $ + fmap (subExpsRes . pure) . letSubExp "v" + =<< eIndex arr (map toExp slice') + pure $ insertIrregular ns flags offsets (distResTag res) elems Dense env + Iota n (Constant x) (Constant s) Int64 + | zeroIsh x, + oneIsh s -> do + ns <- dataArr segments env inps n + (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns + pure $ insertIrregular ns flags offsets (distResTag res) elems Dense env + Iota n x s it -> do + ns <- dataArr segments env inps n + xs <- dataArr segments env inps x + ss <- dataArr segments env inps s + (res_F, res_O, res_D) <- certifying (distCerts inps aux env) $ doSegIota ns + (_, _, repiota_D) <- doRepIota ns + m <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "iota_D_fixed" <=< segMap (MkSolo m) $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex repiota_D [eSubExp i] + v' <- letSubExp "v" =<< eIndex res_D [eSubExp i] + x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] + s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] + fmap (subExpsRes . pure) . letSubExp "v" <=< toExp $ + primExpFromSubExp (IntType it) x' + ~+~ sExt it (untyped (pe64 v')) + ~*~ primExpFromSubExp (IntType it) s' + pure $ insertIrregular ns res_F res_O (distResTag res) res_D' Dense env + Concat 0 arr shp -> do + ns <- dataArr segments env inps shp + reparr <- mapM (getIrregRep segments env inps) (NE.toList arr) + rep' <- concatIrreg segments env ns reparr + pure $ insertRep (distResTag res) (Irregular rep') env + -- TODO: add invariant special handling + Concat d arr shp -> do + ns <- dataArr segments env inps shp + reparr <- mapM (getIrregRep segments env inps) (NE.toList arr) + -- typearr <- mapM lookupType arr + typearr <- + forM arr $ \v -> + case lookup v inps of + Just inp -> pure $ distInputType inp + Nothing -> lookupType v + rep' <- concatIrregAlongDim segments env ns reparr (NE.toList typearr) inps d + pure $ insertRep (distResTag res) (Irregular rep') env + + -- TODO: add invaraint special handling + Replicate (Shape [n]) (Var v) -> do + ns <- dataArr segments env inps n + rep <- getIrregRep segments env inps v + rep' <- replicateIrreg segments env ns (baseName v) rep + pure $ insertRep (distResTag res) (Irregular rep') env + Replicate (Shape [n]) (Constant v) -> do + ns <- dataArr segments env inps n + (res_F, res_O, res_D) <- + certifying (distCerts inps aux env) $ doSegIota ns + w <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "rep_const" $ BasicOp $ Replicate (Shape [w]) (Constant v) + pure $ insertIrregular ns res_F res_O (distResTag res) res_D' Dense env + Replicate (Shape dims) (Constant v) -> do + dim_arrs <- mapM (dataArr segments env inps) dims + seg_number <- arraySize 0 <$> lookupType (head dim_arrs) + mul_dims <- letExp "mul_dims" <=< segMap (MkSolo seg_number) $ \(MkSolo i) -> do + vals <- mapM (\dim_arr -> letSubExp "dim_i" =<< eIndex dim_arr [eSubExp i]) dim_arrs + n <- letSubExp "n" <=< toExp $ product $ map pe64 vals + pure [subExpRes n] + (res_F, res_O, res_D) <- + certifying (distCerts inps aux env) $ doSegIota mul_dims + w <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "rep_const" $ BasicOp $ Replicate (Shape [w]) (Constant v) + pure $ insertIrregular mul_dims res_F res_O (distResTag res) res_D' Dense env + Replicate (Shape []) (Var v) -> + case lookup v inps of + Just (DistInputFree v' _) -> do + v'' <- + letExp (baseName v' <> "_copy") . BasicOp $ + Replicate mempty (Var v') + pure $ insertRegulars [distResTag res] [v''] env + Just (DistInput rt _) -> + case resVar rt env of + Irregular r -> do + let name = baseName (irregularD r) <> "_copy" + elems_copy <- + letExp name . BasicOp $ + Replicate mempty (Var $ irregularD r) + let rep = Irregular $ r {irregularD = elems_copy} + pure $ insertRep (distResTag res) rep env + Regular v' -> do + v'' <- + letExp (baseName v' <> "_copy") . BasicOp $ + Replicate mempty (Var v') + pure $ insertRegulars [distResTag res] [v''] env + Nothing -> do + v' <- + letExp (baseName v <> "_copy_free") . BasicOp $ + Replicate (segmentsShape segments) (Var v) + pure $ insertRegulars [distResTag res] [v'] env + Replicate (Shape dims) (Var v) -> do + dim_arrs <- mapM (dataArr segments env inps) dims + seg_number <- arraySize 0 <$> lookupType (head dim_arrs) + mul_dims <- letExp "mul_dims" <=< segMap (MkSolo seg_number) $ \(MkSolo i) -> do + vals <- mapM (\dim_arr -> letSubExp "dim_i" =<< eIndex dim_arr [eSubExp i]) dim_arrs + n <- letSubExp "n" <=< toExp $ product $ map pe64 vals + pure [subExpRes n] + rep <- getIrregRep segments env inps v + rep' <- replicateIrreg segments env mul_dims (baseName v) rep + pure $ insertRep (distResTag res) (Irregular rep') env + Update _ as slice se + | Just as_t <- distInputType <$> lookup as inps -> do + num_segments <- letSubExp "num_segments" =<< toExp (segmentCount segments) + ns <- letExp "slice_sizes" + <=< renameExp + <=< segMap (MkSolo num_segments) + $ \(MkSolo seg_i) -> do + seg_is <- segmentCoordsFromFlat segments seg_i + readInputs segments env seg_is $ + filter ((`elem` sliceDims slice) . Var . fst) inps + slice_dims <- mapM (readInput segments env seg_is inps) $ sliceDims slice + n <- letSubExp "n" <=< toExp $ product $ map pe64 slice_dims + pure [subExpRes n] + -- Irregular representation of `as` + as_rep <- getIrregRep segments env inps as + IrregularRep shape flags offsets elems _ <- + ensureDenseIrregular (baseName as <> "_update") as_rep + -- Inner indices (1 and 2) of `ns` + (_, _, ii1_vss) <- doRepIota ns + (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns + -- Number of updates to perform + m <- arraySize 0 <$> lookupType ii2_vss + elems' <- letExp "elems_scatter" <=< renameExp <=< genScatter elems m $ \gid -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] + seg_is <- segmentCoordsFromFlat segments seg_i + readInputs segments env seg_is $ filter ((/= as) . fst) inps + as_dims <- readTypeDims segments env seg_is inps as_t + slice_dims <- mapM (readInput segments env seg_is inps) $ sliceDims slice + case se of + Var v -> do + let in_seg_is = + unflattenIndex (map pe64 slice_dims) (pe64 in_seg_i) + slice' = fmap pe64 slice + flat_i = + flattenIndex + (map pe64 as_dims) + (fixSlice slice' in_seg_is) + -- Value to write + v' <- letSubExp "v" =<< eIndex v (map toExp in_seg_is) + o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] + -- Index to write `v'` at + i <- letExp "i" =<< toExp (pe64 o' + flat_i) + pure (i, v') + Constant c -> do + let slice' = fmap pe64 slice + flat_i = flattenIndex (map pe64 as_dims) (fixSlice slice' []) + o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] + i <- letExp "i" =<< toExp (pe64 o' + flat_i) + pure (i, Constant c) + pure $ insertIrregular shape flags offsets (distResTag res) elems' Dense env + | otherwise -> + error "Flattening update: destination is not input." + Rearrange v perm -> do + case lookup v inps of + Just (DistInputFree v' _) -> do + v'' <- + letExp (baseName v' <> "_tr") . BasicOp $ + Rearrange v' perm + pure $ insertRegulars [distResTag res] [v''] env + Just (DistInput rt v_t) -> do + case resVar rt env of + Irregular rep -> do + rep' <- + certifying (distCerts inps aux env) $ + rearrangeIrreg segments env inps v_t perm rep + pure $ insertRep (distResTag res) (Irregular rep') env + Regular v' -> do + let r = segmentsRank segments + v'' <- + letExp (baseName v' <> "_tr") . BasicOp $ + Rearrange v' ([0 .. r - 1] ++ map (+ r) perm) + pure $ insertRegulars [distResTag res] [v''] env + Nothing -> do + let r = segmentsRank segments + v' <- + letExp (baseName v <> "_tr") . BasicOp $ + Rearrange v ([0 .. r - 1] ++ map (+ r) perm) + pure $ insertRegulars [distResTag res] [v'] env + Scratch pt dims + | not $ any (isVariant inps env) dims -> do + -- All dims are invariant result is regular across segments. + v' <- + letExp "scratch" . BasicOp $ + Scratch pt (shapeDims (segmentsShape segments) ++ dims) + pure $ insertRegulars [distResTag res] [v'] env + | [n] <- dims -> do + ns <- dataArr segments env inps n + (_n, offsets, m) <- exScanAndSum ns + flags <- genFlags m offsets + res_D <- letExp "scratch_D" $ BasicOp $ Scratch pt [m] + pure $ insertIrregular ns flags offsets (distResTag res) res_D Dense env + | otherwise -> do + dim_arrs <- mapM (dataArr segments env inps) dims + w <- arraySize 0 <$> lookupType (head dim_arrs) + ns <- letExp "scratch_sizes" <=< segMap (MkSolo w) $ \(MkSolo i) -> do + vals <- mapM (\arr -> letSubExp "d" =<< eIndex arr [eSubExp i]) dim_arrs + n <- letSubExp "n" <=< toExp $ product $ map pe64 vals + pure [subExpRes n] + (_n, offsets, m) <- exScanAndSum ns + flags <- genFlags m offsets + res_D <- letExp "scratch_D" $ BasicOp $ Scratch pt [m] + pure $ insertIrregular ns flags offsets (distResTag res) res_D Dense env + UpdateAcc {} -> + -- TODO: handle irregular case, which is however rare, and also needs + -- modifications to WithAcc. The only irregularity that is possible is in + -- the values to be written. + scalarCase + _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e + where + scalarCase = + transformScalarStm segments env inps [res] $ + Let (Pat [pe]) aux (BasicOp e) + +-- Replicates inner dimension for inputs. +onMapFreeVar :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + (VName, VName, VName) -> + VName -> + Maybe (Builder GPU (VName, MapArray IrregularRep)) +onMapFreeVar segments env inps _ws (_ws_F, _ws_O, ws_data) v = do + let segments_per_elem = ws_data + v_inp <- lookup v inps + pure $ do + ws_prod <- arraySize 0 <$> lookupType ws_data + fmap (v,) $ case v_inp of + DistInputFree v' t -> do + fmap (`MapArray` t) + . letExp (baseName v <> "_rep_free_free_inp") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + segment_is <- segmentCoordsFromFlat segments segment + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' (map eSubExp segment_is)) + -- subExpsRes . pure <$> readInput segments env segment_is inps (Var v) + DistInput rt t -> case resVar rt env of + Irregular rep -> do + ~[new_S, offsets] <- letTupExp (baseName v <> "_rep_free_irreg") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex ws_data [eSubExp i] + s <- letSubExp "s" =<< eIndex (irregularS rep) [eSubExp segment] + o <- letSubExp "o" =<< eIndex (irregularO rep) [eSubExp segment] + pure $ subExpsRes [s, o] + let rep' = + IrregularRep + { irregularS = new_S, + irregularF = irregularF rep, + irregularO = offsets, + irregularD = irregularD rep, + irregularK = Replicated + } + pure $ MapOther rep' t + Regular vs -> + fmap (`MapArray` t) + . letExp (baseName v <> "_rep_free_reg_inp") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + segment_is <- segmentCoordsFromFlat segments segment + subExpsRes . pure <$> (letSubExp "v" =<< eIndex vs (map eSubExp segment_is)) + +-- subExpsRes . pure <$> readInput segments env segment_is inps (Var v) + +onMapInputArr :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + VName -> + VName -> + Param Type -> + VName -> + Builder GPU (MapArray IrregularRep) +onMapInputArr segments env inps ws ws_O ws_data p arr = do + ws_prod <- arraySize 0 <$> lookupType ws_data + case lookup arr inps of + Just v_inp -> + case v_inp of + DistInputFree vs t -> do + let inner_shape = arrayShape $ paramType p + vs_t <- lookupType vs + v <- + if isAcc vs_t + then pure vs + else + letExp (baseName vs <> "_flat") . BasicOp . Reshape vs $ + reshapeAll (arrayShape vs_t) (Shape [ws_prod] <> inner_shape) + pure $ MapArray v t + DistInput rt _ -> + case resVar rt env of + Irregular rep -> do + onMapIrregularInputArr SingleDim segments ws ws_O ws_data p arr rep ws_prod + Regular vs -> do + let inner_shape = arrayShape $ paramType p + vs_t <- lookupType vs + v <- + letExp (baseName arr <> "_reg_flat") . BasicOp . Reshape vs $ + reshapeAll (arrayShape vs_t) (Shape [ws_prod] <> inner_shape) + pure $ MapArray v (stripArray 1 vs_t) + -- undefined + Nothing -> do + arr_row_t <- rowType <$> lookupType arr + arr_rep <- + letExp (baseName arr <> "_inp_rep") . BasicOp $ + Replicate (segmentsShape segments) (Var arr) + arr_rep_t <- lookupType arr_rep + v <- + letExp (baseName arr <> "_inp_rep_flat") . BasicOp . Reshape arr_rep $ + reshapeAll (arrayShape arr_rep_t) (Shape [ws_prod] <> arrayShape arr_row_t) + pure $ MapArray v arr_row_t + +transformInnerMap :: + Segments -> + DistEnv -> + DistInputs -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + Builder GPU [ResRep] +transformInnerMap segments env inps pat w arrs map_lam + | not (isVariant inps env w) = do + traceM "transformInnerMap: w is invariant, treating as multi-dim map" + transformInnerMapMultiDim segments env inps pat w arrs map_lam + | otherwise = do + traceM "transformInnerMap: w is variant, treating as single-dim map" + transformInnerMapSingleDim segments env inps pat w arrs map_lam + +transformInnerMapSingleDim :: + Segments -> + DistEnv -> + DistInputs -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + Builder GPU [ResRep] +transformInnerMapSingleDim segments env inps pat w arrs map_lam = do + ws <- dataArr segments env inps w + (ws_F, ws_O, ws_data) <- doRepIota ws + new_segment <- arraySize 0 <$> lookupType ws_data + arrs' <- + zipWithM + (onMapInputArr segments env inps ws ws_O ws_data) + (lambdaParams map_lam) + arrs + distributeAndTransformInnerMap + SingleDim + (ws_F, ws_O, ws) + (NE.singleton new_segment) + inps + pat + arrs' + (onMapFreeVar segments env inps ws (ws_F, ws_O, ws_data)) + map_lam + +onMapIrregularInputArr :: + InnerMapMode -> + Segments -> + VName -> + VName -> + VName -> + Param Type -> + VName -> + IrregularRep -> + SubExp -> + Builder GPU (MapArray IrregularRep) +onMapIrregularInputArr mode new_segments ws ws_O ws_data p arr rep ws_prod = do + -- new_segments has already has the the new w inside unlike other functions + rep_t <- lookupType $ irregularD rep + when (arrayRank rep_t > 1) $ + error $ + error "onMapIrregularInputArr: irregularD is not 1D" + if null (arrayDims $ paramType p) + then do + -- assuimg the irregD is 1D size(irregularD rep) == ws_prod should hold and this should be fine + let old_shape = arrayShape rep_t + new_shape = + case mode of + SingleDim -> Shape [ws_prod] + MultiDim -> segmentsShape new_segments + case irregularK rep of + Dense -> do + v_reshaped <- letExp (baseName (paramName p) <> "_reshaped") $ BasicOp $ Reshape (irregularD rep) $ reshapeAll old_shape new_shape + pure $ MapArray v_reshaped (stripArray 1 rep_t) + Replicated -> do + new_flat <- + letExp (baseName arr <> "_flat_expand") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + j <- letSubExp "j" =<< eIndex ws_data [eSubExp i] + data_off <- letSubExp "data_off" =<< eIndex (irregularO rep) [eSubExp j] + seg_start <- letSubExp "seg_start" =<< eIndex ws_O [eSubExp j] + local_pos <- letSubExp "local_pos" <=< toExp $ pe64 i - pe64 seg_start + flat_idx <- letSubExp "flat_idx" <=< toExp $ pe64 data_off + pe64 local_pos + fmap (subExpsRes . pure) $ letSubExp "elem" =<< eIndex (irregularD rep) [eSubExp flat_idx] + v_reshaped <- letExp (baseName (paramName p) <> "_reshaped") $ BasicOp $ Reshape new_flat $ reshapeAll old_shape new_shape + pure $ MapArray v_reshaped (stripArray 1 rep_t) + else do + -- We need to split multi-dimensional irregular segments + -- into per-row segments. Compute per-row size by dividing + -- each segment's total size by the number of inner iterations. + -- Important TODO: I should ask troels about this. + -- we should make this consistent. + -- we can avoid getting per_row_size by division. + num_segments <- arraySize 0 <$> lookupType ws + -- per_row_size[s] = irregularS[s] / ws[s] + per_row_size <- + letExp (baseName (paramName p) <> "_per_row_size") + <=< segMap (MkSolo num_segments) + $ \(MkSolo s) -> do + total_s <- letSubExp "total_s" =<< eIndex (irregularS rep) [eSubExp s] + num_rows_s <- letSubExp "num_rows_s" =<< eIndex ws [eSubExp s] + row_size <- + letSubExp "row_size" + =<< eIf + (toExp $ pe64 num_rows_s .==. 0) + (eBody [toExp $ intConst Int64 0]) + (eBody [toExp $ pe64 total_s `div` pe64 num_rows_s]) + pure $ subExpsRes [row_size] + new_S <- + letExp (baseName (paramName p) <> "_new_S") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + seg_i <- letSubExp "seg_i" =<< eIndex ws_data [eSubExp i] + sz <- letSubExp "sz" =<< eIndex per_row_size [eSubExp seg_i] + pure $ subExpsRes [sz] + rep' <- case irregularK rep of + Dense -> do + (new_F, new_O, _new_elems) <- doSegIota new_S + pure $ + IrregularRep + { irregularD = irregularD rep, + irregularF = new_F, + irregularS = new_S, + irregularO = new_O, + irregularK = Dense + } + Replicated -> do + new_O <- + letExp (baseName (paramName p) <> "_new_O") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + seg_i <- letSubExp "seg_i" =<< eIndex ws_data [eSubExp i] + row_size <- letSubExp "row_size" =<< eIndex per_row_size [eSubExp seg_i] + seg_row_start <- letSubExp "seg_row_start" =<< eIndex ws_O [eSubExp seg_i] + row_in_seg <- letSubExp "row_in_seg" <=< toExp $ pe64 i - pe64 seg_row_start + base_off <- letSubExp "base_off" =<< eIndex (irregularO rep) [eSubExp seg_i] + off <- letSubExp "off" <=< toExp $ pe64 base_off + pe64 row_in_seg * pe64 row_size + pure $ subExpsRes [off] + m <- arraySize 0 <$> lookupType (irregularD rep) + -- we will have mutliple write but it is the same value so it should be fine. + new_F <- genFlags m new_O + pure $ + IrregularRep + { irregularD = irregularD rep, + irregularF = new_F, + irregularS = new_S, + irregularO = new_O, + irregularK = Replicated + } + pure $ MapOther rep' rep_t + +onMapInputArrMultiDim :: + Segments -> + SubExp -> + DistEnv -> + DistInputs -> + VName -> + VName -> + VName -> + Param Type -> + VName -> + Builder GPU (MapArray IrregularRep) +onMapInputArrMultiDim old_segments w env inps ws ws_O ws_data p arr = do + case lookup arr inps of + Just v_inp -> + case v_inp of + DistInputFree vs t -> pure $ MapArray vs t + DistInput rt t -> case resVar rt env of + Irregular rep -> do + ws_prod <- arraySize 0 <$> lookupType ws_data + onMapIrregularInputArr MultiDim (old_segments <> pure w) ws ws_O ws_data p arr rep ws_prod + Regular vs -> do + vs_t <- lookupType vs + -- let's be cautious and make sure it has the correct shape + let expected_shape = segmentsShape old_segments <> arrayShape t + if arrayShape vs_t == expected_shape + then pure $ MapArray vs t + else do + v <- + letExp (baseName arr <> "_reg_reshape") . BasicOp . Reshape vs $ + reshapeAll (arrayShape vs_t) expected_shape + pure $ MapArray v t + Nothing -> do + arr_row_t <- rowType <$> lookupType arr + arr_rep <- + letExp (baseName arr <> "_inp_rep") . BasicOp $ + Replicate (segmentsShape old_segments) (Var arr) + pure $ MapArray arr_rep arr_row_t + +onMapFreeVarMultiDim :: + Segments -> + SubExp -> + DistEnv -> + DistInputs -> + VName -> + Maybe (Builder GPU (VName, MapArray IrregularRep)) +onMapFreeVarMultiDim segments w env inps v = do + v_inp <- lookup v inps + pure $ fmap (v,) $ case v_inp of + DistInputFree v' t -> do + v_rep <- replicateForW segments w v' + pure $ MapArray v_rep t + DistInput rt t -> case resVar rt env of + Regular v' -> do + v_rep <- replicateForW segments w v' + pure $ MapArray v_rep t + Irregular rep -> do + -- Can replicate as well + old_nseg <- arraySize 0 <$> lookupType (irregularS rep) + new_nseg <- letSubExp "new_nseg" <=< toExp $ pe64 old_nseg * pe64 w + ~[new_S, offsets] <- letTupExp (baseName v <> "_rep_free_irreg") + <=< segMap (MkSolo new_nseg) + $ \(MkSolo i) -> do + old_seg <- letSubExp "old_seg" <=< toExp $ pe64 i `quot` pe64 w + s <- letSubExp "s" =<< eIndex (irregularS rep) [eSubExp old_seg] + o <- letSubExp "o" =<< eIndex (irregularO rep) [eSubExp old_seg] + pure $ subExpsRes [s, o] + let rep' = + IrregularRep + { irregularS = new_S, + irregularF = irregularF rep, + irregularO = offsets, + irregularD = irregularD rep, + irregularK = Replicated + } + pure $ MapOther rep' t + +-- old_nseg <- arraySize 0 <$> lookupType (irregularS rep) +-- new_nseg <- letSubExp "new_nseg" <=< toExp $ pe64 old_nseg * pe64 w + +-- new_S <- +-- letExp (baseName v <> "_new_S") +-- <=< segMap (MkSolo new_nseg) +-- $ \(MkSolo i) -> do +-- old_seg <- letSubExp "old_seg" <=< toExp $ pe64 i `quot` pe64 w +-- s <- letSubExp "s" =<< eIndex (irregularS rep) [eSubExp old_seg] +-- pure [subExpRes s] + +-- (new_F, new_O, new_II1) <- doRepIota new_S +-- m <- arraySize 0 <$> lookupType new_II1 + +-- new_D <- +-- letExp (baseName v <> "_new_D") +-- <=< segMap (MkSolo m) +-- $ \(MkSolo i) -> do +-- new_seg <- letSubExp "new_seg" =<< eIndex new_II1 [eSubExp i] +-- old_seg <- letSubExp "old_seg" <=< toExp $ pe64 new_seg `quot` pe64 w +-- new_off <- letSubExp "new_off" =<< eIndex new_O [eSubExp new_seg] +-- old_off <- letSubExp "old_off" =<< eIndex (irregularO rep) [eSubExp old_seg] +-- j <- letSubExp "j" <=< toExp $ pe64 i - pe64 new_off +-- x <- letSubExp "x" =<< eIndex (irregularD rep) [toExp $ pe64 old_off + pe64 j] +-- pure [subExpRes x] + +-- pure $ +-- MapOther +-- IrregularRep +-- { irregularS = new_S, +-- irregularF = new_F, +-- irregularO = new_O, +-- irregularD = new_D +-- } +-- t + +-- | Replicate an array to insert a new inner dimension after the +-- existing segment dimensions. +replicateForW :: Segments -> SubExp -> VName -> Builder GPU VName +replicateForW segments w v = do + v_t <- lookupType v + let seg_rank = length (NE.toList segments) + v_rank = arrayRank v_t + perm = [1 .. seg_rank] ++ [0] ++ [seg_rank + 1 .. v_rank] + v_rep <- + letExp (baseName v <> "_free_rep") . BasicOp $ + Replicate (Shape [w]) (Var v) + letExp (baseName v <> "_free_rep_tr") . BasicOp $ + Rearrange v_rep perm + +transformInnerMapMultiDim :: + Segments -> + DistEnv -> + DistInputs -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + Builder GPU [ResRep] +transformInnerMapMultiDim segments env inps pat w arrs map_lam = do + ws <- dataArr segments env inps w + (ws_F, ws_O, ws_data) <- doRepIota ws + arrs' <- + zipWithM + (onMapInputArrMultiDim segments w env inps ws ws_O ws_data) + (lambdaParams map_lam) + arrs + distributeAndTransformInnerMap + MultiDim + (ws_F, ws_O, ws) + (segments <> pure w) + inps + pat + arrs' + (onMapFreeVarMultiDim segments w env inps) + map_lam + +distributeAndTransformInnerMap :: + InnerMapMode -> + (VName, VName, VName) -> + Segments -> + DistInputs -> + Pat Type -> + [MapArray IrregularRep] -> + (VName -> Maybe (Builder GPU (VName, MapArray IrregularRep))) -> + Lambda SOACS -> + Builder GPU [ResRep] +distributeAndTransformInnerMap mode ws_triple new_segment inps pat arrs' onFreeVar map_lam = do + let free = freeIn map_lam + outer_scope <- askScope + let input_scope = scopeOfDistInputs inps `M.difference` outer_scope + free_sizes <- + localScope input_scope $ + foldMap freeIn <$> mapM lookupType (namesToList free) + let free_and_sizes = namesToList $ free <> free_sizes + traceM "distributing inner map with free variables\n" + traceM $ unlines ["inputs: ", prettyString inps, "free variables:", prettyString free_and_sizes] + (free_replicated, replicated) <- + fmap unzip . sequence $ + mapMaybe + onFreeVar + free_and_sizes + free_ps <- + zipWithM + newParam + (map ((<> "_free") . baseName) free_and_sizes) -- this should free_replicated? + (map mapArrayRowType replicated) + scope <- askScope + let substs = M.fromList $ zip free_replicated $ map paramName free_ps + map_lam' = + substituteNames + substs + ( map_lam + { lambdaParams = free_ps <> lambdaParams map_lam + } + ) + (distributed, arrmap) = + distributeMap scope pat new_segment (replicated <> arrs') map_lam' + m = + transformDistributedInnerMap mode ws_triple arrmap new_segment distributed + traceM $ unlines ["inner map distributed", prettyString distributed] + (res, stms) <- runReaderT (runBuilder m) scope + addStms stms + -- order the result representations in the same order as the pattern + pure $ resRepsInPatOrder pat res + +-- Reduction or scan operators may not have any free variables that are variant +-- to the nest (that is, are inputs to the distributed operation). This is +-- because we would be unable to express them as SegScan/SegReds. Fixing this +-- would require modifications to the SegOp representation, but it is likely not +-- worth it, as such operators are extremely rare - and we can just fall back on +suitableOperator :: DistEnv -> DistInputs -> Lambda SOACS -> [SubExp] -> Bool +suitableOperator env inps lam nes = + allNames notVariant (freeIn lam) + -- && not (any (isVariant inps env) nes) -- TODO: maybe not needed + && all primType (lambdaReturnType lam) -- TODO + where + notVariant v = isNothing $ M.lookup v $ inputReps inps env + +suitableSegOpMap :: DistEnv -> DistInputs -> Lambda SOACS -> Bool +suitableSegOpMap env inps map_lam = + not (any isParallelStm (bodyStms $ lambdaBody map_lam)) + -- TODO: do we want to add variants as inputs? + && allNames (not . isVariant inps env . Var) (freeIn map_lam) + +-- doSegScan :: [Scan SOACS] -> VName -> [VName] -> Builder GPU [VName] +doSegScan :: [Scan SOACS] -> VName -> [VName] -> Segments -> DistInputs -> DistEnv -> Builder GPU [VName] +doSegScan scans flags elems segments inps env = do + let scan = singleScan scans + -- TODO: FixME: this is temp hack + let zeros = replicate (segmentsRank segments) (Constant $ IntValue $ intValue Int64 (0 :: Int)) + let nes = scanNeutral scan + nes' <- mapM (readInput segments env zeros inps) nes + genSegScan "scan" (soacsLambdaToGPU $ scanLambda scan) nes' flags elems + +doSegScanomap :: + [Scan SOACS] -> + VName -> + [VName] -> + Lambda SOACS -> + Segments -> + DistInputs -> + DistEnv -> + Builder GPU [VName] +doSegScanomap scans flags elems map_lam segments inps env = do + let scan = singleScan scans + let zeros = replicate (segmentsRank segments) (Constant $ IntValue $ intValue Int64 (0 :: Int)) + let nes = scanNeutral scan + nes' <- mapM (readInput segments env zeros inps) nes + genSegScanomap + "scanomap" + (soacsLambdaToGPU $ scanLambda scan) + nes' + flags + (soacsLambdaToGPU map_lam) + elems + +doSegMaposcanomap :: + [Scan SOACS] -> + VName -> + [VName] -> + Lambda SOACS -> + Lambda SOACS -> + Segments -> + DistInputs -> + DistEnv -> + Builder GPU [VName] +doSegMaposcanomap scans flags elems post_lam map_lam segments inps env = do + let scan = singleScan scans + let zeros = replicate (segmentsRank segments) (Constant $ IntValue $ intValue Int64 (0 :: Int)) + let nes = scanNeutral scan + nes' <- mapM (readInput segments env zeros inps) nes + genSegScanomapWithPost + "maposcanomap" + (soacsLambdaToGPU $ scanLambda scan) + nes' + flags + (soacsLambdaToGPU post_lam) + (soacsLambdaToGPU map_lam) + elems + +postMapResultRep :: Segments -> DistEnv -> DistInputs -> SubExp -> VName -> VName -> VName -> VName -> IrregularKind -> Param Type -> Builder GPU ResRep +postMapResultRep segments env inps w ws_F ws_O ws_S elems elems_kind post_param + | isVariant inps env w || any isTypeVariant (arrayDims p_t) = + pure $ + Irregular $ + IrregularRep + { irregularS = ws_S, + irregularF = ws_F, + irregularO = ws_O, + irregularD = elems, + irregularK = elems_kind + } + | otherwise = do + elem_t <- lookupType elems + elem_v' <- + letExp (baseName elems <> "_reshaped") . BasicOp $ + Reshape elems $ + reshapeAll (arrayShape elem_t) expected_shape + pure $ Regular elem_v' + where + p_t = paramType post_param + expected_shape = segmentsShape segments <> Shape [w] <> arrayShape p_t + + isTypeVariant (Var v) = isVariant inps env (Var v) + isTypeVariant Constant {} = False + +transformPostMaposcanomap :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + [Scan SOACS] -> + Lambda SOACS -> + Builder GPU DistEnv +transformPostMaposcanomap segments env inps res pat w arrs post_lam scans map_lam = do + reps <- mapM (segOpInputRep segments env inps) arrs + (ws_F, ws_O, ws_S, elems, elems_kind) <- + prepareSegOpInputs segments env inps w reps arrs + elems' <- doSegScanomap scans ws_F elems map_lam segments inps env + + let post_params = lambdaParams post_lam + + post_reps <- + zipWithM + (\elem' post_param -> + postMapResultRep + segments env inps w ws_F ws_O ws_S + elem' elems_kind post_param) + elems' + post_params + + transformMaposcanomapPostReps segments env inps res pat w post_reps post_lam + +transformMaposcanomapPostReps :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Pat Type -> + SubExp -> + [ResRep] -> + Lambda SOACS -> + Builder GPU DistEnv +transformMaposcanomapPostReps segments env inps res pat w post_reps post_lam = do + let (inps_local, env_local, next_tag) = localiseInputs env inps + post_params = lambdaParams post_lam + post_tags = map ResTag [next_tag ..] + post_inputs = + zipWith + (\p tag -> (paramName p, DistInput tag (paramType p `arrayOfRow` w))) + post_params + post_tags + <> inps_local + post_env = insertReps (zip post_tags post_reps) env_local + post_arrs = map paramName post_params + traceM "transforming post maposcanomap results with inputs\n" + post_res <- transformInnerMap segments post_env post_inputs pat w post_arrs post_lam + pure $ insertReps (zip (map distResTag res) post_res) env + +transformPreMaposcanomap :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + SubExp -> + [VName] -> + Lambda SOACS -> + [Scan SOACS] -> + Lambda SOACS -> + Builder GPU DistEnv +transformPreMaposcanomap segments env inps res w arrs post_lam scans map_lam = do + map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> + PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) + map_res_all <- transformInnerMap segments env inps map_pat w arrs map_lam + (ws_F, ws_O, ws_S, elems, elems_kind) <- + prepareSegOpInputs segments env inps w map_res_all (patNames map_pat) + id_lam <- mkIdentityLambda $ lambdaReturnType map_lam + elems' <- doSegMaposcanomap scans ws_F elems post_lam id_lam segments inps env + insertSegOpMapResults + segments + ws_S + ws_F + ws_O + elems_kind + (zip res elems') + env + +transformPrePostMaposcanomap :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + [Scan SOACS] -> + Lambda SOACS -> + Builder GPU DistEnv +transformPrePostMaposcanomap segments env inps res pat w arrs post_lam scans map_lam = do + map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> + PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) + + map_res_all <- transformInnerMap segments env inps map_pat w arrs map_lam + + let num_scan_results = scanResults scans + post_params = lambdaParams post_lam + (scan_res_names, _) = splitAt num_scan_results $ patNames map_pat + (scan_params, _) = splitAt num_scan_results post_params + (scan_res, map_res) = splitAt num_scan_results map_res_all + + (ws_F, ws_O, ws_S, scan_elems, scan_elems_kind) <- + prepareSegOpInputs segments env inps w scan_res scan_res_names + + scan_elems' <- doSegScan scans ws_F scan_elems segments inps env + + scan_res' <- + zipWithM + ( \elem' scan_param -> + postMapResultRep + segments + env + inps + w + ws_F + ws_O + ws_S + elem' + scan_elems_kind + scan_param + ) + scan_elems' + scan_params + + transformMaposcanomapPostReps + segments + env + inps + res + pat + w + (scan_res' <> map_res) + post_lam + +-- Hacky fix to get result representations in the same order as the pattern +resRepsInPatOrder :: Pat Type -> [(VName, ResRep)] -> [ResRep] +resRepsInPatOrder pat reps = + let rep_map = M.fromList reps + lookupRes v = + case M.lookup v rep_map of + Just rep -> rep + Nothing -> + error $ + "resRepsInPatOrder: missing result for " + ++ prettyString v + in map lookupRes (patNames pat) + +segOpInputRep :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + Builder GPU ResRep +segOpInputRep segments env inps arr = + case lookup arr inps of + Just (DistInput rt _) -> + pure $ resVar rt env + Just (DistInputFree arr' _) -> + pure $ Regular arr' + Nothing -> + Irregular <$> getIrregRep segments env inps arr + +-- Basically we need to make our arrays ready for our segscan/segred. +-- Regular arrays are flattened only across the outer segment dimensions and +-- the SOAC width; any row shape expected by the consumer is preserved. +-- we need to check the dense/replicated status of the input. +-- if all of scan inputs are replicated we are fine. +-- otherwise, we need to make the replicated inputs dense. +-- for regulars we can just use the segment descriptor and this should be also the same descriptor for dense irregulars. +prepareSegOpInputs :: + Segments -> + DistEnv -> + DistInputs -> + SubExp -> + [ResRep] -> + [VName] -> + Builder GPU (VName, VName, VName, [VName], IrregularKind) +prepareSegOpInputs segments env inps w reps names + | all isRegular reps = do + ws <- dataArr segments env inps w + (ws_F, ws_O, ws_data) <- doRepIota ws + m <- arraySize 0 <$> lookupType ws_data + names' <- mapM (flattenRegularRep m) reps + pure (ws_F, ws_O, ws, names', Dense) + | all isReplicatedIrregular reps = do + let Irregular rep0 = head reps + pure (irregularF rep0, irregularO rep0, irregularS rep0, map getData reps, Replicated) + | otherwise = do + desc_rep <- findOrMakeDense reps + m <- arraySize 0 <$> lookupType (irregularD desc_rep) + names' <- zipWithM (normalise m) reps names + pure (irregularF desc_rep, irregularO desc_rep, irregularS desc_rep, names', Dense) + where + isRegular (Regular _) = True + isRegular _ = False + + isReplicatedIrregular (Irregular rep) = irregularK rep == Replicated + isReplicatedIrregular _ = False + + flattenRegularRep m (Regular v) = + flattenRegularToRows segments m v + flattenRegularRep _ _ = + error "prepareSegOpInputs: impossible irregular regular input" + getData (Irregular rep) = irregularD rep + getData _ = error "prepareSegOpInputs: impossible" + + findOrMakeDense rs = + case [rep | Irregular rep <- rs, irregularK rep == Dense] of + rep : _ -> pure rep + [] -> + case [rep | Irregular rep <- rs] of + rep : _ -> ensureDenseIrregular "segop_desc" rep + [] -> error "prepareSegOpInputs: impossible" + + normalise m rep v = + case rep of + Regular v' -> + flattenRegularToRows segments m v' + Irregular ir + | irregularK ir == Dense -> + pure $ irregularD ir + | otherwise -> + irregularD <$> ensureDenseIrregular (baseName v <> "_dense") ir + +flattenRegularToRows :: Segments -> SubExp -> VName -> Builder GPU VName +flattenRegularToRows segments m v = do + v_t <- lookupType v + when (arrayRank v_t < segmentsRank segments + 1) $ + error "prepareSegOpInputs: regular input rank too small" + let row_shape = arrayShape $ stripArray (segmentsRank segments + 1) v_t + letExp (baseName v <> "_flat") . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) (Shape [m] <> row_shape) + +insertSegOpMapResults :: + Segments -> + VName -> + VName -> + VName -> + IrregularKind -> + [(DistResult, VName)] -> + DistEnv -> + Builder GPU DistEnv +insertSegOpMapResults segments segs flags offsets kind bnds env0 = + foldM insert env0 bnds + where + insert env (dist_res, v) + | isRegularDistResult dist_res = do + let DistType _ _ t = distResType dist_res + expected_shape = segmentsShape segments <> arrayShape t + v_t <- lookupType v + v' <- + letExp (baseName v <> "_reshaped") . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) expected_shape + pure $ insertRegulars [distResTag dist_res] [v'] env + | otherwise = + pure $ insertIrregular segs flags offsets (distResTag dist_res) v kind env + +transformDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv +transformDistStm segments env (DistStm inps res (ScalarStm stms)) = + transformScalarStms segments env inps res (stmsFromList stms) +transformDistStm segments env (DistStm inps res (ParallelStm stm)) = do + case stm of + Let pat aux (BasicOp e) -> do + let ~[res'] = res + ~[pe] = patElems pat + transformDistBasicOp segments env (inps, res', pe, aux, e) + Let pat aux (Op (Screma w arrs form)) + | Just reds <- isReduceSOAC form, + all (\red -> suitableOperator env inps (redLambda red) (redNeutral red)) reds -> do + traceM "HELLO REDUCE" + reps <- mapM (segOpInputRep segments env inps) arrs + (flags, offsets, arr_segments, elems, _elems_kind) <- + prepareSegOpInputs segments env inps w reps arrs + -- TODO: FixME: this is temp hack + let sing_red = singleReduce reds + let zeros = replicate (length segments) (Constant $ IntValue $ intValue Int64 (0 :: Int)) + nes' <- mapM (readInput segments env zeros inps) (redNeutral sing_red) + let sing_red' = sing_red {redNeutral = nes'} + elems' <- genSegRed arr_segments flags offsets elems sing_red' + elems'' <- forM elems' $ \v -> do + v_t <- lookupType v + letExp (baseName v <> "_reshaped") . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) (segmentsShape segments) + pure $ insertRegulars (map distResTag res) elems'' env + | Just (reds, map_lam) <- isRedomapSOAC form, + suitableSegOpMap env inps map_lam, + all (\red -> suitableOperator env inps (redLambda red) (redNeutral red)) reds -> do + traceM "HELLO Fast REDOMAP" + reps <- mapM (segOpInputRep segments env inps) arrs + (ws_F, ws_O, ws_S, elems, elems_kind) <- + prepareSegOpInputs segments env inps w reps arrs + let sing_red = singleReduce reds + let zeros = replicate (length segments) (Constant $ IntValue $ intValue Int64 (0 :: Int)) + nes' <- mapM (readInput segments env zeros inps) (redNeutral sing_red) + let sing_red' = sing_red {redNeutral = nes'} + (red_elems, mapout_elems) <- + genSegRedomap ws_S ws_F ws_O elems sing_red' (soacsLambdaToGPU map_lam) + red_elems' <- forM red_elems $ \v -> do + v_t <- lookupType v + letExp (baseName v <> "_reshaped") . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) (segmentsShape segments) + let (red_res, map_res) = splitAt (redResults reds) res + env' <- + insertSegOpMapResults + segments + ws_S + ws_F + ws_O + elems_kind + (zip map_res mapout_elems) + env + pure $ insertRegulars (map distResTag red_res) red_elems' env' + | Just (reds, map_lam) <- isRedomapSOAC form, + all (\red -> suitableOperator env inps (redLambda red) (redNeutral red)) reds -> do + traceM "HELLO REDOMAP" + map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> + PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) + map_res_all <- + transformInnerMap segments env inps map_pat w arrs map_lam + let (redout_names, _) = splitAt (redResults reds) (patNames map_pat) + (redout_res, mapout_res) = splitAt (redResults reds) map_res_all + (ws_F, ws_O, ws_S, redout_names', _redout_kind) <- + prepareSegOpInputs segments env inps w redout_res redout_names + -- For multi-dim (Regular) results, flatten to 1D before segmented reduce. + -- TODO: FixME: this is temp hack + let sing_red = singleReduce reds + let zeros = replicate (length segments) (Constant $ IntValue $ intValue Int64 (0 :: Int)) + nes' <- mapM (readInput segments env zeros inps) (redNeutral sing_red) + let sing_red' = sing_red {redNeutral = nes'} + elems' <- + genSegRed ws_S ws_F ws_O redout_names' sing_red' + elems'' <- forM elems' $ \v -> do + v_t <- lookupType v + letExp (baseName v <> "_reshaped") . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) (segmentsShape segments) + let (red_tags, map_tags) = splitAt (redResults reds) $ map distResTag res + pure $ + insertRegulars red_tags elems'' $ + insertReps (zip map_tags mapout_res) env + | Just scans <- isScanSOAC form, + all (\scan -> suitableOperator env inps (scanLambda scan) (scanNeutral scan)) scans -> do + reps <- mapM (segOpInputRep segments env inps) arrs + (flags, offsets, arr_segments, elems, elems_kind) <- + prepareSegOpInputs segments env inps w reps arrs + elems' <- doSegScan scans flags elems segments inps env + pure $ + insertIrregulars arr_segments flags offsets (zip (map distResTag res) elems') elems_kind env + | Just (post_lam, scans, map_lam) <- isMaposcanomapSOAC form, + suitableSegOpMap env inps map_lam, + suitableSegOpMap env inps post_lam, + all (\scan -> suitableOperator env inps (scanLambda scan) (scanNeutral scan)) scans -> do + traceM "Status: everything integetrated" + reps <- mapM (segOpInputRep segments env inps) arrs + (ws_F, ws_O, ws_S, elems, elems_kind) <- + prepareSegOpInputs segments env inps w reps arrs + elems' <- doSegMaposcanomap scans ws_F elems post_lam map_lam segments inps env + insertSegOpMapResults + segments + ws_S + ws_F + ws_O + elems_kind + (zip res elems') + env + | Just (post_lam, scans, map_lam) <- isMaposcanomapSOAC form, + suitableSegOpMap env inps map_lam, + not $ suitableSegOpMap env inps post_lam, + all (\scan -> suitableOperator env inps (scanLambda scan) (scanNeutral scan)) scans -> do + traceM "Status: pre map integetrated" + transformPostMaposcanomap segments env inps res pat w arrs post_lam scans map_lam + | Just (post_lam, scans, map_lam) <- isMaposcanomapSOAC form, + suitableSegOpMap env inps post_lam, + not $ suitableSegOpMap env inps map_lam, + all (\scan -> suitableOperator env inps (scanLambda scan) (scanNeutral scan)) scans -> do + traceM "Status: post map integetrated" + transformPreMaposcanomap segments env inps res w arrs post_lam scans map_lam + | Just (post_lam, scans, map_lam) <- isMaposcanomapSOAC form, + all (\scan -> suitableOperator env inps (scanLambda scan) (scanNeutral scan)) scans -> do + traceM "Status Nothing integerated" + transformPrePostMaposcanomap segments env inps res pat w arrs post_lam scans map_lam + | Just map_lam <- isMapSOAC form, + all isVersionableRegularResult res -> + versionedRegularMap segments env inps res pat aux w arrs form map_lam + | Just map_lam <- isMapSOAC form -> do + map_res <- + transformInnerMap segments env inps pat w arrs map_lam + pure $ insertReps (zip (map distResTag res) map_res) env + | otherwise -> do + -- XXX: here we silently sequentialise any SOAC that is not handled + -- above. We need to make sure that we actually handle everything we + -- care about! + error "unhandled SOAC" + -- transformScalarStm segments env inps res $ + -- Let { stmPat = pat, stmAux = aux, stmExp = Op (Screma w arrs form) } + Let _ aux (Match scrutinees cases defaultCase rt) -> + if any (isVariant inps env) scrutinees + then + transformMatch flattenOps segments env inps res scrutinees cases defaultCase + -- else error $ unlines ["scrutinees: ", prettyString scrutinees, "cases:", prettyString cases, "defaultCase:", prettyString defaultCase] + else do + scope <- askScope + new_cases <- forM cases $ \(Case c body) -> do + let (case_body_inputs, case_dstms) = distributeBody scope segments inps body + + (case_body_res, case_body_stms) <- + runReaderT + ( runBuilder $ + liftBodyWithDistResults segments case_body_inputs env case_dstms res (bodyResult body) + ) + scope + pure $ Case c $ Body () case_body_stms case_body_res + new_default_body <- do + let (new_default_body_inputs, new_default_dstms) = distributeBody scope segments inps defaultCase + (new_default_body_res, new_default_body_stms) <- + runReaderT + ( runBuilder $ + liftBodyWithDistResults segments new_default_body_inputs env new_default_dstms res (bodyResult defaultCase) + ) + scope + pure $ Body () new_default_body_stms new_default_body_res + + -- Maybe it is better to build MatchDec ourselves + match_e <- + eMatch' + scrutinees + [Case c (pure body) | Case c body <- new_cases] + (pure new_default_body) + (matchSort rt) + + match_res <- + certifying (distCerts inps aux env) $ + letTupExp "match_res" match_e + + rets <- expExtType match_e + -- get rid of the existential context + traceM $ unlines ["match res type:", prettyString rets] + let payload_res = drop (S.size (shapeContext rets)) match_res + let reps = distResultsToResReps res payload_res + pure $ insertReps (zip (map distResTag res) reps) env + Let _ _ (Apply name args rettype s) -> do + let name' = liftFunName name + w <- letSubExp "num_segments" =<< toExp (segmentCount segments) + args' <- ((w, Observe) :) . concat <$> mapM (liftArg segments w inps env) args + args_ts <- mapM (subExpType . fst) args' + let dietToUnique Consume = Unique + dietToUnique Observe = Nonunique + dietToUnique ObservePrim = Nonunique + param_ts = zipWith toDecl args_ts $ map (dietToUnique . snd) args' + rettype' = addRetAls param_ts $ liftRetType w $ map fst rettype + result <- letTupExp (name' <> "_res") $ Apply name' args' rettype' s + reps <- + zipWithM (reshapeLiftedApplyResult segments) (map fst rettype) $ + resultToResReps (map fst rettype) result + pure $ insertReps (zip (map distResTag res) reps) env + Let _ aux (Loop merge (ForLoop i it n) body) -> do + if isVariant inps env n + then transformFortoWhile segments env inps res aux merge i it n body + else do + let old_loop_params = map fst merge + old_loop_inits = map snd merge + loopParamNames = S.fromList $ map paramName old_loop_params + + num_segments <- letSubExp "num_segments" =<< toExp (segmentCount segments) + (lifted_loop_params, lifted_loop_reps, lifted_init) <- + unzip3 <$> mapM (liftLoopParam segments num_segments inps env loopParamNames) (zip old_loop_params old_loop_inits) + + let lifted_loop_params' = concat lifted_loop_params + lifted_init' = concat lifted_init + + traceM $ unlines ["lifted_loop_params:", prettyString lifted_loop_params', "lifted_init:", prettyString lifted_init'] + + let (inps_local, env_local0, next0) = localiseInputs env inps + loop_param_inputs_local = + zipWith + (\p j -> (paramName p, DistInput (ResTag j) (paramType p))) + old_loop_params + [next0 ..] + + loop_param_reps_local = + zipWith + (\j rep -> (ResTag j, rep)) + [next0 ..] + lifted_loop_reps + loop_new_inputs = inps_local <> loop_param_inputs_local + loop_env_local = insertReps loop_param_reps_local env_local0 + + let i_param = Param mempty i (Prim (IntType it)) + let build_scope = scopeOfFParams lifted_loop_params' <> scopeOfLParams [i_param] + scope <- askScope + let (loop_new_inputs', loop_dstms) = + distributeBody scope segments loop_new_inputs body + + (loop_body_res, loop_body_stms) <- + runReaderT + ( runBuilder $ + liftLoopBody segments num_segments loop_new_inputs' loop_env_local loop_dstms res (bodyResult body) + ) + (scope <> build_scope) + + let loop_body_gpu = Body () loop_body_stms loop_body_res + loop_exp_gpu = + Loop + (zip lifted_loop_params' lifted_init') + (ForLoop i it n) + loop_body_gpu + + loop_out_vs <- + certifying (distCerts inps aux env) $ + letTupExp "loop_res_out" loop_exp_gpu + + let out_reps = loopResultToResReps res loop_out_vs + pure $ insertReps (zip (map distResTag res) out_reps) env + Let _ aux (Loop merge (WhileLoop cond) body) -> do + -- TODO: + -- 4) Use reduction rather than scan for any_active + -- 5) Consider updating the active segment so we don't go over w everytime + + -- inside the body we should compute the indices for which the condition is true and for which it is false, and then distribute the body based on that. + -- We can then merge the results of the two branches by writing them back to a blank space like we do for the branches of a match. + + let old_loop_params = map fst merge + old_loop_inits = map snd merge + loopParamNames = S.fromList $ map paramName old_loop_params + w <- letSubExp "num_segments" =<< toExp (segmentCount segments) + (lifted_loop_params, lifted_loop_reps, lifted_init) <- + unzip3 <$> mapM (liftLoopParam segments w inps env loopParamNames) (zip old_loop_params old_loop_inits) + + let lifted_loop_params' = concat lifted_loop_params + lifted_init' = concat lifted_init + + -- find cond_lifted_param in old_lifted_loop_params to get the lifted_loop_reps + let (inps_local, env_local0, next0) = localiseInputs env inps + loop_param_inputs_local = + zipWith + (\p j -> (paramName p, DistInput (ResTag j) (paramType p))) + old_loop_params + [next0 ..] + loop_param_reps_local = + zipWith + (\j rep -> (ResTag j, rep)) + [next0 ..] + lifted_loop_reps + loop_new_inputs = inps_local <> loop_param_inputs_local + loop_env_local = insertReps loop_param_reps_local env_local0 + + let maybe_cond = lookup cond (zip (map paramName old_loop_params) (zip lifted_loop_reps lifted_init)) + scope <- askScope + case maybe_cond of + -- infinite loop . later can be uniform case as well. + -- !!! TODO: !!!! update this as well. + Nothing -> do + let build_scope = scopeOfFParams lifted_loop_params' + let (loop_new_inputs', loop_dstms) = distributeBody scope segments loop_new_inputs body + (loop_body_res, loop_body_stms) <- + runReaderT + (runBuilder $ liftBody w loop_new_inputs' loop_env_local loop_dstms (bodyResult body)) + (scope <> build_scope) + let loop_body_gpu = Body () loop_body_stms loop_body_res + loop_exp_gpu = Loop (zip lifted_loop_params' lifted_init') (WhileLoop cond) loop_body_gpu + loop_out_vs <- certifying (distCerts inps aux env) $ letTupExp "loop_res_out" loop_exp_gpu + let result_types = map ((\(DistType _ _ t) -> t) . distResType) res + out_reps = resultToResReps result_types loop_out_vs + pure $ insertReps (zip (map distResTag res) out_reps) env + Just (cond_lifted_rep, cond_init) -> do + let [cond_init_se] = cond_init + + -- Compute initial any_active + cond_init_arr_v <- letExp "cond_init_arr" $ BasicOp $ SubExp cond_init_se + let cond_lifted_param = case cond_lifted_rep of + Regular v -> v + Irregular {} -> error "WhileLoop condition cannot be irregular" + + -- latter chagne to reduction + cond_init_arr_t <- lookupType cond_init_arr_v + cond_init_flat <- + letExp "cond_init_flat" . BasicOp $ + Reshape cond_init_arr_v $ + reshapeAll (arrayShape cond_init_arr_t) (Shape [w]) + + or_lam <- binOpLambda LogOr Bool + cond_scanned <- genScan "any_scan" (NE.singleton w) or_lam [constant False] [cond_init_flat] + let [cond_scanned_v] = cond_scanned + + any_active_init <- + letSubExp "any_active_init" + =<< eIf + (toExp $ pe64 w .==. 0) + (eBody [eSubExp $ constant False]) + (eBody [eIndex cond_scanned_v [toExp $ pe64 w - 1]]) + + any_active_param <- newParam "any_active" (Prim Bool) + let build_scope = scopeOfFParams lifted_loop_params' <> scopeOfFParams [any_active_param] + -- ‌build body + (loop_body_res, loop_body_stms) <- + runReaderT + ( runBuilder $ do + -- (num_data, active_inds) <- genFilter cond_lifted_param + equiv_classes <- letExp "equiv_classes" <=< segMap (MkSolo w) $ \(MkSolo i) -> do + let seg_is = unflattenIndex (segmentDims segments) (pe64 i) + c <- letSubExp "c" =<< eIndex cond_lifted_param (map toExp seg_is) + cls <- + letSubExp "cls" + =<< eIf + (eSubExp c) + (eBody [toExp $ intConst Int64 1]) + (eBody [toExp $ intConst Int64 0]) + pure [subExpRes cls] + n_cases <- letExp "n_cases" <=< toExp $ intConst Int64 2 + (partition_sizes, partition_offs, partition_inds) <- doPartition n_cases equiv_classes + inds_t <- lookupType partition_inds + + let getInds nm k = do + sz <- + letSubExp (nm <> "_sz") + =<< eIndex partition_sizes [toExp $ intConst Int64 k] + off <- + letSubExp (nm <> "_off") + =<< eIndex partition_offs [toExp $ intConst Int64 k] + inds <- + letExp (nm <> "_inds") $ + BasicOp $ + Index partition_inds $ + fullSlice inds_t [DimSlice off sz (intConst Int64 1)] + pure (sz, inds) + + (_, inactive_inds) <- getInds "inactive" 0 + (active_size, active_inds) <- getInds "active" 1 + + inactive_reps <- forM old_loop_params $ \p -> do + (_, _, rep) <- splitInput segments loop_new_inputs loop_env_local inactive_inds (paramName p) + pure rep + + let free_in_body = + filter + (isVariant loop_new_inputs loop_env_local . Var) + (namesToList $ freeIn body) + (ts, vs, reps) <- unzip3 <$> mapM (splitInput segments loop_new_inputs loop_env_local active_inds) free_in_body + let subset_inputs = do + (v, t, i) <- zip3 vs ts [0 ..] + pure (v, DistInput (ResTag i) t) + env_subset = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + let subset_segments = NE.singleton active_size + let (subset_inputs', subset_dstms) = distributeBody scope subset_segments subset_inputs body + env_subset' <- foldM (transformDistStm subset_segments) env_subset subset_dstms + active_reps <- + zipWithM + (liftDistResultRep subset_segments subset_inputs' env_subset') + res + (bodyResult body) + + let mergeOneLifted t rep0 rep1 = + case (rep0, rep1) of + (Regular x0, Regular x1) -> do + let initial_shape = Shape [w] <> arrayShape t + let final_shape = segmentsShape segments <> arrayShape t + let pt = elemType t + space <- letExp "blank" =<< eBlank (Array pt initial_shape NoUniqueness) + + out <- + foldM + scatterRegular + space + [(inactive_inds, x0), (active_inds, x1)] + + out_type <- arrayShape <$> lookupType out + out_reshaped <- + letExp "out_reshaped" . BasicOp $ + Reshape out $ + reshapeAll out_type final_shape + + pure [SubExpRes mempty (Var out_reshaped)] + (Irregular ir0, Irregular ir1) -> do + segsSpace <- + letExp "blank_segs" + =<< eBlank (Array int64 (Shape [w]) NoUniqueness) + + segs <- + foldM + scatterRegular + segsSpace + [(inactive_inds, irregularS ir0), (active_inds, irregularS ir1)] + + (_, offsets, num_data) <- exScanAndSum segs + + let pt = elemType t + elemsSpace <- + letExp "blank_elems" + =<< eBlank (Array pt (Shape [num_data]) NoUniqueness) + + elems <- + foldM + (scatterIrregular offsets) + elemsSpace + [(inactive_inds, ir0), (active_inds, ir1)] + + flags <- genFlags num_data offsets + + pure + [ SubExpRes mempty num_data, + SubExpRes mempty (Var segs), + SubExpRes mempty (Var flags), + SubExpRes mempty (Var offsets), + SubExpRes mempty (Var elems) + ] + _ -> error "mergeOneLifted: mismatched reps" + + merged_results <- + concat + <$> zipWithM + (\p (r0, r1) -> mergeOneLifted (declTypeOf p) r0 r1) + old_loop_params + (zip inactive_reps active_reps) + + -- we have one extra iteration but it is better than extra reduction in the loop body, + any_active <- + letSubExp "any_active" + =<< eIf + (toExp $ pe64 active_size .==. 0) + (eBody [eSubExp $ constant False]) + (eBody [eSubExp $ constant True]) + + pure $ merged_results ++ [SubExpRes mempty any_active] + ) + (scope <> build_scope) + + let loop_body_gpu = Body () loop_body_stms loop_body_res + loop_exp_gpu = + Loop + (zip (lifted_loop_params' ++ [any_active_param]) (lifted_init' ++ [any_active_init])) + (WhileLoop (paramName any_active_param)) + loop_body_gpu + + loop_out_vs <- + certifying (distCerts inps aux env) $ + letTupExp "loop_res_out" loop_exp_gpu + let loop_out_vs' = L.init loop_out_vs + let out_reps = loopResultToResReps res loop_out_vs' + pure $ insertReps (zip (map distResTag res) out_reps) env + Let pat aux (WithAcc inputs lam) -> + transformWithAcc flattenOps segments env inps res pat aux inputs lam + Let _ _ (Op (Stream {})) -> error "transformDistStm: Stream should have been removed" + Let _ _ (Op (Hist {})) -> error "Unhandled Hist" + Let _ _ (Op (JVP {})) -> error "Unhandled JVP" + Let _ _ (Op (VJP {})) -> error "Unhandled VJP" + +-- helper to not mess up the tags when generating new ones for the loop parameters +-- probably won't be used in future +localiseInputs :: DistEnv -> DistInputs -> (DistInputs, DistEnv, Int) +localiseInputs env_outer inps = + let step (i, env_acc) (v, inp) = + case inp of + DistInputFree arr t -> + ((i, env_acc), (v, DistInputFree arr t)) + DistInput oldrt t -> + let newrt = ResTag i + rep = resVar oldrt env_outer + env_acc' = insertRep newrt rep env_acc + in ((i + 1, env_acc'), (v, DistInput newrt t)) + + ((next, env_local), inps_local) = + L.mapAccumL step (0, mempty) inps + in (inps_local, env_local, next) + +distResCerts :: DistEnv -> [DistInput] -> Certs +distResCerts env = Certs . map f + where + f (DistInputFree v _) = v + f (DistInput rt _) = case resVar rt env of + Regular v -> v + Irregular {} -> error "resCerts: irregular" + +reshapeAndBind :: VName -> VName -> Shape -> Builder GPU () +reshapeAndBind v src shape = do + v_copy <- letExp (baseName v) . BasicOp $ Replicate mempty (Var src) + v_copy_shape <- arrayShape <$> lookupType v_copy + letBindNames [v] $ BasicOp $ Reshape v_copy $ reshapeAll v_copy_shape shape + +mapResultRep :: InnerMapMode -> (VName, VName, VName) -> VName -> Builder GPU ResRep +mapResultRep MultiDim _ v = pure $ Regular v +mapResultRep SingleDim (ws, ws_F, ws_O) v = + -- Forcing the irregular rep to be 1D because in some places that is my assumption + -- and also this will make the metadata consistent. + Irregular + <$> flattenIrregularRep + IrregularRep + { irregularS = ws, + irregularF = ws_F, + irregularO = ws_O, + irregularD = v, + irregularK = Dense + } + +resultMapMode :: InnerMapMode -> DistInputs -> Type -> InnerMapMode +resultMapMode SingleDim _ _ = SingleDim +resultMapMode MultiDim new_inps v_t + | any isTypeVariant (arrayDims v_t) = SingleDim + | otherwise = MultiDim + where + new_inp_var = S.fromList $ map fst new_inps + isTypeVariant se = case se of + Var v -> v `S.member` new_inp_var + _ -> False + +irregularMapResult :: + InnerMapMode -> + (VName, VName, VName) -> + Segments -> + IrregularRep -> + VName -> + Type -> + DistInputs -> + Builder GPU ResRep +irregularMapResult mode (ws, ws_F, ws_O) segments irreg v v_t new_inps = + do + irreg_dense <- ensureDenseIrregular (baseName v <> "_map_result") irreg + if any (isTypeVariant new_inp_var) (arrayShape v_t) + then do + old_segment <- arraySize 0 <$> lookupType ws + new_shape <- letExp (baseName v <> "_outer_shape") <=< segMap (MkSolo old_segment) $ \(MkSolo is) -> do + outer_ind <- letSubExp "outer_ind" =<< eIndex ws_O [eSubExp is] + outer_ws_i <- letSubExp "outer_ws" =<< eIndex ws [eSubExp is] + sz <- + letSubExp "sz" + =<< eIf + (toExp $ pe64 outer_ws_i .==. 0) + (eBody [toExp $ intConst Int64 0]) + ( do + last_row <- letSubExp "last_row" <=< toExp $ pe64 outer_ind + pe64 outer_ws_i - 1 + start <- letSubExp "start" =<< eIndex (irregularO irreg_dense) [eSubExp outer_ind] + last_offset <- letSubExp "last_offset" =<< eIndex (irregularO irreg_dense) [eSubExp last_row] + last_size <- letSubExp "last_size" =<< eIndex (irregularS irreg_dense) [eSubExp last_row] + eBody [toExp $ pe64 last_offset - pe64 start + pe64 last_size] + ) + pure [subExpRes sz] + (new_ws_F, new_ws_O, _) <- doRepIota new_shape + letBindNames [v] $ BasicOp $ Replicate mempty $ Var $ irregularD irreg_dense + mapResultRep SingleDim (new_shape, new_ws_F, new_ws_O) v + else case mode of + MultiDim -> do + reshapeAndBind v (irregularD irreg_dense) (segmentsShape segments <> arrayShape v_t) + mapResultRep MultiDim (ws, ws_F, ws_O) v + SingleDim -> do + -- TODO: have to do this even it seems very annoying should think something better + reshapeAndBind v (irregularD irreg_dense) (segmentsShape segments <> arrayShape v_t) + mapResultRep SingleDim (ws, ws_F, ws_O) v + where + isTypeVariant vin se = case se of + Var v' -> S.member v' vin + _ -> False + new_inp_var = S.fromList $ map fst new_inps + +transformDistributedInnerMap :: + InnerMapMode -> + (VName, VName, VName) -> + M.Map ResTag IrregularRep -> + Segments -> + Distributed -> + Builder GPU [(VName, ResRep)] +transformDistributedInnerMap mode (ws_F, ws_O, ws) irregs segments dist = do + let Distributed dstms (DistResults resmap reps) = dist + let new_inps = concatMap distStmInputs dstms + env <- foldM (transformDistStm segments) env_initial dstms + resmap_res <- fmap concat $ forM (M.toList resmap) $ \(rt, binds) -> + forM binds $ \(cs_inps, v, v_t) -> + certifying (distResCerts env cs_inps) $ + -- FIXME: the copies are because we have too liberal aliases on + -- lifted functions. + case (resultMapMode mode new_inps v_t, resVar rt env) of + (MultiDim, Regular v') -> do + if isAcc v_t + then do + letBindNames [v] $ BasicOp $ Replicate mempty $ Var v' + pure (v, Regular v) + else do + reshapeAndBind v v' (segmentsShape segments <> arrayShape v_t) + pure (v, Regular v) + (SingleDim, Regular v') -> do + if isAcc v_t + then do + letBindNames [v] $ BasicOp $ Replicate mempty $ Var v' + pure (v, Regular v) + else do + letBindNames [v] $ BasicOp $ Replicate mempty $ Var v' + rep <- mapResultRep SingleDim (ws, ws_F, ws_O) v + pure (v, rep) + (result_mode, Irregular irreg) -> do + rep <- irregularMapResult result_mode (ws, ws_F, ws_O) segments irreg v v_t new_inps + pure (v, rep) + reps_res <- forM reps $ \(v, r) -> do + case r of + Left se -> do + letBindNames [v] $ BasicOp $ Replicate (segmentsShape segments) se + -- the se is not part of input so this should be fine + rep <- mapResultRep mode (ws, ws_F, ws_O) v + pure (v, rep) + Right (DistInputFree arr t) -> do + letBindNames [v] $ BasicOp $ SubExp $ Var arr + rep <- mapResultRep (resultMapMode mode new_inps t) (ws, ws_F, ws_O) v + pure (v, rep) + Right (DistInput rt t) -> + let result_mode = resultMapMode mode new_inps t + in case resVar rt env of + Regular v' -> do + letBindNames [v] $ BasicOp $ SubExp $ Var v' + rep <- mapResultRep result_mode (ws, ws_F, ws_O) v + pure (v, rep) + Irregular irreg -> do + rep <- irregularMapResult result_mode (ws, ws_F, ws_O) segments irreg v t new_inps + pure (v, rep) + pure $ resmap_res <> reps_res + where + env_initial = DistEnv {distResMap = M.map Irregular irregs} + +transformDistributed :: + M.Map ResTag IrregularRep -> + Segments -> + Distributed -> + Builder GPU () +transformDistributed irregs segments dist = do + let Distributed dstms (DistResults resmap reps) = dist + env <- foldM (transformDistStm segments) env_initial dstms + forM_ (M.toList resmap) $ \(rt, binds) -> + forM_ binds $ \(cs_inps, v, v_t) -> + certifying (distResCerts env cs_inps) $ + -- FIXME: the copies are because we have too liberal aliases on + -- lifted functions. + case resVar rt env of + Regular v' -> letBindNames [v] $ BasicOp $ Replicate mempty $ Var v' + Irregular irreg -> + -- It might have an irregular representation, but we know + -- that it is actually regular because it is a result. + do + irreg' <- ensureDenseIrregular (baseName v <> "_dist_res") irreg + reshapeAndBind v (irregularD irreg') (segmentsShape segments <> arrayShape v_t) + forM_ reps $ \(v, r) -> + case r of + Left se -> + letBindNames [v] $ BasicOp $ Replicate (segmentsShape segments) se + Right (DistInputFree arr _) -> + letBindNames [v] $ BasicOp $ SubExp $ Var arr + -- This can happen. ask Troels + Right (DistInput rt t) -> + case resVar rt env of + Regular v' -> letBindNames [v] $ BasicOp $ SubExp $ Var v' + Irregular irreg -> + do + irreg' <- ensureDenseIrregular (baseName v <> "_dist_rep") irreg + reshapeAndBind v (irregularD irreg') (segmentsShape segments <> arrayShape t) + where + env_initial = DistEnv {distResMap = M.map Irregular irregs} + +-- Check whether a loop parameter array needs irregular representation. +-- we need the irregular representation when any of its dimensions are either: +-- a loop parameter name or variant in the outer map context + +needsIrregular :: DistInputs -> DistEnv -> S.Set VName -> DeclType -> Bool +needsIrregular inps env loopParamNames t = + case t of + Array {} -> any dimIsVariant (arrayDims t) + _ -> False + where + dimIsVariant (Constant _) = False + dimIsVariant (Var v) = v `S.member` loopParamNames || isVariant inps env (Var v) + +-- Lift a loop parameter and its initial value together. +-- If the parameter is an array whose dimensions are all invariant, +-- we lift it to a regular array. Otherwise we fall back to irregular. +liftLoopParam :: + Segments -> + SubExp -> + DistInputs -> + DistEnv -> + S.Set VName -> + (FParam SOACS, SubExp) -> + Builder GPU ([FParam GPU], ResRep, [SubExp]) +liftLoopParam segments num_segments inps env loopParamNames (fparam, initSE) = do + let t = declTypeOf fparam + case t of + Prim pt -> do + param <- + newParam + (baseName (paramName fparam) <> "_lifted") + (arrayOf (Prim pt) (segmentsShape segments) Nonunique) + initV <- liftSubExpRegular segments inps env (segmentsShape segments) initSE + pure ([param], Regular $ paramName param, [Var initV]) + Array pt _ u + | needsIrregular inps env loopParamNames t -> do + (params, rep) <- liftParam num_segments fparam + initVals <- liftLoopInit segments inps env initSE num_segments + pure (params, rep, initVals) + | otherwise -> do + -- Regular case: all dims are invariant, just add w as outermost dim + let pShape = segmentsShape segments <> arrayShape t + p <- + newParam + (baseName (paramName fparam) <> "_lifted") + (arrayOf (Prim pt) pShape u) + initV <- liftSubExpRegular segments inps env pShape initSE + pure ([p], Regular $ paramName p, [Var initV]) + Acc {} -> + error "liftLoopParam: Acc" + Mem {} -> + error "liftLoopParam: Mem" + +liftParam :: (MonadFreshNames m) => SubExp -> FParam SOACS -> m ([FParam GPU], ResRep) +liftParam w fparam = + case declTypeOf fparam of + Prim pt -> do + p <- + newParam + (desc <> "_lifted") + (arrayOf (Prim pt) (Shape [w]) Nonunique) + pure ([p], Regular $ paramName p) + Array pt _ u -> do + num_data <- + newParam (desc <> "_num_data") $ Prim int64 + segments <- + newParam (desc <> "_segments") $ + arrayOf (Prim int64) (Shape [w]) Nonunique + flags <- + newParam (desc <> "_F") $ + arrayOf (Prim Bool) (Shape [Var (paramName num_data)]) Nonunique + offsets <- + newParam (desc <> "_O") $ + arrayOf (Prim int64) (Shape [w]) Nonunique + elems <- + newParam (desc <> "_data") $ + arrayOf (Prim pt) (Shape [Var (paramName num_data)]) u + pure + ( [num_data, segments, flags, offsets, elems], + Irregular $ + IrregularRep + { irregularS = paramName segments, + irregularF = paramName flags, + irregularO = paramName offsets, + irregularD = paramName elems, + irregularK = Dense + } + ) + Acc {} -> + error "liftParam: Acc" + Mem {} -> + error "liftParam: Mem" + where + desc = baseName (paramName fparam) + +liftArg :: Segments -> SubExp -> DistInputs -> DistEnv -> (SubExp, Diet) -> Builder GPU [(SubExp, Diet)] +liftArg segments w inps env (se, d) = do + (_, rep) <- liftSubExp segments inps env se + case rep of + Regular v -> do + v_t <- lookupType v + v' <- + if arrayShape v_t == Shape [w] + then pure v + else + letExp "lifted_arg_flat" . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) (Shape [w]) + pure [(Var v', d)] + Irregular irreg -> mkIrrep irreg + where + mkIrrep + ( IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + ) = do + t <- lookupType elems + t_o <- lookupType offsets + flags_t <- lookupType flags + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + let shape = Shape [Var num_data] + flags' <- letExp "flags" $ BasicOp $ Reshape flags $ reshapeAll (arrayShape flags_t) shape + elems' <- letExp "elems" $ BasicOp $ Reshape elems $ reshapeAll (arrayShape t) shape + segs' <- letExp "segs" $ BasicOp $ Reshape segs $ reshapeAll (arrayShape t_o) (Shape [w]) + offsets' <- letExp "offsets" $ BasicOp $ Reshape offsets $ reshapeAll (arrayShape t_o) (Shape [w]) + + -- Only apply the original diet to the 'elems' array + let diets = replicate 4 Observe ++ [d] + pure $ zipWith (curry (first Var)) [num_data, segs', flags', offsets', elems'] diets + +reshapeLiftedApplyResult :: Segments -> RetType SOACS -> ResRep -> Builder GPU ResRep +reshapeLiftedApplyResult segments Prim {} (Regular v) = do + v_t <- lookupType v + let expectedShape = segmentsShape segments + v' <- + if arrayShape v_t == expectedShape + then pure v + else + letExp "lifted_apply_res" . BasicOp $ + Reshape v $ + reshapeAll (arrayShape v_t) expectedShape + pure $ Regular v' +reshapeLiftedApplyResult _ _ rep = + pure rep + +liftLoopInit :: Segments -> DistInputs -> DistEnv -> SubExp -> SubExp -> Builder GPU [SubExp] +liftLoopInit segments inps env se num_segments = do + (_, rep) <- liftSubExp segments inps env se + case rep of + Regular v -> pure [Var v] + Irregular irreg -> mkIrrep irreg + where + mkIrrep + ( IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + ) = do + t <- lookupType elems + t_o <- lookupType offsets + flags_t <- lookupType flags + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + let shape = Shape [Var num_data] + flags' <- letExp "flags" $ BasicOp $ Reshape flags $ reshapeAll (arrayShape flags_t) shape + elems' <- letExp "elems" $ BasicOp $ Reshape elems $ reshapeAll (arrayShape t) shape + -- I'm not sure why I need this reshapes + segs' <- letExp "segs" $ BasicOp $ Reshape segs $ reshapeAll (arrayShape t_o) (Shape [num_segments]) + offsets' <- letExp "offsets" $ BasicOp $ Reshape offsets $ reshapeAll (arrayShape t_o) (Shape [num_segments]) + pure $ map Var [num_data, segs', flags', offsets', elems'] + +-- Lifts a functions return type such that it matches the lifted functions return type. +liftRetType :: SubExp -> [RetType SOACS] -> [RetType GPU] +liftRetType w = concat . snd . L.mapAccumL liftType 0 + where + liftType i rettype = + let lifted = case rettype of + Prim pt -> pure $ arrayOf (Prim pt) (Shape [Free w]) Nonunique + Array pt _ u -> + let num_data = Prim int64 + segs = arrayOf (Prim int64) (Shape [Free w]) Nonunique + flags = arrayOf (Prim Bool) (Shape [Ext i]) Nonunique + offsets = arrayOf (Prim int64) (Shape [Free w]) Nonunique + elems = arrayOf (Prim pt) (Shape [Ext i]) u + in [num_data, segs, flags, offsets, elems] + Acc {} -> error "liftRetType: Acc" + Mem {} -> error "liftRetType: Mem" + in (i + length lifted, lifted) + +loopResultToResReps :: [DistResult] -> [VName] -> [ResRep] +loopResultToResReps dist_res results = + snd $ + L.mapAccumL + ( \rs dist_res' -> + if isRegularDistResult dist_res' + then + let (v : rs') = rs + in (rs', Regular v) + else + let (_ : segs : flags : offsets : elems : rs') = rs + in (rs', Irregular $ IrregularRep segs flags offsets elems Dense) + ) + results + dist_res + +liftLoopResult :: Segments -> SubExp -> DistInputs -> DistEnv -> DistResult -> SubExpRes -> Builder GPU Result +liftLoopResult segments num_segments inps env dist_res res = + if isRegularDistResult dist_res + then do + let (DistType _ _ t) = distResType dist_res + let expectedShape = segmentsShape segments <> arrayShape t + v <- liftSubExpRegular segments inps env expectedShape (resSubExp res) + pure [SubExpRes mempty (Var v)] + else case resSubExp res of + Var v -> do + irreg <- getIrregRep segments env inps v + map (SubExpRes mempty . Var) <$> mkIrrep irreg + _ -> undefined + where + mkIrrep + ( IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + ) = do + flags_t <- lookupType flags + t <- lookupType elems + t_o <- lookupType offsets + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + let shape = Shape [Var num_data] + flags' <- letExp "flags" $ BasicOp $ Reshape flags $ reshapeAll (arrayShape flags_t) shape + elems' <- letExp "elems" $ BasicOp $ Reshape elems $ reshapeAll (arrayShape t) shape + segs' <- letExp "segs" $ BasicOp $ Reshape segs $ reshapeAll (arrayShape t_o) (Shape [num_segments]) + offsets' <- letExp "offsets" $ BasicOp $ Reshape offsets $ reshapeAll (arrayShape t_o) (Shape [num_segments]) + pure [num_data, segs', flags', offsets', elems'] + +liftLoopBody :: Segments -> SubExp -> DistInputs -> DistEnv -> [DistStm] -> [DistResult] -> Result -> Builder GPU Result +liftLoopBody segments num_segments inputs env dstms dist_res result = do + env' <- foldM (transformDistStm segments) env dstms + results <- zipWithM (liftLoopResult segments num_segments inputs env') dist_res result + pure $ concat results + +distResultsToResReps :: [DistResult] -> [VName] -> [ResRep] +distResultsToResReps dist_res results = + snd $ + L.mapAccumL + ( \rs dist_res' -> + if isRegularDistResult dist_res' + then + let (v : rs') = rs + in (rs', Regular v) + else + let (segs : flags : offsets : elems : rs') = rs + in (rs', Irregular $ IrregularRep segs flags offsets elems Dense) + ) + results + dist_res + +liftDistResult :: Segments -> DistInputs -> DistEnv -> DistResult -> SubExpRes -> Builder GPU Result +liftDistResult segments inps env dist_res res = + if isRegularDistResult dist_res + then do + let (DistType _ _ t) = distResType dist_res + let expectedShape = segmentsShape segments <> arrayShape t + v <- liftSubExpRegular segments inps env expectedShape (resSubExp res) + pure [SubExpRes mempty (Var v)] + else case resSubExp res of + Var v -> do + irreg <- getIrregRep segments env inps v + pure $ map (SubExpRes mempty . Var) [irregularS irreg, irregularF irreg, irregularO irreg, irregularD irreg] + _ -> undefined + +liftBodyWithDistResults :: Segments -> DistInputs -> DistEnv -> [DistStm] -> [DistResult] -> Result -> Builder GPU Result +liftBodyWithDistResults segments inputs env dstms dist_res result = do + env' <- foldM (transformDistStm segments) env dstms + result' <- zipWithM (liftDistResult segments inputs env') dist_res result + pure $ concat result' + +liftBody :: SubExp -> DistInputs -> DistEnv -> [DistStm] -> Result -> Builder GPU Result +liftBody w inputs env dstms result = do + let segments = NE.singleton w + env' <- foldM (transformDistStm segments) env dstms + result' <- mapM (liftResult segments inputs env') result + pure $ concat result' + +liftFunName :: Name -> Name +liftFunName name = name <> "_lifted" + +addRetAls :: [DeclType] -> [RetType GPU] -> [(RetType GPU, RetAls)] +addRetAls params rettype = zip rettype $ map possibleAliases rettype + where + aliasable (Array _ _ Nonunique) = True + aliasable _ = False + aliasable_params = + map snd $ filter (aliasable . fst) $ zip params [0 ..] + aliasable_rets = + map snd $ filter (aliasable . declExtTypeOf . fst) $ zip rettype [0 ..] + possibleAliases t + | aliasable t = RetAls aliasable_params aliasable_rets + | otherwise = mempty + +liftFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) +liftFunDef const_scope fd = do + let FunDef + { funDefBody = body, + funDefParams = fparams, + funDefRetType = rettype + } = fd + wp <- newParam "w" $ Prim int64 + let w = Var $ paramName wp + (fparams', reps) <- mapAndUnzipM (liftParam w) fparams + let fparams'' = wp : concat fparams' + let inputs = do + (p, i) <- zip fparams [0 ..] + pure (paramName p, DistInput (ResTag i) (paramType p)) + let rettype' = + addRetAls (map paramDeclType fparams'') $ + liftRetType w (map fst rettype) + let (inputs', dstms) = + distributeBody const_scope (NE.singleton (Var (paramName wp))) inputs body + env = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + -- Lift the body of the function and get the results + (result, stms) <- + runReaderT + (runBuilder $ liftBody w inputs' env dstms $ bodyResult body) + (const_scope <> scopeOfFParams fparams'') + let name = liftFunName $ funDefName fd + pure $ + fd + { funDefName = name, + funDefBody = Body () stms result, + funDefParams = fparams'', + funDefRetType = rettype' + } + +transformLambda :: Scope SOACS -> Lambda SOACS -> PassM (Lambda GPU) +transformLambda scope (Lambda params ret body) = do + body' <- transformBody (scopeOfLParams params <> scope) body + pure $ Lambda params ret body' + +transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) +transformStm scope (Let pat _ (Op (Screma w arrs form))) + | Just lam <- isMapSOAC form = do + let arrs' = + zipWith MapArray arrs $ + map paramType (lambdaParams (scremaLambda form)) + (distributed, _) = distributeMap scope pat (NE.singleton w) arrs' lam + m = transformDistributed mempty (NE.singleton w) distributed + traceM $ prettyString distributed + runReaderT (runBuilder_ m) scope +transformStm scope (Let pat aux (Loop params form body)) = + oneStm . Let pat aux . Loop params form <$> transformBody scope' body + where + scope' = scopeOfLoopForm form <> scopeOfFParams (map fst params) <> scope +transformStm scope (Let pat aux (Match ses cases def_body ret)) = + oneStm . Let pat aux + <$> (Match ses <$> mapM onCase cases <*> transformBody scope def_body <*> pure ret) + where + onCase = traverse (transformBody scope) +transformStm scope (Let pat aux (WithAcc inputs withacc_lam)) = + oneStm . Let pat aux + <$> (WithAcc (map onInput inputs) <$> transformLambda scope withacc_lam) + where + onInput (shape, arrs, Nothing) = + (shape, arrs, Nothing) + onInput (shape, arrs, Just (lam, nes)) = + (shape, arrs, Just (soacsLambdaToGPU lam, nes)) +transformStm _ stm = pure $ oneStm $ soacsStmToGPU stm + +transformStms :: Scope SOACS -> Stms SOACS -> PassM (Stms GPU) +transformStms scope stms = + fold <$> traverse (transformStm (scope <> scopeOf stms)) stms + +transformBody :: Scope SOACS -> Body SOACS -> PassM (Body GPU) +transformBody scope (Body () stms res) = do + stms' <- transformStms scope stms + pure $ Body () stms' res + +transformFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) +transformFunDef consts_scope fd = do + let FunDef + { funDefBody = body, + funDefParams = fparams, + funDefRetType = rettype + } = fd + body' <- transformBody (scopeOfFParams fparams <> consts_scope) body + pure $ + fd + { funDefBody = body', + funDefRetType = rettype, + funDefParams = fparams + } + +transformProg :: Prog SOACS -> PassM (Prog GPU) +transformProg prog = do + progAfterPreProcessing <- preprocessProg prog + traceM $ "After preprocessProg:" <> prettyString progAfterPreProcessing + consts' <- transformStms mempty $ progConsts progAfterPreProcessing + funs' <- mapM (transformFunDef $ scopeOf (progConsts progAfterPreProcessing)) $ progFuns progAfterPreProcessing + lifted_funs <- + mapM (liftFunDef $ scopeOf (progConsts progAfterPreProcessing)) $ + filter (isNothing . funDefEntryPoint) $ + progFuns progAfterPreProcessing + -- In extremely unlikely cases (mostly empty programs), we may end up having a + -- name source that overlaps the names used in the builtin functions. Avoid + -- that by bumping it by enough that we probably will not have a conflict. + modifyNameSource $ \src -> ((), mappend (newNameSource 1000) src) + pure $ + prog + { progConsts = consts', + progFuns = flatteningBuiltins <> lifted_funs <> funs' + } + +-- transform a for-loop with a variant iteration count into a while-loop +transformFortoWhile :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + StmAux () -> + [(FParam SOACS, SubExp)] -> + VName -> + IntType -> + SubExp -> + Body SOACS -> + Builder GPU DistEnv +transformFortoWhile segments env inps res aux merge i it n body = do + let old_loop_params = map fst merge + -- Fresh names used only in the synthetic rewritten body. + cond_param_v <- newVName "for_cond" + cond0_v <- newVName "for_cond0" + cond_next_v <- newVName "for_cond_next" + i_next_v <- newVName "for_i_next" + loop_old_out_vs <- replicateM (length merge) $ newVName "for_out" + i_out_v <- newVName "for_i_out" + cond_out_v <- newVName "for_cond_out" + + let zero = intConst it 0 + one = intConst it 1 + aux_no_certs = aux {stmAuxCerts = mempty} + + cond0_stm = + Let + (Pat [PatElem cond0_v (Prim Bool)]) + aux_no_certs + (BasicOp $ CmpOp (CmpSlt it) zero n) + + -- Extend the loop parameters with iteration variable and condition variable + i_param = Param mempty i (Prim (IntType it)) + cond_param = Param mempty cond_param_v (Prim Bool) + + Body loop_body_dec loop_body_stms loop_body_res = body + + i_next_stm = + Let + (Pat [PatElem i_next_v (Prim (IntType it))]) + aux_no_certs + -- OverflowWrap or OverflowUndef? + (BasicOp $ BinOp (Add it OverflowUndef) (Var i) one) + + cond_next_stm = + Let + (Pat [PatElem cond_next_v (Prim Bool)]) + aux_no_certs + (BasicOp $ CmpOp (CmpSlt it) (Var i_next_v) n) + + loop_new_body = + Body + loop_body_dec + (loop_body_stms <> oneStm i_next_stm <> oneStm cond_next_stm) + ( [ SubExpRes mempty (Var cond_next_v), + SubExpRes mempty (Var i_next_v) + ] + <> loop_body_res + ) + + merge' = + [ (cond_param, Var cond0_v), + (i_param, zero) + ] + <> merge + + loop_out_tys = [Prim Bool, Prim (IntType it)] ++ map paramType old_loop_params + + loop_pat = + Pat $ + zipWith + PatElem + ([cond_out_v, i_out_v] ++ loop_old_out_vs) + loop_out_tys + + while_stm = + Let + loop_pat + aux + (Loop merge' (WhileLoop (paramName cond_param)) loop_new_body) + + synthetic_body = + Body + () + (oneStm cond0_stm <> oneStm while_stm) + (map (SubExpRes mempty . Var) loop_old_out_vs) + + let (inps_local, env_local, _) = localiseInputs env inps + + scope <- askScope + let (inps_dist, dstms) = distributeBody scope segments inps_local synthetic_body + + lifted_res <- liftBodyWithDistResults segments inps_dist env_local dstms res (bodyResult synthetic_body) + lifted_vs <- mapM (letExp "for_variant_res" <=< toExp . resSubExp) lifted_res + let reps = distResultsToResReps res lifted_vs + pure $ insertReps (zip (map distResTag res) reps) env + +splitInput :: + Segments -> + DistInputs -> + DistEnv -> + VName -> + VName -> + Builder GPU (Type, VName, ResRep) +splitInput segments inps env is v = do + (t, rep) <- liftSubExpPreserveRep segments inps env (Var v) + (t,v,) <$> case rep of + Regular arr -> do + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + -- isnt' it better to do the segmap over all dims? + arr' <- letExp "split_arr" <=< segMap (MkSolo n) $ \(MkSolo i) -> do + idx <- letSubExp "idx" =<< eIndex is [eSubExp i] + let arr_is = unflattenIndex (segmentDims segments) (pe64 idx) + subExpsRes . pure <$> (letSubExp "arr" =<< eIndex arr (map toExp arr_is)) + pure $ Regular arr' + Irregular (IrregularRep segs flags offsets elems _) -> do + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + segs' <- letExp "split_segs" <=< segMap (MkSolo n) $ \(MkSolo i) -> do + idx <- letExp "idx" =<< eIndex is [eSubExp i] + subExpsRes . pure <$> (letSubExp "segs" =<< eIndex segs [toExp idx]) + (_, offsets', num_data) <- exScanAndSum segs' + (_, _, ii1) <- doRepIota segs' + (_, _, ii2) <- doSegIota segs' + ~[flags', elems'] <- letTupExp "split_F_data" <=< segMap (MkSolo num_data) $ \(MkSolo i) -> do + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp i]]] + idx <- letExp "idx" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp i]) + flags_split <- letSubExp "flags" =<< eIndex flags [toExp idx] + elems_split <- letSubExp "elems" =<< eIndex elems [toExp idx] + pure $ subExpsRes [flags_split, elems_split] + pure $ + Irregular $ + IrregularRep + { irregularS = segs', + irregularF = flags', + irregularO = offsets', + irregularD = elems', + irregularK = Dense + } + +-- | Transform a SOACS program to a GPU program, using flattening. +flattenSOACs :: Pass SOACS GPU +flattenSOACs = + Pass + { passName = "flatten", + passDescription = "Perform full flattening", + passFunction = transformProg + } +{-# NOINLINE flattenSOACs #-} diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs new file mode 100644 index 0000000000..26da66145b --- /dev/null +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -0,0 +1,662 @@ +{-# LANGUAGE TypeFamilies #-} + +module Futhark.Pass.Flatten.Builtins + ( flatteningBuiltins, + segMap, + genFlags, + genScan, + genFilter, + genSegScan, + genSegScanomap, + genSegScanomapWithPost, + genSegRed, + genSegRedomap, + genScatter, + genShapeIota, + exScanAndSum, + doSegIota, + doPrefixSum, + doRepIota, + doPartition, + ) +where + +import Control.Monad (forM, forM_, (<=<)) +import Control.Monad.State.Strict +import Data.Foldable (toList) +import Data.Maybe (fromMaybe) +import Data.Text qualified as T +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace) +import Futhark.Pass.ExtractKernels.ToGPU (soacsLambdaToGPU) +import Futhark.Tools +import Futhark.Util (unsnoc) + +builtinName :: T.Text -> Name +builtinName = nameFromText . ("builtin#" <>) + +segIotaName, repIotaName, prefixSumName, partitionName :: Name +segIotaName = builtinName "segiota" +repIotaName = builtinName "repiota" +prefixSumName = builtinName "prefixsum" +partitionName = builtinName "partition" + +segMap :: (Traversable f) => f SubExp -> (f SubExp -> Builder GPU Result) -> Builder GPU (Exp GPU) +segMap segments f = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + res <- f $ fmap Var gtids + ts <- mapM (subExpType . resSubExp) res + pure (map mkResult res, ts) + let kbody = Body () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody + where + mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se + +genScanWithKernelBody :: + (Traversable f) => + Name -> + f SubExp -> + Lambda GPU -> + [SubExp] -> + (f SubExp -> Builder GPU Result) -> + Builder GPU [VName] +genScanWithKernelBody desc segments lam nes = + genScanWithKernelBodyAndPost desc segments lam nes mkIdentityLambda + +genScanWithKernelBodyAndPost :: + (Traversable f) => + Name -> + f SubExp -> + Lambda GPU -> + [SubExp] -> + ([Type] -> Builder GPU (Lambda GPU)) -> + (f SubExp -> Builder GPU Result) -> + Builder GPU [VName] +genScanWithKernelBodyAndPost desc segments lam nes mkPostLam m = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) + ((res, res_t), stms) <- runBuilder . localScope (scopeOfSegSpace space) $ do + res <- m $ fmap Var gtids + res_t <- mapM (subExpType . resSubExp) res + pure (map mkResult res, res_t) + let kbody = Body () stms res + op = SegBinOp Commutative lam nes mempty + post_lam <- mkPostLam res_t + letTupExp desc $ Op $ SegOp $ SegScan lvl space res_t kbody [op] (SegPostOp post_lam) + where + lvl = SegThread SegVirt Nothing + mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se + +bindLambdaInputArrays :: + (Traversable f) => + f SubExp -> + Lambda GPU -> + [VName] -> + Builder GPU () +bindLambdaInputArrays gtids lam arrs = do + let idxs = toList gtids + forM_ (zip (lambdaParams lam) arrs) $ \(p, arr) -> + letBindNames [paramName p] + =<< case paramType p of + Acc {} -> + eSubExp $ Var arr + _ -> + eIndex arr $ map eSubExp idxs + +genScan :: (Traversable f) => Name -> f SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genScan desc segments lam nes arrs = + genScanWithKernelBody desc segments lam nes $ \gtids -> + subExpsRes + <$> forM + arrs + ( \arr -> + letSubExp (baseName arr <> "_elem") =<< eIndex arr (toList $ fmap eSubExp gtids) + ) + +-- Also known as a prescan. +genExScan :: (Traversable f) => Name -> f SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genExScan desc segments lam nes arrs = + genScanWithKernelBody desc segments lam nes $ \gtids -> + let Just (outerDims, innerDim) = unsnoc $ toList gtids + in do + prescan <- + letTupExp' "to_prescan" + =<< eIf + (toExp $ pe64 innerDim .==. 0) + (eBody (map eSubExp nes)) + (eBody (map (`eIndex` (map toExp outerDims ++ [toExp $ pe64 innerDim - 1])) arrs)) + pure $ subExpsRes prescan + +segScanLambda :: + (MonadBuilder m, BranchType (Rep m) ~ ExtType, LParamInfo (Rep m) ~ Type) => + Lambda (Rep m) -> + m (Lambda (Rep m)) +segScanLambda lam = do + x_flag_p <- newParam "x_flag" $ Prim Bool + y_flag_p <- newParam "y_flag" $ Prim Bool + let ts = lambdaReturnType lam + (xps, yps) = splitAt (length ts) $ lambdaParams lam + mkLambda ([x_flag_p] ++ xps ++ [y_flag_p] ++ yps) $ + bodyBind + =<< eBody + [ eBinOp LogOr (eParam x_flag_p) (eParam y_flag_p), + eIf + (eParam y_flag_p) + (eBody (map eParam yps)) + (pure $ lambdaBody lam) + ] + +genSegScan :: Name -> Lambda GPU -> [SubExp] -> VName -> [VName] -> Builder GPU [VName] +genSegScan desc lam nes flags arrs = do + w <- arraySize 0 <$> lookupType flags + lam' <- segScanLambda lam + drop 1 <$> genScan desc [w] lam' (constant False : nes) (flags : arrs) + +segScanomapPostLambda :: + (MonadBuilder m, LParamInfo (Rep m) ~ Type) => + Lambda (Rep m) -> + m (Lambda (Rep m)) +segScanomapPostLambda lam = do + flag_p <- newParam "seg_flag" $ Prim Bool + mkLambda (flag_p : lambdaParams lam) $ + bodyBind $ lambdaBody lam +genSegScanomap :: + Name -> + Lambda GPU -> + [SubExp] -> + VName -> + Lambda GPU -> + [VName] -> + Builder GPU [VName] +genSegScanomap desc scan_lam nes flags map_lam arrs = do + post_lam <- mkIdentityLambda $ lambdaReturnType map_lam + genSegScanomapWithPost desc scan_lam nes flags post_lam map_lam arrs + +genSegScanomapWithPost :: + Name -> + Lambda GPU -> + [SubExp] -> + VName -> + Lambda GPU -> + Lambda GPU -> + [VName] -> + Builder GPU [VName] +genSegScanomapWithPost desc scan_lam nes flags post_lam map_lam arrs = do + w <- arraySize 0 <$> lookupType flags + scan_lam' <- segScanLambda scan_lam + post_lam' <- segScanomapPostLambda post_lam + genScanWithKernelBodyAndPost + desc + [w] + scan_lam' + (constant False : nes) + (const $ pure post_lam') + ( \gtids -> do + let [gtid] = toList gtids + flag <- letSubExp "flag" =<< eIndex flags [eSubExp gtid] + bindLambdaInputArrays gtids map_lam arrs + map_res <- bodyBind (lambdaBody map_lam) + pure (subExpRes flag : map_res) + ) + +genPrefixSum :: Name -> VName -> Builder GPU VName +genPrefixSum desc ns = do + ws <- arrayDims <$> lookupType ns + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genScan desc ws add_lam [intConst Int64 0] [ns] + +genExPrefixSum :: Name -> VName -> Builder GPU VName +genExPrefixSum desc ns = do + ws <- arrayDims <$> lookupType ns + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genExScan desc ws add_lam [intConst Int64 0] [ns] + +genSegPrefixSum :: Name -> VName -> VName -> Builder GPU VName +genSegPrefixSum desc flags ns = do + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genSegScan desc add_lam [intConst Int64 0] flags [ns] + +genScatter :: VName -> SubExp -> (SubExp -> Builder GPU (VName, SubExp)) -> Builder GPU (Exp GPU) +genScatter dest n f = do + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, n)] + withAcc [dest] 1 $ \ ~[acc] -> do + kbody <- buildBody_ $ localScope (scopeOfSegSpace space) $ do + (i, v) <- f $ Var gtid + acc' <- letExp (baseName acc) $ BasicOp $ UpdateAcc Safe acc [Var i] [v] + pure [Returns ResultMaySimplify mempty $ Var acc'] + acc_t <- lookupType acc + letTupExp' "scatter" $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space [acc_t] kbody + +genTabulate :: SubExp -> (SubExp -> Builder GPU [SubExp]) -> Builder GPU (Exp GPU) +genTabulate w m = do + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, w)] + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + ses <- m $ Var gtid + ts <- mapM subExpType ses + pure (map (Returns ResultMaySimplify mempty) ses, ts) + let kbody = Body () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody + +genFlags :: SubExp -> VName -> Builder GPU VName +genFlags m offsets = do + flags_allfalse <- + letExp "flags_allfalse" . BasicOp $ + Replicate (Shape [m]) (constant False) + n <- arraySize 0 <$> lookupType offsets + letExp "flags" <=< genScatter flags_allfalse n $ \gtid -> do + i <- letExp "i" =<< eIndex offsets [eSubExp gtid] + pure (i, constant True) + +genSegRed :: VName -> VName -> VName -> [VName] -> Reduce SOACS -> Builder GPU [VName] +genSegRed segments flags offsets elems red = do + scanned <- + genSegScan + "red" + (soacsLambdaToGPU $ redLambda red) + (redNeutral red) + flags + elems + num_segments <- arraySize 0 <$> lookupType offsets + letTupExp "segred" <=< genTabulate num_segments $ \i -> do + n <- letSubExp "n" =<< eIndex segments [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [toExp (pe64 i)] + letTupExp' "segment_res" <=< eIf (toExp $ pe64 n .==. 0) (eBody $ map eSubExp nes) $ + eBody $ + map (`eIndex` [toExp $ pe64 offset + pe64 n - 1]) scanned + where + nes = redNeutral red + +genSegRedomap :: + VName -> + VName -> + VName -> + [VName] -> + Reduce SOACS -> + Lambda GPU -> + Builder GPU ([VName], [VName]) +genSegRedomap segments flags offsets elems red map_lam = do + scanned_and_map <- + genSegScanomap + "redomap" + (soacsLambdaToGPU $ redLambda red) + (redNeutral red) + flags + map_lam + elems + let (scanned, mapout) = splitAt (length nes) scanned_and_map + num_segments <- arraySize 0 <$> lookupType offsets + reds <- letTupExp "segred" <=< genTabulate num_segments $ \i -> do + n <- letSubExp "n" =<< eIndex segments [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [toExp (pe64 i)] + letTupExp' "segment_res" <=< eIf (toExp $ pe64 n .==. 0) (eBody $ map eSubExp nes) $ + eBody $ + map (`eIndex` [toExp $ pe64 offset + pe64 n - 1]) scanned + pure (reds, mapout) + where + nes = redNeutral red + +-- | Produces a multidimensional iota for the given shape. +genShapeIota :: Shape -> Builder GPU VName +genShapeIota shape = + letExp "shape_iota" =<< segMap (shapeDims shape) (pure . subExpsRes) + +-- Returns (#segments, segment start offsets, sum of segment sizes) +-- Note: If given a multi-dimensional array, +-- `#segments` and `sum of segment sizes` will be arrays, not scalars. +-- `segment start offsets` will always have the same shape as `ks`. +exScanAndSum :: VName -> Builder GPU (SubExp, VName, SubExp) +exScanAndSum ks = do + ns <- arrayDims <$> lookupType ks + -- If `ks` only has a single dimension + -- the size will be a scalar, otherwise it's an array. + ns' <- letExp "ns" $ BasicOp $ case ns of + [] -> error $ "exScanAndSum: Given non-array argument: " ++ prettyString ks + [n] -> SubExp n + _ -> ArrayLit ns (Prim int64) + -- Check if the innermost dimension is empty. + is_empty <- + letExp "is_empty" + =<< ( case ns of + [n] -> toExp (pe64 n .==. 0) + _ -> eLast ns' >>= letSubExp "n" >>= (\n -> toExp $ pe64 n .==. 0) + ) + offsets <- letExp "offsets" =<< toExp =<< genExPrefixSum "offsets" ks + ms <- letExp "ms" <=< segMap (init ns) $ \gtids -> do + let idxs = map toExp gtids + offset <- letExp "offset" =<< eIndex offsets idxs + k <- letExp "k" =<< eIndex ks idxs + m <- + letSubExp "m" + =<< eIf + (toExp is_empty) + (eBody [eSubExp $ intConst Int64 0]) + -- Add last size because 'offsets' is an *exclusive* prefix + -- sum. + (eBody [eBinOp (Add Int64 OverflowUndef) (eLast offset) (eLast k)]) + pure [subExpRes m] + pure (Var ns', offsets, Var ms) + +genSegIota :: VName -> Builder GPU (VName, VName, VName) +genSegIota ks = do + (_n, offsets, m) <- exScanAndSum ks + flags <- genFlags m offsets + ones <- letExp "ones" $ BasicOp $ Replicate (Shape [m]) one + iotas <- genSegPrefixSum "iotas" flags ones + res <- letExp "res" <=< genTabulate m $ \i -> do + x <- letSubExp "x" =<< eIndex iotas [eSubExp i] + letTupExp' "xm1" $ BasicOp $ BinOp (Sub Int64 OverflowUndef) x one + pure (flags, offsets, res) + where + one = intConst Int64 1 + +genRepIota :: VName -> Builder GPU (VName, VName, VName) +genRepIota ks = do + (n, offsets, m) <- exScanAndSum ks + is <- letExp "is" <=< genTabulate n $ \i -> do + o <- letSubExp "o" =<< eIndex offsets [eSubExp i] + k <- letSubExp "n" =<< eIndex ks [eSubExp i] + letTupExp' "i" + =<< eIf + (toExp (pe64 k .==. 0)) + (eBody [eSubExp negone]) + (eBody [toExp $ pe64 o]) + zeroes <- letExp "zeroes" $ BasicOp $ Replicate (Shape [m]) zero + starts <- + letExp "starts" <=< genScatter zeroes n $ \gtid -> do + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, gtid) + flags <- letExp "flags" <=< genTabulate m $ \i -> do + x <- letSubExp "x" =<< eIndex starts [eSubExp i] + letTupExp' "nonzero" =<< toExp (pe64 x .>. 0) + res <- genSegPrefixSum "res" flags starts + pure (flags, offsets, res) + where + zero = intConst Int64 0 + negone = intConst Int64 (-1) + +genPartition :: VName -> VName -> VName -> Builder GPU (VName, VName, VName) +genPartition n k cls = do + let n' = Var n + let k' = Var k + let dims = [k', n'] + -- Create a `[k][n]` array of flags such that `cls_flags[i][j]` + -- is equal 1 if the j'th element is a member of equivalence class `i` i.e. + -- the `i`th row is a flag array for equivalence class `i`. + cls_flags <- + letExp "flags" + <=< segMap dims + $ \[i, j] -> do + c <- letSubExp "c" =<< eIndex cls [toExp j] + cls_flag <- + letSubExp "cls_flag" + =<< eIf + (toExp $ pe64 i .==. pe64 c) + (eBody [toExp $ intConst Int64 1]) + (eBody [toExp $ intConst Int64 0]) + pure [subExpRes cls_flag] + + -- Offsets of each of the individual equivalence classes. + (_, local_offs, _counts) <- exScanAndSum cls_flags + -- The number of elems in each class + counts <- letExp "counts" =<< toExp _counts + -- Offsets of the whole equivalence classes + global_offs <- genExPrefixSum "global_offs" counts + -- Offsets over all of the equivalence classes. + cls_offs <- + letExp "cls_offs" =<< do + segMap dims $ \[i, j] -> do + global_offset <- letExp "global_offset" =<< eIndex global_offs [toExp i] + offset <- + letSubExp "offset" + =<< eBinOp + (Add Int64 OverflowUndef) + (eIndex local_offs [toExp i, toExp j]) + (toExp global_offset) + pure [subExpRes offset] + + scratch <- letExp "scratch" $ BasicOp $ Scratch int64 [n'] + res <- letExp "scatter_res" <=< genScatter scratch n' $ \gtid -> do + c <- letExp "c" =<< eIndex cls [toExp gtid] + ind <- letExp "ind" =<< eIndex cls_offs [toExp c, toExp gtid] + i <- letSubExp "i" =<< toExp gtid + pure (ind, i) + pure (counts, global_offs, res) + +genFilter :: VName -> BuilderT GPU (State VNameSource) (SubExp, VName) +genFilter flags = do + w <- arraySize 0 <$> lookupType flags + flags_int <- letExp "flags_int" <=< segMap [w] $ \[i] -> do + b <- letSubExp "b" =<< eIndex flags [eSubExp i] + v <- + letSubExp "v" + =<< eIf + (eSubExp b) + (eBody [toExp $ intConst Int64 1]) + (eBody [toExp $ intConst Int64 0]) + pure [subExpRes v] + -- offsets <- genExPrefixSum "filter_offs" flags_int + (_n, offsets, num_true) <- exScanAndSum flags_int + -- num_true <- letSubExp "num_true" =<< eIndex flags_int [toExp $ pe64 w - 1] + scratch <- letExp "scratch" $ BasicOp $ Scratch int64 [num_true] + -- is this efficient or do i need to do something smarter? like scatter with guard? + -- offsets' <- letExp "offset" <=< segMap [w] $ \[i] -> do + -- b' <- letSubExp "b" =<< eIndex flags [eSubExp i] + -- v' <- + -- letSubExp "v'" + -- =<< eIf + -- (eSubExp b') + -- (eBody [eIndex offsets [eSubExp i]] ) + -- (eBody [toExp $ intConst Int64 (-1)]) + -- pure [subExpRes v'] + + filtered <- letExp "filtered" <=< genScatter scratch w $ \gtid -> do + b <- letSubExp "b" =<< eIndex flags [eSubExp gtid] + -- idx <- letExp "idx" =<< eIndex offsets' [eSubExp gtid] + idx_se <- + letSubExp "idx" + =<< eIf + (eSubExp b) + (eBody [eIndex offsets [eSubExp gtid]]) + (eBody [toExp $ intConst Int64 (-1)]) + -- maybe cleaner? + idx <- letExp "idx" =<< toExp idx_se + pure (idx, gtid) + pure (num_true, filtered) + +buildingBuiltin :: Builder GPU (FunDef GPU) -> FunDef GPU +buildingBuiltin m = fst $ evalState (runBuilderT m mempty) blankNameSource + +segIotaBuiltin :: FunDef GPU +segIotaBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ do + (flags, offsets, res) <- genSegIota (paramName nsp) + m <- arraySize 0 <$> lookupType res + pure $ subExpsRes [m, Var flags, Var offsets, Var res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = segIotaName, + funDefRetType = + map + (,mempty) + [ Prim int64, + Array Bool (Shape [Ext 0]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique, + Array int64 (Shape [Ext 0]) Unique + ], + funDefParams = [np, nsp], + funDefBody = body + } + +repIotaBuiltin :: FunDef GPU +repIotaBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ do + (flags, offsets, res) <- genRepIota (paramName nsp) + m <- arraySize 0 <$> lookupType res + pure $ subExpsRes [m, Var flags, Var offsets, Var res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = repIotaName, + funDefRetType = + map + (,mempty) + [ Prim int64, + Array Bool (Shape [Ext 0]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique, + Array int64 (Shape [Ext 0]) Unique + ], + funDefParams = [np, nsp], + funDefBody = body + } + +prefixSumBuiltin :: FunDef GPU +prefixSumBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ + varsRes . pure <$> genPrefixSum "res" (paramName nsp) + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = prefixSumName, + funDefRetType = + [(Array int64 (Shape [Free $ Var $ paramName np]) Unique, mempty)], + funDefParams = [np, nsp], + funDefBody = body + } + +partitionBuiltin :: FunDef GPU +partitionBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + kp <- newParam "k" $ Prim int64 + csp <- newParam "cs" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, kp, csp]) . buildBody_ $ do + (counts, offsets, res) <- genPartition (paramName np) (paramName kp) (paramName csp) + pure $ varsRes [counts, offsets, res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = partitionName, + funDefRetType = + map + (,mempty) + [ Array int64 (Shape [Free $ Var $ paramName kp]) Unique, + Array int64 (Shape [Free $ Var $ paramName kp]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique + ], + funDefParams = [np, kp, csp], + funDefBody = body + } + +-- | Builtin functions used in flattening. Must be prepended to a +-- program that is transformed by flattening. The intention is to +-- avoid the code explosion that would result if we inserted +-- primitives everywhere. +flatteningBuiltins :: [FunDef GPU] +flatteningBuiltins = + [ segIotaBuiltin, + repIotaBuiltin, + prefixSumBuiltin, + partitionBuiltin + ] + +-- | @[0,1,2,0,1,0,1,2,3,4,...]@. Returns @(flags,offsets,elems)@. +doSegIota :: VName -> Builder GPU (VName, VName, VName) +doSegIota ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + m <- newVName "m" + flags <- newVName "segiota_flags" + offsets <- newVName "segiota_offsets" + elems <- newVName "segiota_elems" + let args = [(n, Prim int64), (Var ns, ns_t)] + restype = + fromMaybe (error "doSegIota: bad application") $ + applyRetType + (map fst $ funDefRetType segIotaBuiltin) + (funDefParams segIotaBuiltin) + args + letBindNames [m, flags, offsets, elems] $ + Apply + (funDefName segIotaBuiltin) + [(n, Observe), (Var ns, Observe)] + (map (,mempty) restype) + Safe + pure (flags, offsets, elems) + +-- | Produces @[0,0,0,1,1,2,2,2,...]@. Returns @(flags, offsets, +-- elems)@. +doRepIota :: VName -> Builder GPU (VName, VName, VName) +doRepIota ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + m <- newVName "m" + flags <- newVName "repiota_flags" + offsets <- newVName "repiota_offsets" + elems <- newVName "repiota_elems" + let args = [(n, Prim int64), (Var ns, ns_t)] + restype = + fromMaybe (error "doRepIota: bad application") $ + applyRetType + (map fst $ funDefRetType repIotaBuiltin) + (funDefParams repIotaBuiltin) + args + letBindNames [m, flags, offsets, elems] $ + Apply + (funDefName repIotaBuiltin) + [(n, Observe), (Var ns, Observe)] + (map (,mempty) restype) + Safe + pure (flags, offsets, elems) + +doPrefixSum :: VName -> Builder GPU VName +doPrefixSum ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + letExp "prefix_sum" $ + Apply + (funDefName prefixSumBuiltin) + [(n, Observe), (Var ns, Observe)] + [(toDecl (staticShapes1 ns_t) Unique, mempty)] + Safe + +doPartition :: VName -> VName -> Builder GPU (VName, VName, VName) +doPartition k cs = do + cs_t <- lookupType cs + let n = arraySize 0 cs_t + counts <- newVName "partition_counts" + offsets <- newVName "partition_offsets" + res <- newVName "partition_res" + let args = [(n, Prim int64), (Var k, Prim int64), (Var cs, cs_t)] + restype = + fromMaybe (error "doPartition: bad application") $ + applyRetType + (map fst $ funDefRetType partitionBuiltin) + (funDefParams partitionBuiltin) + args + letBindNames [counts, offsets, res] $ + Apply + (funDefName partitionBuiltin) + [(n, Observe), (Var k, Observe), (Var cs, Observe)] + (map (,mempty) restype) + Safe + pure (counts, offsets, res) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs new file mode 100644 index 0000000000..7a056f5915 --- /dev/null +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -0,0 +1,387 @@ +module Futhark.Pass.Flatten.Distribute + ( distributeMap, + distributeBody, + MapArray (..), + mapArrayRowType, + DistResults (..), + DistRep, + ResMap, + Distributed (..), + DistStm (..), + DistBody (..), + DistInput (..), + DistInputs, + DistType (..), + distInputType, + DistResult (..), + ResTag (..), + isRegularDistResult, + isParallelStm, + + -- * Segments + Segments, + segmentsShape, + segmentsRank, + ) +where + +import Data.Bifunctor +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.IR.SOACS +import Futhark.Util (nubOrd) +import Futhark.Util.Pretty + +type Segments = NE.NonEmpty SubExp + +segmentsShape :: Segments -> Shape +segmentsShape = Shape . NE.toList + +segmentsRank :: Segments -> Int +segmentsRank = shapeRank . segmentsShape + +newtype ResTag = ResTag Int + deriving (Eq, Ord, Show) + +-- | Something that is mapped. +data DistInput + = -- | A value bound outside the original map nest. By necessity + -- regular. The type is the parameter type. + DistInputFree VName Type + | -- | A value constructed inside the original map nest. May be + -- irregular. + DistInput ResTag Type + deriving (Eq, Ord, Show) + +type DistInputs = [(VName, DistInput)] + +nubInputs :: DistInputs -> DistInputs +nubInputs = L.nubBy (\a b -> fst a == fst b) + +-- | The type of a 'DistInput'. This corresponds to the parameter +-- type of the original map nest. +distInputType :: DistInput -> Type +distInputType (DistInputFree _ t) = t +distInputType (DistInput _ t) = t + +data DistType + = DistType + -- | Outer regular size. + Segments + -- | Irregular dimensions on top (but after the leading regular + -- size). + Rank + -- | The regular "element type" - in the worst case, at least a + -- scalar. + Type + deriving (Eq, Ord, Show) + +data DistResult = DistResult {distResTag :: ResTag, distResType :: DistType, distResName :: VName} + deriving (Eq, Ord, Show) + +-- | The body of a distributed statement. +data DistBody + = -- | A single statement That may involve parallel operations or produces non unifrom array. + ParallelStm (Stm SOACS) + | -- | Single or Multiple scalar operations grouped into a single traversal + ScalarStm [Stm SOACS] + deriving (Eq, Ord, Show) + +distBodyStms :: DistBody -> [Stm SOACS] +distBodyStms (ParallelStm stm) = [stm] +distBodyStms (ScalarStm stms) = stms + +data DistStm = DistStm + { distStmInputs :: DistInputs, + distStmResult :: [DistResult], + distStmBody :: DistBody + } + deriving (Eq, Ord, Show) + +distStmStms :: DistStm -> [Stm SOACS] +distStmStms = distBodyStms . distStmBody + +-- | First element of tuple are certificates for this result. +-- +-- Second is the name to which is should be bound. +-- +-- Third is the element type (i.e. excluding shape of segments). +type ResMap = M.Map ResTag [([DistInput], VName, Type)] + +-- | The results of a map-distribution that were free or identity +-- mapped in the original map function. These correspond to plain +-- replicated arrays. +type DistRep = (VName, Either SubExp DistInput) + +data DistResults = DistResults ResMap [DistRep] + deriving (Eq, Ord, Show) + +data Distributed = Distributed [DistStm] DistResults + deriving (Eq, Ord, Show) + +instance Pretty ResTag where + pretty (ResTag x) = "r" <> pretty x + +instance Pretty DistInput where + pretty (DistInputFree v _) = pretty v + pretty (DistInput rt _) = pretty rt + +instance Pretty DistType where + pretty (DistType w r t) = + brackets (pretty w) <> pretty r <> pretty t + +instance Pretty DistResult where + pretty (DistResult rt t _) = + pretty rt <> colon <+> pretty t + +instance Pretty DistStm where + pretty (DistStm inputs res stms) = + "let" <+> ppTuple' (map pretty res) <+> "=" indent 2 stm' + where + stm' = + "map" + <+> nestedBlock + ( stack $ + map onInput inputs + ++ map pretty (distBodyStms stms) + ++ [ "return" <+> ppTuple' (map pretty res) + ] + ) + onInput (v, inp) = + "for" + <+> parens (pretty v <> colon <+> pretty (distInputType inp)) + <+> "<-" + <+> pretty inp + +instance Pretty Distributed where + pretty (Distributed stms (DistResults resmap reps)) = + stms' res' + where + res' = stack $ map onRes (M.toList resmap) <> map onRep reps + stms' = stack $ map pretty stms + onRes (rt, binds) = + stack ["let" <+> pretty v <+> "=" <+> pretty rt | v <- binds] + onRep (v, Left se) = + "let" <+> pretty v <+> "=" <+> "rep" <> parens (pretty se) + onRep (v, Right tag) = + "let" <+> pretty v <+> "=" <+> "rep" <> parens (pretty tag) + +resultMap :: [(VName, DistInput)] -> [DistStm] -> Pat Type -> Result -> ResMap +resultMap avail_inputs stms pat res = foldMap f $ concatMap distStmResult stms + where + pes = M.fromList [(patElemName pe, pe) | stm <- stms, pe <- concatMap (patElems . stmPat) (distStmStms stm)] + f (DistResult rt _ v) = + case maybe [] findRess $ M.lookup v pes of + [] -> mempty + binds -> M.singleton rt binds + findRess (PatElem v v_t) = do + (SubExpRes cs se, pv) <- zip res (patNames pat) + if se == Var v + then pure (map findCert (unCerts cs), pv, v_t) + else [] + findCert v = fromMaybe (DistInputFree v (Prim Unit)) $ lookup v avail_inputs + +splitIrregDims :: Names -> Type -> (Rank, Type) +splitIrregDims bound_outside (Array pt shape u) = + let (reg, irreg) = + first reverse $ span regDim $ reverse $ shapeDims shape + in (Rank $ length irreg, Array pt (Shape reg) u) + where + regDim (Var v) = v `nameIn` bound_outside + regDim Constant {} = True +splitIrregDims _ t = (mempty, t) + +freeInput :: [(VName, DistInput)] -> VName -> Maybe (VName, DistInput) +freeInput avail_inputs v = + (v,) <$> lookup v avail_inputs + +patInput :: ResTag -> PatElem Type -> (VName, DistInput) +patInput tag pe = + (patElemName pe, DistInput tag $ patElemType pe) + +distributeBody :: + Scope rep -> + Segments -> + DistInputs -> + Body SOACS -> + (DistInputs, [DistStm]) +distributeBody outer_scope w param_inputs body = + let ((_, avail_inputs), stms) = + L.mapAccumL distributeStm (ResTag (length param_inputs), param_inputs) $ + stmsToList $ + bodyStms body + in (avail_inputs, classifyStms (bodyResult body) stms) + where + bound_outside = namesFromList $ M.keys outer_scope + distType t = uncurry (DistType w) $ splitIrregDims bound_outside t + distributeStm (ResTag tag, avail_inputs) stm = + let pat = stmPat stm + new_tags = map ResTag $ take (patSize pat) [tag ..] + avail_inputs' = + avail_inputs <> zipWith patInput new_tags (patElems pat) + free_in_stm = freeIn stm + used_free = mapMaybe (freeInput avail_inputs) $ namesToList free_in_stm + used_free_types = + mapMaybe (freeInput avail_inputs) + . namesToList + . foldMap (freeIn . distInputType . snd) + $ used_free + stm' = + DistStm + (nubInputs $ used_free_types <> used_free) + (zipWith3 DistResult new_tags (map distType $ patTypes pat) (patNames pat)) + (ParallelStm stm) + in ((ResTag $ tag + length new_tags, avail_inputs'), stm') + +isParallelDistStm :: DistStm -> Bool +isParallelDistStm (DistStm _ res (ParallelStm stm)) = + isParallelStm stm || not (all isRegularDistResult res) +isParallelDistStm _ = False + +isParallelStm :: Stm SOACS -> Bool +isParallelStm stm = isMap (stmExp stm) && not ("sequential" `inAttrs` stmAuxAttrs (stmAux stm)) + where + isParallelOp Stream {} = error "isParallelStm: Stream" + isParallelOp JVP {} = error "isParallelStm: JVP" + isParallelOp VJP {} = error "isParallelStm: VJP" + isParallelOp _ = True + + -- TODO: Check other operations + isParallelBasicOp Update {} = True + isParallelBasicOp Concat {} = True + isParallelBasicOp Iota {} = True + isParallelBasicOp Replicate {} = True + isParallelBasicOp ArrayLit {} = True + isParallelBasicOp ArrayVal {} = True + isParallelBasicOp FlatUpdate {} = error "isParallelStm: flatUpdate" + isParallelBasicOp FlatIndex {} = error "isParallelStm: flatIndex" + isParallelBasicOp _ = False + + isMap (BasicOp op) = isParallelBasicOp op + isMap (Apply fname _ _ _) = not $ isBuiltInFunction fname -- TODO: do better + isMap (Match _ cases def_case _) = + any isParallelStm $ + bodyStms def_case + <> mconcat (map (bodyStms . caseBody) cases) + isMap (Loop _ _ body) = (any isParallelStm . bodyStms) body + isMap (WithAcc _ lam) = (any isParallelStm . bodyStms) $ lambdaBody lam + isMap (Op op) = isParallelOp op + +isRegularDistResult :: DistResult -> Bool +isRegularDistResult (DistResult _ (DistType _ (Rank r) _) _) = r == 0 + +-- we should probably sort the DistStms first and we should assume they are sorted +-- and then given to this function. +classifyStms :: Result -> [DistStm] -> [DistStm] +classifyStms _ [] = [] +classifyStms bodyRes ds = + let (scalars, rest) = break isParallelDistStm ds + scalar_grouped = + [mergeGroup bodyRes scalars rest | not (null scalars)] + in case rest of + [] -> scalar_grouped + p : ps -> + scalar_grouped ++ (p : classifyStms bodyRes ps) + +-- | Merge a group of scalar 'DistStm's into a single one. +mergeGroup :: Result -> [DistStm] -> [DistStm] -> DistStm +mergeGroup bodyRes ds rest = + let resTags = + S.fromList $ concatMap (map distResTag . distStmResult) ds + isInternal (_, DistInput rt _) = rt `S.member` resTags + isInternal _ = False + externalInputs = + nubInputs $ + filter (not . isInternal) $ + concatMap distStmInputs ds + externalResults = + nubOrd $ + filter (isExternal bodyRes rest) $ + concatMap distStmResult ds + allStms = concatMap distStmStms ds + in DistStm externalInputs externalResults (ScalarStm allStms) + +-- | A result is external if it is used by a subsequent 'DistStm' or +-- by the body result. +isExternal :: Result -> [DistStm] -> DistResult -> Bool +isExternal bodyRes rest (DistResult rt _ rn) = + rt `S.member` usedByRest || rn `S.member` bodyResVars || rn `S.member` bodyResCerts + where + usedByRest = + S.fromList + [rt' | (_, DistInput rt' _) <- concatMap distStmInputs rest] + bodyResVars = + S.fromList $ + mapMaybe + ( \(SubExpRes _ se) -> case se of + Var v -> Just v + _ -> Nothing + ) + bodyRes + bodyResCerts = + S.fromList $ + concatMap (\(SubExpRes cs _) -> unCerts cs) bodyRes + +-- | The input we are mapping over in 'distributeMap'. +data MapArray t + = -- | A straightforward array passed in to a + -- top-level map. + MapArray VName Type + | -- | Something more exotic - distribution will assign it a + -- 'ResTag', but not do anything else. This is used to + -- distributed nested maps whose inputs are produced in the outer + -- nests. + MapOther t Type + +mapArrayRowType :: MapArray t -> Type +mapArrayRowType (MapArray _ t) = t +mapArrayRowType (MapOther _ t) = t + +-- This is used to handle those results that are constants or lambda +-- parameters. +findReps :: [(VName, DistInput)] -> Pat Type -> Lambda SOACS -> [DistRep] +findReps avail_inputs map_pat lam = + mapMaybe f $ zip (patElems map_pat) (bodyResult (lambdaBody lam)) + where + f (pe, SubExpRes _ (Var v)) = + case lookup v avail_inputs of + Nothing -> Just (patElemName pe, Left $ Var v) + Just inp + | v `elem` map paramName (lambdaParams lam) -> + Just (patElemName pe, Right inp) + | otherwise -> Nothing + f (pe, SubExpRes _ (Constant v)) = do + Just (patElemName pe, Left $ Constant v) + +distributeMap :: + Scope rep -> + Pat Type -> + Segments -> + [MapArray t] -> + Lambda SOACS -> + (Distributed, M.Map ResTag t) +distributeMap outer_scope map_pat w arrs lam = + let ((_, arrmap), param_inputs) = + L.mapAccumL paramInput (ResTag 0, mempty) $ + zip (lambdaParams lam) arrs + (avail_inputs, stms) = + distributeBody outer_scope w param_inputs $ lambdaBody lam + resmap = + resultMap avail_inputs stms map_pat $ + bodyResult (lambdaBody lam) + reps = findReps avail_inputs map_pat lam + in ( Distributed stms $ DistResults resmap reps, + arrmap + ) + where + paramInput (ResTag i, m) (p, MapArray arr _) = + ( (ResTag i, m), + (paramName p, DistInputFree arr $ paramType p) + ) + paramInput (ResTag i, m) (p, MapOther x _) = + ( (ResTag (i + 1), M.insert (ResTag i) x m), + (paramName p, DistInput (ResTag i) $ paramType p) + ) diff --git a/src/Futhark/Pass/Flatten/ISRWIM.hs b/src/Futhark/Pass/Flatten/ISRWIM.hs new file mode 100644 index 0000000000..1af307874b --- /dev/null +++ b/src/Futhark/Pass/Flatten/ISRWIM.hs @@ -0,0 +1,201 @@ +{-# LANGUAGE TypeFamilies #-} + +-- | Interchanging scans with inner maps. +-- Basically a copy of ExtractKernels ISRWIM with small change. +module Futhark.Pass.Flatten.ISRWIM + ( iswim, + irwim, + rwimPossible, + ) +where + +import Control.Monad +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Tools + +-- | Interchange Scan With Inner Map. Tries to turn a @scan(map)@ into a +-- @map(scan) +iswim :: + (MonadBuilder m, Rep m ~ SOACS) => + Pat Type -> + SubExp -> + Lambda SOACS -> + [(SubExp, VName)] -> + Maybe (m ()) +iswim res_pat w scan_fun scan_input + | Just (map_pat, map_aux, map_w, map_fun) <- rwimPossible scan_fun = Just $ do + let (accs, arrs) = unzip scan_input + let indexAcc (Var v) = do + v_t <- lookupType v + letSubExp "acc" $ + BasicOp $ + Index v $ + fullSlice v_t [DimFix $ intConst Int64 0] + indexAcc Constant {} = + error "irwim: array accumulator is a constant." + arrs' <- transposedArrays arrs + accs' <- mapM indexAcc accs + + -- let (_red_acc_params, red_elem_params) = + -- splitAt (length arrs) $ lambdaParams red_fun + -- map_rettype = map rowType $ lambdaReturnType red_fun + -- map_params = map (setParamOuterDimTo w) red_elem_params + + let map_arrs' = arrs' + (_scan_acc_params, scan_elem_params) = + splitAt (length arrs) $ lambdaParams scan_fun + map_params = map (setParamOuterDimTo w) scan_elem_params + map_rettype = map (setOuterDimTo w) $ lambdaReturnType scan_fun + + scan_params = lambdaParams map_fun + scan_body = lambdaBody map_fun + scan_rettype = lambdaReturnType map_fun + scan_fun' = Lambda scan_params scan_rettype scan_body + scan_input' = zip accs' $ map paramName map_params + + scan_soac <- scanSOAC [Scan scan_fun' accs'] + let map_body = + mkBody + ( oneStm $ + Let (setPatOuterDimTo w map_pat) (defAux ()) $ + Op $ + Screma w (map snd scan_input') scan_soac + ) + $ varsRes + $ patNames map_pat + map_fun' = Lambda map_params map_rettype map_body + + res_pat' <- + fmap basicPat $ + mapM (newIdent' (<> "_transposed") . transposeIdentType) $ + patIdents res_pat + + addStm . Let res_pat' map_aux . Op . Screma map_w map_arrs' + =<< mapSOAC map_fun' + + forM_ (zip (patIdents res_pat) (patIdents res_pat')) $ \(to, from) -> do + let perm = [1, 0] ++ [2 .. arrayRank (identType from) - 1] + addStm $ + Let (basicPat [to]) (defAux ()) . BasicOp $ + Rearrange (identName from) perm + | otherwise = Nothing + +-- | Interchange Reduce With Inner Map. Tries to turn a @reduce(map)@ into a +-- @map(reduce) +irwim :: + (MonadBuilder m, Rep m ~ SOACS) => + Pat Type -> + SubExp -> + Commutativity -> + Lambda SOACS -> + [(SubExp, VName)] -> + Maybe (m ()) +irwim res_pat w comm red_fun red_input + | Just (map_pat, map_aux, map_w, map_fun) <- rwimPossible red_fun = Just $ do + let (accs, arrs) = unzip red_input + arrs' <- transposedArrays arrs + -- FIXME? Can we reasonably assume that the accumulator is a + -- replicate? We also assume that it is non-empty. + let indexAcc (Var v) = do + v_t <- lookupType v + letSubExp "acc" $ + BasicOp $ + Index v $ + fullSlice v_t [DimFix $ intConst Int64 0] + indexAcc Constant {} = + error "irwim: array accumulator is a constant." + accs' <- mapM indexAcc accs + + let (_red_acc_params, red_elem_params) = + splitAt (length arrs) $ lambdaParams red_fun + map_rettype = map rowType $ lambdaReturnType red_fun + map_params = map (setParamOuterDimTo w) red_elem_params + + red_params = lambdaParams map_fun + red_body = lambdaBody map_fun + red_rettype = lambdaReturnType map_fun + red_fun' = Lambda red_params red_rettype red_body + red_input' = zip accs' $ map paramName map_params + red_pat = stripPatOuterDim map_pat + + map_body <- + case irwim red_pat w comm red_fun' red_input' of + Nothing -> do + reduce_soac <- reduceSOAC [Reduce comm red_fun' $ map fst red_input'] + pure + $ mkBody + ( oneStm $ + Let red_pat (defAux ()) $ + Op $ + Screma w (map snd red_input') reduce_soac + ) + $ varsRes + $ patNames map_pat + Just m -> localScope (scopeOfLParams map_params) $ do + map_body_stms <- collectStms_ m + pure $ mkBody map_body_stms $ varsRes $ patNames map_pat + + let map_fun' = Lambda map_params map_rettype map_body + + addStm . Let res_pat map_aux . Op . Screma map_w arrs' + =<< mapSOAC map_fun' + | otherwise = Nothing + +-- | Does this reduce operator contain an inner map, and if so, what +-- does that map look like? +rwimPossible :: + Lambda SOACS -> + Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS) +rwimPossible fun + | Body _ stms res <- lambdaBody fun, + [stm] <- stmsToList stms, -- Body has a single binding + map_pat <- stmPat stm, + map Var (patNames map_pat) == map resSubExp res, -- Returned verbatim + Op (Screma map_w map_arrs form) <- stmExp stm, + Just map_fun <- isMapSOAC form, + map paramName (lambdaParams fun) == map_arrs = + Just (map_pat, stmAux stm, map_w, map_fun) + | otherwise = + Nothing + +transposedArrays :: (MonadBuilder m) => [VName] -> m [VName] +transposedArrays arrs = forM arrs $ \arr -> do + t <- lookupType arr + let perm = [1, 0] ++ [2 .. arrayRank t - 1] + letExp (baseName arr) $ BasicOp $ Rearrange arr perm + +removeParamOuterDim :: LParam SOACS -> LParam SOACS +removeParamOuterDim param = + let t = rowType $ paramType param + in param {paramDec = t} + +setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS +setParamOuterDimTo w param = + let t = setOuterDimTo w $ paramType param + in param {paramDec = t} + +setIdentOuterDimTo :: SubExp -> Ident -> Ident +setIdentOuterDimTo w ident = + let t = setOuterDimTo w $ identType ident + in ident {identType = t} + +setOuterDimTo :: SubExp -> Type -> Type +setOuterDimTo w t = + arrayOfRow (rowType t) w + +setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type +setPatOuterDimTo w pat = + basicPat $ map (setIdentOuterDimTo w) $ patIdents pat + +transposeIdentType :: Ident -> Ident +transposeIdentType ident = + ident {identType = transposeType $ identType ident} + +stripIdentOuterDim :: Ident -> Ident +stripIdentOuterDim ident = + ident {identType = rowType $ identType ident} + +stripPatOuterDim :: Pat Type -> Pat Type +stripPatOuterDim pat = + basicPat $ map stripIdentOuterDim $ patIdents pat diff --git a/src/Futhark/Pass/Flatten/Match.hs b/src/Futhark/Pass/Flatten/Match.hs new file mode 100644 index 0000000000..5fbd1d45ab --- /dev/null +++ b/src/Futhark/Pass/Flatten/Match.hs @@ -0,0 +1,225 @@ +-- | Flattening of 'Match'. +module Futhark.Pass.Flatten.Match + ( transformMatch, + ) +where + +import Control.Monad +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Tuple.Solo +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.Pass.Flatten.Builtins +import Futhark.Pass.Flatten.Distribute +import Futhark.Pass.Flatten.Monad +import Futhark.Tools + +-- Take the elements at index `is` from an input `v`. +splitInput :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + VName -> + Builder GPU (Type, VName, ResRep) +splitInput segments env inps is v = do + (t, rep) <- liftSubExpPreserveRep segments inps env (Var v) + (t,v,) <$> case rep of + Regular arr -> do + -- In the regular case we just take the elements + -- of the array given by `is` + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + arr' <- letExp "split_arr" <=< segMap (MkSolo n) $ \(MkSolo i) -> do + idx <- letSubExp "idx" =<< eIndex is [eSubExp i] + -- unflatten index + let arr_is = unflattenIndex (segmentDims segments) (pe64 idx) + subExpsRes . pure <$> (letSubExp "arr" =<< eIndex arr (map toExp arr_is)) + pure $ Regular arr' + Irregular (IrregularRep segs flags offsets elems _) -> do + -- In the irregular case we take the elements + -- of the `segs` array given by `is` like in the regular case + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + segs' <- letExp "split_segs" <=< segMap (MkSolo n) $ \(MkSolo i) -> do + idx <- letExp "idx" =<< eIndex is [eSubExp i] + subExpsRes . pure <$> (letSubExp "segs" =<< eIndex segs [toExp idx]) + -- From this we calculate the offsets and number of elements + (_, offsets', num_data) <- exScanAndSum segs' + (_, _, ii1) <- doRepIota segs' + (_, _, ii2) <- doSegIota segs' + -- We then take the elements we need from `elems` and `flags` + -- For each index `i`, we roughly: + -- Get the offset of the segment we want to copy by indexing + -- `offsets` through `is` further through `ii1` i.e. + -- `offset = offsets[is[ii1[i]]]` + -- We then add `ii2[i]` to `offset` + -- and use that to index into `elems` and `flags`. + ~[flags', elems'] <- letTupExp "split_F_data" <=< segMap (MkSolo num_data) $ \(MkSolo i) -> do + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp i]]] + idx <- letExp "idx" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp i]) + flags_split <- letSubExp "flags" =<< eIndex flags [toExp idx] + elems_split <- letSubExp "elems" =<< eIndex elems [toExp idx] + pure $ subExpsRes [flags_split, elems_split] + pure $ + Irregular $ + IrregularRep + { irregularS = segs', + irregularF = flags', + irregularO = offsets', + irregularD = elems', + irregularK = Dense + } + +-- Given the indices for which a branch is taken and its body, +-- distribute the statements of the body of that branch. +distributeBranch :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + Body SOACS -> + Builder GPU (DistInputs, DistEnv, [DistStm]) +distributeBranch segments env inps is body = do + let free_in_body = filter (isVariant inps env . Var) (namesToList $ freeIn body) + (ts, vs, reps) <- + unzip3 <$> mapM (splitInput segments env inps is) free_in_body + let inputs = do + (v, t, i) <- zip3 vs ts [0 ..] + pure (v, DistInput (ResTag i) t) + let env' = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + scope <- askScope + let (inputs', dstms) = distributeBody scope segments inputs body + pure (inputs', env', dstms) + +-- Given a single result from each branch as well the *unlifted* +-- result type, merge the results of all branches into a single result. +mergeResult :: + Segments -> + SubExp -> + [VName] -> + [ResRep] -> + DistResult -> + Builder GPU ResRep +mergeResult segments w iss branchesRep dist_res + -- Regular case + | isRegularDistResult dist_res = do + let (DistType _ _ resType) = distResType dist_res + resultType = + Array (elemType resType) (Shape [w] <> arrayShape resType) NoUniqueness + xs <- mapM regularBranch branchesRep + -- Create the blank space for the result + resultSpace <- letExp "blank_res" =<< eBlank resultType + -- Write back the values of each branch to the blank space + result <- foldM scatterRegular resultSpace $ zip iss xs + result_t <- arrayShape <$> lookupType result + result' <- + letExp "match_res_reg" . BasicOp $ + Reshape result (reshapeAll result_t (segmentsShape segments <> arrayShape resType)) + pure $ Regular result' + -- Irregular case + | DistType _ _ (Array pt _ _) <- distResType dist_res = do + branchesIrregRep <- mapM irregularBranch branchesRep + let segsType = Array (IntType Int64) (Shape [w]) NoUniqueness + -- Create a blank space for the 'segs' + segsSpace <- letExp "blank_segs" =<< eBlank segsType + -- Write back the segs of each branch to the blank space + segs <- foldM scatterRegular segsSpace $ zip iss (irregularS <$> branchesIrregRep) + (_, offsets, num_data) <- exScanAndSum segs + let resultType = Array pt (Shape [num_data]) NoUniqueness + -- Create the blank space for the result + resultSpace <- letExp "blank_res" =<< eBlank resultType + -- Write back the values of each branch to the blank space + elems <- foldM (scatterIrregular offsets) resultSpace $ zip iss branchesIrregRep + flags <- genFlags num_data offsets + pure $ + Irregular $ + IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems, + irregularK = Dense + } + | otherwise = error "mergeResult: non-array irregular result" + where + regularBranch (Regular v) = pure v + regularBranch _ = error "mergeResult: mismatched reps" + + irregularBranch (Irregular irreg) = pure irreg + irregularBranch _ = error "mergeResult: mismatched reps" + +transformMatch :: + FlattenOps -> + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + [SubExp] -> + [Case (Body SOACS)] -> + Body SOACS -> + Builder GPU DistEnv +transformMatch ops segments env inps res scrutinees cases defaultCase = do + w <- letSubExp "w" <=< toExp $ product $ segmentDims segments + -- We need to partition the indices of the scrutinees by which case they match. + -- Lift the scrutinees. + -- If it's a variable, we know it's a scalar and the lifted version will therefore be a regular array. + lifted_scrutinees <- forM scrutinees $ \scrut -> do + liftSubExpRegular segments inps env (segmentsShape segments) scrut + -- Cases for tagging values that match the same branch. + -- The default case is the 0'th equvalence class. + let equiv_cases = + zipWith + (\(Case pat _) n -> Case pat $ eBody [toExp $ intConst Int64 n]) + cases + [1 ..] + let equiv_case_default = eBody [toExp $ intConst Int64 0] + -- Match the scrutinees againts the branch cases + equiv_classes <- letExp "equiv_classes" <=< segMap (MkSolo w) $ \(MkSolo i) -> do + -- unflatten index + let seg_is = unflattenIndex (segmentDims segments) (pe64 i) + scruts <- mapM (letSubExp "scruts" <=< flip eIndex (map toExp seg_is)) lifted_scrutinees + cls <- letSubExp "cls" =<< eMatch scruts equiv_cases equiv_case_default + pure [subExpRes cls] + let num_cases = fromIntegral $ length cases + 1 + n_cases <- letExp "n_cases" <=< toExp $ intConst Int64 num_cases + -- Parition the indices of the scrutinees by their equvalence class such + -- that (the indices) of the scrutinees belonging to class 0 come first, + -- then those belonging to class 1 and so on. + (partition_sizes, partition_offs, partition_inds) <- doPartition n_cases equiv_classes + inds_t <- lookupType partition_inds + -- Get the indices of each scrutinee by equivalence class + branch_info <- forM [0 .. num_cases - 1] $ \i -> do + num_data <- + letSubExp ("size" <> nameFromString (show i)) + =<< eIndex partition_sizes [toExp $ intConst Int64 i] + begin <- + letSubExp ("idx_begin" <> nameFromString (show i)) + =<< eIndex partition_offs [toExp $ intConst Int64 i] + inds <- + letExp ("inds_branch" <> nameFromString (show i)) $ + BasicOp . Index partition_inds $ + fullSlice inds_t [DimSlice begin num_data (intConst Int64 1)] + pure (num_data, inds) + let (branch_sizes, inds) = unzip branch_info + + -- Distribute and lift the branch bodies. + -- We put the default case at the start as it's the 0'th equivalence class + -- and is therefore the first segment after the partition. + let branch_bodies = defaultCase : map (\(Case _ body) -> body) cases + (branch_inputs, branch_envs, branch_dstms) <- + unzip3 <$> zipWithM (distributeBranch segments env inps) inds branch_bodies + + let branch_results = map bodyResult branch_bodies + branch_reps <- forM [0 .. num_cases - 1] $ \i -> do + let inputs = branch_inputs !! fromIntegral i + let env' = branch_envs !! fromIntegral i + let dstms = branch_dstms !! fromIntegral i + let result = branch_results !! fromIntegral i + branch_segments = NE.singleton $ branch_sizes !! fromIntegral i + env'' <- foldM (flattenDistStm ops branch_segments) env' dstms + zipWithM (liftDistResultRep branch_segments inputs env'') res result + + -- Merge the results of the branches and insert the resulting res reps + reps <- zipWithM (mergeResult segments w inds) (L.transpose branch_reps) res + pure $ insertReps (zip (map distResTag res) reps) env diff --git a/src/Futhark/Pass/Flatten/Monad.hs b/src/Futhark/Pass/Flatten/Monad.hs new file mode 100644 index 0000000000..7726d0c889 --- /dev/null +++ b/src/Futhark/Pass/Flatten/Monad.hs @@ -0,0 +1,589 @@ +-- | General definitions for the flattening transformation. +-- +-- Defines not just the core monads that are involved, but also the various +-- representations, except perhaps the ones that are completely local to another +-- module. +module Futhark.Pass.Flatten.Monad + ( IrregularKind (..), + IrregularRep (..), + ResRep (..), + DistEnv (..), + + -- * Reading inputs + readInputVar, + readInputs, + readInput, + readTypeDims, + + -- * Insertions + insertRep, + insertReps, + insertIrregulars, + insertIrregular, + insertRegulars, + + -- * Building blocks + ensureDenseIrregular, + liftResult, + liftDistResultRep, + liftSubExp, + liftSubExpPreserveRep, + liftSubExpRegular, + mkIrregFromReg, + distCerts, + dataArr, + getIrregRep, + scatterIrregular, + scatterRegular, + + -- * Various + segsAndElems, + inputReps, + resVar, + scopeOfDistInputs, + resultToResReps, + isVariant, + flattenDistStms, + segmentDims, + FlattenOps (..), + ) +where + +import Control.Monad +import Data.Bifunctor (bimap, second) +import Data.Foldable +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe (fromMaybe, isJust) +import Data.Tuple.Solo +import Futhark.IR.GPU +import Futhark.Pass.Flatten.Builtins +import Futhark.Pass.Flatten.Distribute +import Futhark.Tools +import Futhark.Util.IntegralExp +import Prelude hiding (div, rem) + +-- Note [Representation of Flat Arrays] +-- +-- This flattening implementation uses largely the nomenclature and +-- structure described by Cosmin Oancea. In particular, consider an +-- irregular array 'A' where +-- +-- - A has 'n' segments (outermost dimension). +-- +-- - A has element type 't'. +-- +-- - A has a total of 'm' elements (where 'm' is divisible by 'n', +-- and may indeed be 'm'). +-- +-- Then A is represented by the following arrays: +-- +-- - A_D : [m]t; the "data array". +-- +-- - A_S : [n]i64; the "shape array" giving the number of scalar elements of each segment. +-- +-- - A_F : [m]bool; the "flag array", indicating when an element begins a +-- new segment. +-- +-- - A_O : [n]i64; the offset array, indicating for each segment +-- where it starts in the data (and flag) array. +-- +-- - A_II1 : [m]t; the "segment indices"; a mapping from element +-- index to index of the segment it belongs to. +-- +-- - A_II2 : [m]t; the "inner indices"; a mapping from element index +-- to index within its corresponding segment. +-- +-- The arrays that are not the data array are collectively called the +-- "structure arrays". All of the structure arrays can be computed +-- from each other, but conceptually they all coexist. +-- +-- Note that we only consider the *outer* dimension to be the +-- "segments". Also, 't' may actually be an array itself (although in +-- this case, the shape of 't' must be invariant to all parallel +-- dimensions). The inner structure is preserved through code, not +-- data. (Or in practice, ad-hoc auxiliary arrays produced by code.) +-- In Cosmin's notation, we maintain only the information for the +-- outermost dimension. +-- +-- As an example, consider an irregular array +-- +-- A = [ [], [ [1,2,3], [4], [], [5,6] ], [ [7], [], [8,9,10] ] ] +-- +-- then +-- +-- n = 3 +-- +-- m = 10 +-- +-- A_D = [1,2,3,4,5,6,7,8,9,10] +-- +-- A_S = [0, 6, 4] +-- +-- A_F = [T,F,F,F,F,F,T,F,F,F] +-- +-- A_O = [0, 0, 6] +-- +-- A_II1 = [1,1,1,1,1,1,2,2,2,2] +-- +-- A_II2 = [0,0,0,1,3,3,0,2,2,2] + +data IrregularKind + = Dense + | Replicated + deriving (Show,Eq) + +data IrregularRep = IrregularRep + { -- | Array of size of each segment, type @[]i64@. + irregularS :: VName, + irregularF :: VName, + irregularO :: VName, + irregularD :: VName, + irregularK :: IrregularKind + } + deriving (Show) + +data ResRep + = -- | This variable is represented + -- completely straightforwardly- if it is + -- an array, it is a regular array. + Regular VName + | -- | The representation of an + -- irregular array. + Irregular IrregularRep + deriving (Show) + +newtype DistEnv = DistEnv {distResMap :: M.Map ResTag ResRep} + +insertRep :: ResTag -> ResRep -> DistEnv -> DistEnv +insertRep rt rep env = env {distResMap = M.insert rt rep $ distResMap env} + +insertReps :: [(ResTag, ResRep)] -> DistEnv -> DistEnv +insertReps = flip $ foldl (flip $ uncurry insertRep) + +insertIrregular :: VName -> VName -> VName -> ResTag -> VName -> IrregularKind -> DistEnv -> DistEnv +insertIrregular ns flags offsets rt elems kind env = + let rep = Irregular $ IrregularRep ns flags offsets elems kind + in insertRep rt rep env + +insertIrregulars :: VName -> VName -> VName -> [(ResTag, VName)] -> IrregularKind -> DistEnv -> DistEnv +insertIrregulars ns flags offsets bnds kind env = + let (tags, elems) = unzip bnds + mkRep elem = + Irregular $ + IrregularRep + { irregularS = ns, + irregularF = flags, + irregularO = offsets, + irregularD = elem, + irregularK = kind + } + in insertReps (zip tags $ map mkRep elems) env + +insertRegulars :: [ResTag] -> [VName] -> DistEnv -> DistEnv +insertRegulars rts xs = + insertReps (zip rts $ map Regular xs) + +instance Monoid DistEnv where + mempty = DistEnv mempty + +instance Semigroup DistEnv where + DistEnv x <> DistEnv y = DistEnv (x <> y) + +resVar :: ResTag -> DistEnv -> ResRep +resVar rt env = fromMaybe bad $ M.lookup rt $ distResMap env + where + bad = error $ "resVar: unknown tag: " ++ show rt + +segsAndElems :: DistEnv -> [DistInput] -> (Maybe (VName, VName, VName), [VName]) +segsAndElems _ [] = (Nothing, []) +segsAndElems env (DistInputFree v _ : vs) = + second (v :) $ segsAndElems env vs +segsAndElems env (DistInput rt _ : vs) = + case resVar rt env of + Regular v' -> + second (v' :) $ segsAndElems env vs + Irregular (IrregularRep segments flags offsets elems k) -> do + case k of + Dense -> do + bimap (mplus $ Just (segments, flags, offsets)) (elems :) $ segsAndElems env vs + Replicated -> + second ( flags :) $ segsAndElems env vs + +-- Mapping from original variable names to their distributed resreps +inputReps :: DistInputs -> DistEnv -> M.Map VName (Type, ResRep) +inputReps inputs env = M.fromList $ map (second getRep) inputs + where + getRep di = case di of + DistInput rt t -> (t, resVar rt env) + DistInputFree v' t -> (t, Regular v') + +readInputVar :: Segments -> DistEnv -> [SubExp] -> DistInputs -> VName -> Builder GPU VName +readInputVar _segments env is inputs v = + case lookup v inputs of + Nothing -> pure v + Just (DistInputFree arr t) + | isAcc t -> pure arr + | otherwise -> letExp (baseName v) =<< eIndex arr (map eSubExp is) + Just (DistInput rt t) -> do + case resVar rt env of + Regular arr + | isAcc t -> pure arr + | otherwise -> letExp (baseName v) =<< eIndex arr (map eSubExp is) + Irregular (IrregularRep _ _flags _offsets _elems _) -> + undefined + +readInput :: Segments -> DistEnv -> [SubExp] -> DistInputs -> SubExp -> Builder GPU SubExp +readInput _ _ _ _ (Constant x) = + pure $ Constant x +readInput segments env is inputs (Var v) = + Var <$> readInputVar segments env is inputs v + +readTypeDims :: + Segments -> + DistEnv -> + [SubExp] -> + DistInputs -> + TypeBase Shape u -> + Builder GPU [SubExp] +readTypeDims segments env is inputs = + mapM (readInput segments env is inputs) . arrayDims + +segmentDims :: Segments -> [TPrimExp Int64 VName] +segmentDims = map pe64 . shapeDims . segmentsShape + +flatSegmentIndex :: Segments -> [SubExp] -> TPrimExp Int64 VName +flatSegmentIndex segments = flattenIndex (segmentDims segments) . map pe64 + +readInputs :: Segments -> DistEnv -> [SubExp] -> DistInputs -> Builder GPU () +readInputs segments env is = mapM_ onInput + where + bindInputName v e + | v `nameIn` freeIn e = do + v' <- letExp (baseName v <> "_inp") e + letBindNames [v] $ BasicOp $ SubExp $ Var v' + | otherwise = + letBindNames [v] e + onInput (v, DistInputFree arr t) = + bindInputName v + =<< if isAcc t + then eSubExp (Var arr) + else eIndex arr (map eSubExp is) + onInput (v, DistInput rt t) = + case resVar rt env of + Regular arr -> + bindInputName v + =<< if isAcc t + then eSubExp $ Var arr + else eIndex arr (map eSubExp is) + Irregular (IrregularRep _ _ v_O v_D _) -> do + offset <- letSubExp "offset" =<< eIndex v_O [toExp $ flatSegmentIndex segments is] + case arrayDims t of + [num_elems] -> do + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + bindInputName v $ BasicOp $ Index v_D slice + _ -> do + num_elems <- + letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + v_flat <- + letExp (baseName v <> "_flat") $ BasicOp $ Index v_D slice + v_flat_t <- lookupType v_flat + v' <- + letExp (baseName v <> "_inp") . BasicOp $ + Reshape v_flat (reshapeAll (arrayShape v_flat_t) (arrayShape t)) + letBindNames [v] $ BasicOp $ SubExp $ Var v' + +scopeOfDistInputs :: DistInputs -> Scope GPU +scopeOfDistInputs = scopeOfLParams . map f + where + f (v, inp) = Param mempty v (distInputType inp) + +isVariant :: DistInputs -> DistEnv -> SubExp -> Bool +isVariant inps env se = case se of + Constant _ -> False + Var v -> isJust $ M.lookup v $ inputReps inps env + +ensureDenseIrregular :: Name -> IrregularRep -> Builder GPU IrregularRep +ensureDenseIrregular _ rep@IrregularRep {irregularK = Dense} = + pure rep +ensureDenseIrregular desc rep@IrregularRep {} = do + (new_F, new_O, ii1) <- doRepIota (irregularS rep) + m <- arraySize 0 <$> lookupType ii1 + new_D <- letExp (desc <> "_dense_D") <=< segMap (MkSolo m) $ \(MkSolo i) -> do + seg <- letSubExp "seg" =<< eIndex ii1 [eSubExp i] + old_off <- letSubExp "old_off" =<< eIndex (irregularO rep) [eSubExp seg] + new_off <- letSubExp "new_off" =<< eIndex new_O [eSubExp seg] + j <- letSubExp "j" <=< toExp $ pe64 i - pe64 new_off + x <- letSubExp "x" =<< eIndex (irregularD rep) [toExp $ pe64 old_off + pe64 j] + pure [subExpRes x] + pure $ + IrregularRep + { irregularS = irregularS rep, + irregularF = new_F, + irregularO = new_O, + irregularD = new_D, + irregularK = Dense + } +-- Lift a result of a function. +liftResult :: Segments -> DistInputs -> DistEnv -> SubExpRes -> Builder GPU Result +liftResult segments inps env res = map (SubExpRes mempty . Var) <$> vs + where + vs = do + (_, rep) <- liftSubExp segments inps env (resSubExp res) + case rep of + Regular v -> pure [v] + Irregular irreg -> mkIrrep irreg + mkIrrep + ( IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + ) = do + flags_t <- lookupType flags + t <- lookupType elems + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + let shape = Shape [Var num_data] + flags' <- letExp "flags" $ BasicOp $ Reshape flags $ reshapeAll (arrayShape flags_t) shape + elems' <- letExp "elems" $ BasicOp $ Reshape elems $ reshapeAll (arrayShape t) shape + pure [num_data, segs, flags', offsets, elems'] + +liftDistResultRep :: + Segments -> + DistInputs -> + DistEnv -> + DistResult -> + SubExpRes -> + Builder GPU ResRep +liftDistResultRep segments inps env dist_res res + | isRegularDistResult dist_res = do + let (DistType _ _ t) = distResType dist_res + expectedShape = segmentsShape segments <> arrayShape t + Regular <$> liftSubExpRegular segments inps env expectedShape (resSubExp res) + | otherwise = + case resSubExp res of + Var v -> Irregular <$> getIrregRep segments env inps v + _ -> error "liftBranchResultRep: irregular result is not a variable" + +mkIrregFromReg :: + Segments -> + VName -> + Builder GPU IrregularRep +mkIrregFromReg segments arr = do + arr_t <- lookupType arr + num_segments <- + letSubExp "reg_num_segments" <=< toExp $ product $ segmentDims segments + segment_size <- + letSubExp "reg_seg_size" <=< toExp . product . map pe64 $ + drop (segmentsRank segments) (arrayDims arr_t) + arr_S <- + letExp "reg_segments" . BasicOp $ + Replicate (Shape [num_segments]) segment_size + num_elems <- + letSubExp "reg_num_elems" <=< toExp $ product $ map pe64 $ arrayDims arr_t + arr_D <- + letExp "reg_D" . BasicOp $ + Reshape arr (reshapeAll (arrayShape arr_t) (Shape [num_elems])) + arr_F <- letExp "reg_F" <=< segMap (MkSolo num_elems) $ \(MkSolo i) -> do + flag <- letSubExp "flag" <=< toExp $ (pe64 i `rem` pe64 segment_size) .==. 0 + pure [subExpRes flag] + arr_O <- letExp "reg_O" <=< segMap (MkSolo num_segments) $ \(MkSolo i) -> do + offset <- letSubExp "offset" <=< toExp $ pe64 i * pe64 segment_size + pure [subExpRes offset] + pure $ + IrregularRep + { irregularS = arr_S, + irregularF = arr_F, + irregularO = arr_O, + irregularD = arr_D, + irregularK = Dense + } + +-- If the sub-expression is a constant, replicate it to match the shape of `segments` +-- If it's a variable, lookup the variable in the dist inputs and dist env, +-- and if it can't be found it is a free variable, so we replicate it to match the shape of `segments`. +liftSubExp :: Segments -> DistInputs -> DistEnv -> SubExp -> Builder GPU (Type, ResRep) +liftSubExp segments inps env se = case se of + c@(Constant prim) -> + let t = Prim $ primValueType prim + in ((t,) . Regular <$> letExp "lifted_const" (BasicOp $ Replicate (segmentsShape segments) c)) + Var v -> case M.lookup v $ inputReps inps env of + Just (t, Regular v') -> do + (t,) + <$> case t of + Prim {} -> pure $ Regular v' + Array {} -> Irregular <$> mkIrregFromReg segments v' + Acc {} -> pure $ Regular v' + Mem {} -> error "getRepSubExp: Mem" + Just (t, Irregular irreg) -> do + irreg' <- ensureDenseIrregular "lifted_irreg" irreg + pure (t, Irregular irreg') + Nothing -> do + t <- lookupType v + v' <- letExp "free_replicated" $ BasicOp $ Replicate (segmentsShape segments) (Var v) + (t,) + <$> case t of + Prim {} -> pure $ Regular v' + Array {} -> Irregular <$> mkIrregFromReg segments v' + Acc {} -> pure $ Regular v' + Mem {} -> error "getRepSubExp: Mem" + +liftSubExpPreserveRep :: Segments -> DistInputs -> DistEnv -> SubExp -> Builder GPU (Type, ResRep) +liftSubExpPreserveRep segments inps env se = case se of + c@(Constant prim) -> + let t = Prim $ primValueType prim + in do + v <- letExp "lifted_const" $ BasicOp $ Replicate (segmentsShape segments) c + pure (t, Regular v) + Var v -> case M.lookup v $ inputReps inps env of + Just (t, rep) -> pure (t, rep) + Nothing -> do + t <- lookupType v + v' <- letExp "free_replicated" $ BasicOp $ Replicate (segmentsShape segments) (Var v) + pure (t, Regular v') + +-- | Like 'liftSubExp' but always returns a Regular result with the +-- given expected shape. Reshapes the underlying data if necessary. +liftSubExpRegular :: + Segments -> + DistInputs -> + DistEnv -> + Shape -> + SubExp -> + Builder GPU VName +liftSubExpRegular segments inps env expectedShape se = do + v <- case se of + c@(Constant _) -> + letExp "lifted_const" (BasicOp $ Replicate (segmentsShape segments) c) + Var x -> case M.lookup x $ inputReps inps env of + Just (_, Regular v') -> pure v' + Just (_, Irregular irreg) -> do + rep_dense <- ensureDenseIrregular "lifted_irreg" irreg + pure $ irregularD rep_dense + Nothing -> + letExp "free_replicated" $ BasicOp $ Replicate (segmentsShape segments) (Var x) + v_t <- lookupType v + if arrayShape v_t == expectedShape + then pure v + else + letExp "reg_lifted" . BasicOp $ + Reshape v (reshapeAll (arrayShape v_t) expectedShape) + +distCerts :: DistInputs -> StmAux a -> DistEnv -> Certs +distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux + where + f v = case lookup v inps of + Nothing -> v + Just (DistInputFree vs _) -> vs + Just (DistInput rt _) -> + case resVar rt env of + Regular vs -> vs + Irregular r -> irregularD r + +-- | Only sensible for variables of segment-invariant type. +dataArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName +dataArr segments env inps (Var v) + | Just v_inp <- lookup v inps = + case v_inp of + DistInputFree vs _ -> irregularD <$> mkIrregFromReg segments vs + DistInput rt _ -> case resVar rt env of + Irregular r -> do + rep_dense <- ensureDenseIrregular "dataArr" r + pure $ irregularD rep_dense + Regular vs -> irregularD <$> mkIrregFromReg segments vs +dataArr segments _ _ se = do + rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se + rep_t <- lookupType rep + let dims = arrayDims rep_t + if length dims == 1 + then pure rep + else do + n <- toSubExp "n" $ product $ map pe64 dims + letExp "reshape" $ BasicOp $ Reshape rep $ reshapeAll (arrayShape rep_t) (Shape [n]) + +-- | Get the irregular representation of a var. +getIrregRep :: Segments -> DistEnv -> DistInputs -> VName -> Builder GPU IrregularRep +getIrregRep segments env inps v = + case lookup v inps of + Just v_inp -> case v_inp of + DistInputFree arr _ -> mkIrregFromReg segments arr + DistInput rt _ -> case resVar rt env of + Irregular r -> pure r + Regular arr -> mkIrregFromReg segments arr + Nothing -> do + v' <- + letExp (baseName v <> "_rep") . BasicOp $ + Replicate (segmentsShape segments) (Var v) + mkIrregFromReg segments v' + +-- | This function walks through the *unlifted* result types +-- and uses the *lifted* results to construct the corresponding res reps. +-- +-- See the 'liftResult' function for the opposite process i.e. +-- turning 'ResRep's into results. +resultToResReps :: [TypeBase s u] -> [VName] -> [ResRep] +resultToResReps types results = + snd $ + L.mapAccumL + ( \rs t -> case t of + Prim {} -> + let (v : rs') = rs + rep = Regular v + in (rs', rep) + Array {} -> + let (_ : segs : flags : offsets : elems : rs') = rs + rep = Irregular $ IrregularRep segs flags offsets elems Dense + in (rs', rep) + Acc {} -> error "resultToResReps: Illegal type 'Acc'" + Mem {} -> error "resultToResReps: Illegal type 'Mem'" + ) + results + types + +-- | Write back the irregular results of a branch to a (partially) blank space +-- The `offsets` variable is the offsets of the final result, whereas `irregRep` +-- is the irregular representation of the result. +scatterIrregular :: + VName -> + VName -> + (VName, IrregularRep) -> + Builder GPU VName +scatterIrregular offsets space (is, irregRep) = do + dense_irreg <- ensureDenseIrregular "scatter_irreg" irregRep + let IrregularRep {irregularS = segs, irregularD = elems, irregularK = kind} = dense_irreg + (_, _, ii1) <- doRepIota segs + (_, _, ii2) <- doSegIota segs + ~(Array _ (Shape [size]) _) <- lookupType elems + letExp "irregular_scatter" <=< genScatter space size $ \gtid -> do + x <- letSubExp "x" =<< eIndex elems [eSubExp gtid] + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp gtid]]] + i <- letExp "i" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp gtid]) + pure (i, x) + +-- | Write back the regular results to a (partially) blank space +scatterRegular :: + VName -> + (VName, VName) -> + Builder GPU VName +scatterRegular space (is, xs) = do + size <- arraySize 0 <$> lookupType xs + letExp "regular_scatter" <=< genScatter space size $ \gtid -> do + x <- letSubExp "x" =<< eIndex xs [eSubExp gtid] + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, x) + +-- | Functions for tying together disparate modules - this is to avoid mutually +-- recursive modules. +newtype FlattenOps = FlattenOps + { flattenDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv + } + +flattenDistStms :: FlattenOps -> SubExp -> DistInputs -> DistEnv -> [DistStm] -> Result -> Builder GPU Result +flattenDistStms ops w inputs env dstms result = do + let segments = NE.singleton w + env' <- foldM (flattenDistStm ops segments) env dstms + result' <- mapM (liftResult segments inputs env') result + pure $ concat result' diff --git a/src/Futhark/Pass/Flatten/PreProcess.hs b/src/Futhark/Pass/Flatten/PreProcess.hs new file mode 100644 index 0000000000..0ab8430151 --- /dev/null +++ b/src/Futhark/Pass/Flatten/PreProcess.hs @@ -0,0 +1,98 @@ +{-# LANGUAGE TypeFamilies #-} + +-- | Preprocess the program before flattening. This rewrites SOAC forms +-- that flatten does not want to see directly, while leaving the result in +-- SOACS form so the normal flattening pipeline can continue afterwards. +module Futhark.Pass.Flatten.PreProcess (preprocessProg) where + +import Data.Maybe (isNothing) +import Futhark.Builder +import Futhark.IR.SOACS +import Futhark.IR.SOACS.Simplify (simplifyStms) +import Futhark.IR.SOACS.Simplify qualified as SOACS +import Futhark.Pass +import Futhark.Pass.Flatten.ISRWIM (irwim, iswim) +import Futhark.Tools + +shouldDissectForm :: ScremaForm SOACS -> Bool +shouldDissectForm form = + isNothing (isMapSOAC form) + && isNothing (isReduceSOAC form) + && isNothing (isScanSOAC form) + && isNothing (isRedomapSOAC form) + && isNothing (isScanomapSOAC form) + && isNothing (isMaposcanomapSOAC form) + +soacMapper :: Scope SOACS -> SOACMapper SOACS SOACS PassM +soacMapper scope = + identitySOACMapper {mapOnSOACLambda = onLambda scope} + +runSimplifiedBuilder :: + Scope SOACS -> + BuilderT SOACS PassM a -> + PassM (Stms SOACS) +runSimplifiedBuilder scope m = + fst <$> runBuilderT (simplifyStms =<< collectStms_ m) scope + +-- TODO: maybe it is better to seperate these as they are doing different things. +onStm :: Scope SOACS -> Stm SOACS -> PassM (Stms SOACS) +onStm scope (Let pat aux (Op (Stream w arrs nes lam))) = do + lam' <- onLambda scope lam + runBuilderT_ (auxing aux $ sequentialStreamWholeArray pat w nes lam' arrs) scope +onStm scope (Let pat aux (Op (Screma w arrs form))) = do + soac' <- mapSOACM (soacMapper scope) (Screma w arrs form) + case soac' of + Screma w' arrs' form' + | Just scans <- isScanSOAC form', + Scan scan_lam nes <- singleScan scans, + Just do_iswim <- iswim pat w' scan_lam (zip nes arrs') -> + runSimplifiedBuilder scope $ auxing aux do_iswim + | Just [Reduce comm red_fun nes] <- isReduceSOAC form', + let comm' + | commutativeLambda red_fun = Commutative + | otherwise = comm, + Just do_irwim <- irwim pat w' comm' red_fun (zip nes arrs') -> + runSimplifiedBuilder scope $ auxing aux do_irwim + | shouldDissectForm form' -> + runBuilderT_ (auxing aux $ dissectScrema pat w' form' arrs') scope + | otherwise -> + pure $ oneStm $ Let pat aux $ Op $ Screma w' arrs' form' + _ -> + error "onStm: impossible non-Screma" +onStm scope (Let pat aux e) = + oneStm . Let pat aux <$> mapExpM mapper e + where + mapper = + (identityMapper @SOACS) + { mapOnBody = \bscope -> onBody (bscope <> scope), + mapOnOp = mapSOACM (soacMapper scope) + } + +onStms :: Scope SOACS -> Stms SOACS -> PassM (Stms SOACS) +onStms scope stms = mconcat <$> mapM (onStm scope') (stmsToList stms) + where + scope' = scopeOf stms <> scope + +onBody :: Scope SOACS -> Body SOACS -> PassM (Body SOACS) +onBody scope body = do + stms <- onStms scope $ bodyStms body + pure $ body {bodyStms = stms} + +onLambda :: Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS) +onLambda scope lam = do + body <- onBody (scopeOfLParams (lambdaParams lam) <> scope) $ lambdaBody lam + pure $ lam {lambdaBody = body} + +onFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS) +onFun consts fd = do + body <- onBody (scopeOf consts <> scopeOf fd) $ funDefBody fd + pure $ fd {funDefBody = body} + +preprocessProg :: Prog SOACS -> PassM (Prog SOACS) +preprocessProg prog = do + prog' <- + intraproceduralTransformationWithConsts + (onStms mempty) + onFun + prog + SOACS.simplifySOACS prog' -- Is this a good idea? \ No newline at end of file diff --git a/src/Futhark/Pass/Flatten/WithAcc.hs b/src/Futhark/Pass/Flatten/WithAcc.hs new file mode 100644 index 0000000000..379221ee5d --- /dev/null +++ b/src/Futhark/Pass/Flatten/WithAcc.hs @@ -0,0 +1,153 @@ +-- | Flattening of 'WithAcc'. +module Futhark.Pass.Flatten.WithAcc + ( transformWithAcc, + ) +where + +import Control.Monad +import Control.Monad.Identity +import Data.Foldable +import Data.List.NonEmpty qualified as NE +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass.ExtractKernels.ToGPU (soacsLambdaToGPU) +import Futhark.Pass.Flatten.Builtins +import Futhark.Pass.Flatten.Distribute +import Futhark.Pass.Flatten.Monad +import Futhark.Tools +import Prelude hiding (div, rem) + +transformWithAcc :: + FlattenOps -> + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Pat Type -> + StmAux () -> + [WithAccInput SOACS] -> + Lambda SOACS -> + Builder GPU DistEnv +transformWithAcc ops segments env inps distres _withacc_pat withacc_aux withacc_inputs acc_lam = do + let inputTypes (_, arrs, _) = mapM lookupType arrs + variant <- + localScope (scopeOfDistInputs inps) $ + any (any (any (isVariant inps env) . arrayDims)) + <$> mapM inputTypes withacc_inputs + when variant $ error "Cannot yet handle variant WithAccs" + + withacc_inputs' <- mapM onInput withacc_inputs + lam_params' <- newAccLamParams $ lambdaParams acc_lam + + iota_w <- genShapeIota $ segmentsShape segments + + iota_p <- newParam "iota_p" $ Prim int64 + + iota_w_t <- lookupType iota_w + let iota_se = Var (paramName iota_p) + + acc_lam_body <- + runBodyBuilder $ + localScope (scopeOfLParams lam_params') $ + bodyBind (lambdaBody (trLam iota_se acc_lam)) + + scope <- askScope + let acc_params = drop num_accs lam_params' + orig_acc_params = drop num_accs $ lambdaParams acc_lam + interchanged_inps = + (paramName iota_p, DistInputFree iota_w iota_w_t) + : [ (paramName p, DistInputFree (paramName acc) (paramType acc)) + | (p, acc) <- zip orig_acc_params acc_params + ] + ++ inps + [w] = NE.toList segments + -- FIXME: we are not using withacc_new_inputs, which has got to be wrong. + (withacc_new_inputs, withacc_dstms) = + distributeBody + scope + segments + interchanged_inps + acc_lam_body + + withacc_lam' <- mkLambda (map trParam lam_params') $ do + env' <- foldM (flattenDistStm ops segments) env withacc_dstms + -- TODO: Isn't this the fix that we need? + concat <$> mapM (liftResult segments withacc_new_inputs env') (bodyResult $ lambdaBody acc_lam) + + withacc_out_vs <- + certifying (distCerts inps withacc_aux env) $ + letTupExp "withacc_flatten_out" (WithAcc withacc_inputs' withacc_lam') + + let out_reps = map Regular withacc_out_vs + pure $ insertReps (zip (map distResTag distres) out_reps) env + where + newAccLamParams ps = do + let (cert_ps, acc_ps) = splitAt num_accs ps + -- Should not rename the certificates. + acc_ps' <- forM acc_ps $ \(Param attrs v t) -> + Param attrs <$> newName v <*> pure t + pure $ cert_ps <> acc_ps' + + num_accs = length withacc_inputs + acc_certs = map paramName $ take num_accs $ lambdaParams acc_lam + + onOp (op_lam, nes) = do + -- We need to add an additional index parameter because we are + -- extending the index space of the accumulator. + idx_p <- newParam "idx" $ Prim int64 + pure + ( soacsLambdaToGPU $ op_lam {lambdaParams = idx_p : lambdaParams op_lam}, + nes + ) + + onInput (shape, arrs, op) = + (segmentsShape segments <> shape,,) + <$> mapM onArr arrs + <*> traverse onOp op + + onArr = readInputVar segments env [] inps + + trType :: TypeBase shape u -> TypeBase shape u + trType (Acc acc ispace ts u) + | acc `elem` acc_certs = + Acc acc (segmentsShape segments <> ispace) ts u + trType t = t + + trParam :: Param (TypeBase shape u) -> Param (TypeBase shape u) + trParam = fmap trType + + trStm i (Let pat aux e) = + Let (fmap trType pat) aux $ trExp i pat e + + trBody i (Body dec stms res) = + Body dec (fmap (trStm i) stms) res + + trLam i (Lambda params ret body) = + Lambda (map trParam params) (map trType ret) (trBody i body) + + trSOAC i = runIdentity . mapSOACM mapper + where + mapper = + identitySOACMapper {mapOnSOACLambda = pure . trLam i} + + trExp i _ (WithAcc acc_inputs lam) = + WithAcc acc_inputs $ trLam i lam + trExp i (Pat [PatElem _ acc_t]) (BasicOp (UpdateAcc safety acc is ses)) = do + case acc_t of + Acc cert _ _ _ + | cert `elem` acc_certs -> + BasicOp $ UpdateAcc safety acc (i : is) ses + _ -> + BasicOp $ UpdateAcc safety acc is ses + trExp i _ e = mapExp mapper e + where + mapper = + identityMapper + { mapOnBody = \_ -> pure . trBody i, + mapOnRetType = pure . trType, + mapOnBranchType = pure . trType, + mapOnFParam = pure . trParam, + mapOnLParam = pure . trParam, + mapOnOp = pure . trSOAC i + } diff --git a/src/Futhark/Passes.hs b/src/Futhark/Passes.hs index a9d26647a0..db05ee7b06 100644 --- a/src/Futhark/Passes.hs +++ b/src/Futhark/Passes.hs @@ -39,9 +39,9 @@ import Futhark.Pass.ExpandAllocations import Futhark.Pass.ExplicitAllocations.GPU qualified as GPU import Futhark.Pass.ExplicitAllocations.MC qualified as MC import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq -import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform +import Futhark.Pass.Flatten import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations import Futhark.Pass.Simplify @@ -85,7 +85,7 @@ adPipeline = gpuPipeline :: Pipeline SOACS GPU gpuPipeline = standardPipeline - >>> onePass extractKernels + >>> onePass flattenSOACs >>> passes [ simplifyGPU, addGlobalParams, diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 275967863d..2bfc09f3f7 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -22,6 +22,7 @@ module Futhark.Util partitionMaybe, maybeNth, maybeHead, + unsnoc, lookupWithIndex, splitFromEnd, splitAt3, @@ -193,6 +194,12 @@ maybeHead :: [a] -> Maybe a maybeHead [] = Nothing maybeHead (x : _) = Just x +-- | Split the last element from the list, if it exists. +unsnoc :: [a] -> Maybe ([a], a) +unsnoc [] = Nothing +unsnoc [x] = Just ([], x) +unsnoc (x : xs) = unsnoc xs >>= \(ys, y) -> Just (x : ys, y) + -- | Lookup a value, returning also the index at which it appears. lookupWithIndex :: (Eq a) => a -> [(a, b)] -> Maybe (Int, b) lookupWithIndex needle haystack = diff --git a/tests/flattening/CosminArrayExample.fut b/tests/flattening/CosminArrayExample.fut deleted file mode 100644 index b741a1acd6..0000000000 --- a/tests/flattening/CosminArrayExample.fut +++ /dev/null @@ -1,17 +0,0 @@ --- Problem here is that we need will distribute the map --- let arrs = map (\x -> iota(2*x)) xs --- let arr's = map (\x arr -> reshape( (x,2), arr) $ zip xs arrs --- let res = map(\arr' -> reduce(op(+), 0, arr')) arr's --- == --- input { --- [ 1i64, 2i64, 3i64, 4i64] --- } --- output { --- [1i64, 6i64, 15i64, 28i64] --- } -def main (xs: []i64) : []i64 = - map (\(x: i64) -> - let arr = #[unsafe] 0..<(2 * x) - let arr' = #[unsafe] unflatten arr - in reduce (+) 0 (arr'[0]) + reduce (+) 0 (arr'[1])) - xs diff --git a/tests/flattening/HighlyNestedMap.fut b/tests/flattening/HighlyNestedMap.fut deleted file mode 100644 index 4ff101dd75..0000000000 --- a/tests/flattening/HighlyNestedMap.fut +++ /dev/null @@ -1,41 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- ] --- [ [ [ [4,5,6] , [1,2,3] ] --- , [ [9,10,11], [6,7,8] ] --- ] --- , [ [ [4,5,6] , [3,2,1] ] --- , [ [11,10,9], [8,7,6] ] --- ] --- ] --- } --- output { --- [[[[5, 7, 9], --- [5, 7, 9]], --- [[15, 17, 19], --- [15, 17, 19]]], --- [[[7, 7, 7], --- [7, 7, 7]], --- [[19, 17, 15], --- [19, 17, 15]]]] --- } -def add1 [n] (xs: [n]i32, ys: [n]i32) : [n]i32 = - map2 (+) xs ys - -def add2 [n] [m] (xs: [n][m]i32, ys: [n][m]i32) : [n][m]i32 = - map add1 (zip xs ys) - -def add3 [n] [m] [l] (xs: [n][m][l]i32, ys: [n][m][l]i32) : [n][m][l]i32 = - map add2 (zip xs ys) - -def add4 (xs: [][][][]i32, ys: [][][][]i32) : [][][][]i32 = - map add3 (zip xs ys) - -def main (a: [][][][]i32) (b: [][][][]i32) : [][][][]i32 = - add4 (a, b) diff --git a/tests/flattening/IntmRes1.fut b/tests/flattening/IntmRes1.fut deleted file mode 100644 index a45f594494..0000000000 --- a/tests/flattening/IntmRes1.fut +++ /dev/null @@ -1,23 +0,0 @@ --- == --- input { --- [ [1,2,3], [4,5,6] --- , [6,7,8], [9,10,11] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[7, 8, 9], --- [16, 17, 18], --- [24, 25, 26], --- [33, 34, 35]] --- } -def addToRow [n] (xs: [n]i32, y: i32) : [n]i32 = - map (\(x: i32) : i32 -> x + y) xs - -def main (xss: [][]i32) (cs: []i32) (y: i32) : [][]i32 = - map (\(xs: []i32, c: i32) -> - let y' = y * c + c - let zs = addToRow (xs, y') - in zs) - (zip xss cs) diff --git a/tests/flattening/IntmRes2.fut b/tests/flattening/IntmRes2.fut deleted file mode 100644 index f370fa131b..0000000000 --- a/tests/flattening/IntmRes2.fut +++ /dev/null @@ -1,30 +0,0 @@ --- == --- input { --- [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- , [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[[7, 8, 9], --- [10, 11, 12]], --- [[18, 19, 20], --- [21, 22, 23]], --- [[21, 20, 19], --- [22, 23, 24]], --- [[32, 31, 30], --- [35, 34, 33]]] --- } -def addToRow [n] (xs: [n]i32, y: i32) : [n]i32 = - map (\(x: i32) : i32 -> x + y) xs - -def main (xsss: [][][]i32) (cs: []i32) (y: i32) : [][][]i32 = - map (\(xss: [][]i32, c: i32) -> - let y' = y * c + c - in map (\(xs: []i32) -> - addToRow (xs, y')) - xss) - (zip xsss cs) diff --git a/tests/flattening/IntmRes3.fut b/tests/flattening/IntmRes3.fut deleted file mode 100644 index 4ea4497364..0000000000 --- a/tests/flattening/IntmRes3.fut +++ /dev/null @@ -1,36 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- ] --- , [ [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- ] --- , [ [ [8,7,6], [11,10,9] ] --- ] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[[[7, 8, 9], --- [10, 11, 12]]], --- [[[18, 19, 20], --- [21, 22, 23]]], --- [[[21, 20, 19], --- [22, 23, 24]]], --- [[[32, 31, 30], --- [35, 34, 33]]]] --- } -def addToRow [n] (xs: [n]i32, y: i32) : [n]i32 = - map (\(x: i32) : i32 -> x + y) xs - -def main (xssss: [][][][]i32) (cs: []i32) (y: i32) : [][][][]i32 = - map (\(xsss: [][][]i32, c: i32) -> - let y' = y * c + c - in map (\(xss: [][]i32) -> - map (\(xs: []i32) -> - addToRow (xs, y')) - xss) - xsss) - (zip xssss cs) diff --git a/tests/flattening/LoopInv1.fut b/tests/flattening/LoopInv1.fut deleted file mode 100644 index 99b9c430d2..0000000000 --- a/tests/flattening/LoopInv1.fut +++ /dev/null @@ -1,24 +0,0 @@ --- == --- input { --- [ [1,2,3], [4,5,6] --- , [6,7,8], [9,10,11] --- , [3,2,1], [4,5,6] --- , [8,7,6], [11,10,9] --- ] --- [1,2,3] --- } --- output { --- [[2, 4, 6], --- [5, 7, 9], --- [7, 9, 11], --- [10, 12, 14], --- [4, 4, 4], --- [5, 7, 9], --- [9, 9, 9], --- [12, 12, 12]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32) : [n]i32 = - map2 (+) xs ys - -def main (xss: [][]i32) (ys: []i32) : [][]i32 = - map (\(xs: []i32) -> addRows (xs, ys)) xss diff --git a/tests/flattening/LoopInv2.fut b/tests/flattening/LoopInv2.fut deleted file mode 100644 index bdf8ee19d5..0000000000 --- a/tests/flattening/LoopInv2.fut +++ /dev/null @@ -1,26 +0,0 @@ --- == --- input { --- [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- , [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- [1,2,3] --- } --- output { --- [[[2, 4, 6], --- [5, 7, 9]], --- [[7, 9, 11], --- [10, 12, 14]], --- [[4, 4, 4], --- [5, 7, 9]], --- [[9, 9, 9], --- [12, 12, 12]]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32) : [n]i32 = - map2 (+) xs ys - -def main (xsss: [][][]i32) (ys: []i32) : [][][]i32 = - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> addRows (xs, ys)) xss) - xsss diff --git a/tests/flattening/LoopInv3.fut b/tests/flattening/LoopInv3.fut deleted file mode 100644 index b4f23df918..0000000000 --- a/tests/flattening/LoopInv3.fut +++ /dev/null @@ -1,34 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- ] --- , [ [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- ] --- , [ [ [8,7,6], [11,10,9] ] --- ] --- ] --- [1,2,3] --- } --- output { --- [[[[2, 4, 6], --- [5, 7, 9]]], --- [[[7, 9, 11], --- [10, 12, 14]]], --- [[[4, 4, 4], --- [5, 7, 9]]], --- [[[9, 9, 9], --- [12, 12, 12]]]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32) : [n]i32 = - map2 (+) xs ys - -def main (xssss: [][][][]i32) (ys: []i32) : [][][][]i32 = - map (\(xsss: [][][]i32) -> - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> - addRows (xs, ys)) - xss) - xsss) - xssss diff --git a/tests/flattening/LoopInvReshape.fut b/tests/flattening/LoopInvReshape.fut deleted file mode 100644 index 8c423d4ebe..0000000000 --- a/tests/flattening/LoopInvReshape.fut +++ /dev/null @@ -1,16 +0,0 @@ --- This example presents difficulty for me right now, but also has a --- large potential for improvement later on. --- --- we could turn it into: --- --- let []i32 bettermain ([]i32 xs, [#n]i32 ys, [#n]i32 zs, [#n]i32 is, [#n]i32 js) = --- map (\i32 (i32 y, i32 z, i32 i, i32 j) -> --- xs[i*z + j] --- , zip(ys,zs,is,js)) - -def main [n] [m] (xs: [m]i32, ys: [n]i64, zs: [n]i64, is: [n]i32, js: [n]i32) : []i32 = - map (\(y: i64, z: i64, i: i32, j: i32) : i32 -> - #[unsafe] - let tmp = unflatten (xs :> [y * z]i32) - in tmp[i, j]) - (zip4 ys zs is js) diff --git a/tests/flattening/Map-IotaMapReduce.fut b/tests/flattening/Map-IotaMapReduce.fut deleted file mode 100644 index 6b04d652d5..0000000000 --- a/tests/flattening/Map-IotaMapReduce.fut +++ /dev/null @@ -1,14 +0,0 @@ --- == --- input { --- [2,3,4] --- [8,3,2] --- } --- output { --- [8,9,12] --- } -def main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = - map (\(x: i32, y: i32) : i32 -> - let tmp1 = 0.. - map (\(x: i32) : i32 -> - let tmp1 = map i32.i64 (iota (i64.i32 x)) - let tmp2 = map (* y) tmp1 - in reduce (+) 0 tmp2) - xs) - (zip xss ys) diff --git a/tests/flattening/MapIotaReduce.fut b/tests/flattening/MapIotaReduce.fut deleted file mode 100644 index 6b72daa179..0000000000 --- a/tests/flattening/MapIotaReduce.fut +++ /dev/null @@ -1,12 +0,0 @@ --- == --- input { --- [1,2,3,4] --- } --- output { --- [0, 1, 3, 6] --- } -def main (xs: []i32) : []i32 = - map (\(x: i32) : i32 -> - let tmp = 0.. - reduce (+) 0 xs) - xss diff --git a/tests/flattening/VectorAddition.fut b/tests/flattening/VectorAddition.fut deleted file mode 100644 index 85241e4630..0000000000 --- a/tests/flattening/VectorAddition.fut +++ /dev/null @@ -1,10 +0,0 @@ --- == --- input { --- [1,2,3,4] --- [5,6,7,8] --- } --- output { --- [6,8,10,12] --- } -def main (xs: []i32) (ys: []i32) : []i32 = - map2 (+) xs ys diff --git a/tests/flattening/arraylit-irregular.fut b/tests/flattening/arraylit-irregular.fut new file mode 100644 index 0000000000..a2e3dc8a71 --- /dev/null +++ b/tests/flattening/arraylit-irregular.fut @@ -0,0 +1,12 @@ +-- == +-- input { [2i64,2i64,3i64, 7i64, 8i64] } +-- auto output + +def main (xs: []i64) = + map (\x -> + let ys = iota x + let reps = opaque (replicate x 2 with [1] = x) + let ks = opaque ([reps, ys]) + let ks' = map (map (* 2)) ks + in map (i64.sum) ks') + xs diff --git a/tests/flattening/arraylit-nested.fut b/tests/flattening/arraylit-nested.fut new file mode 100644 index 0000000000..27ae3fd00f --- /dev/null +++ b/tests/flattening/arraylit-nested.fut @@ -0,0 +1,13 @@ +-- == +-- input { [1i64,2i64,3i64] [10i64,20i64] } +-- output { +-- [ +-- [[1i64,10i64,11i64], [1i64,20i64,21i64]], +-- [[2i64,10i64,12i64], [2i64,20i64,22i64]], +-- [[3i64,10i64,13i64], [3i64,20i64,23i64]] +-- ] +-- } +entry main (xs: []i64) (ys: []i64) = + map (\x -> + map (\y -> [x, y, x + y]) ys) + xs \ No newline at end of file diff --git a/tests/flattening/arraylit-regualar-array.fut b/tests/flattening/arraylit-regualar-array.fut new file mode 100644 index 0000000000..00e5f6b2ac --- /dev/null +++ b/tests/flattening/arraylit-regualar-array.fut @@ -0,0 +1,11 @@ +-- == +-- input { [1i64,2i64,3i64] } +-- auto output +entry main (xs: []i64) = + map (\x -> + let y = opaque (replicate 50 10 with [x] = x) + let d = opaque (replicate 50 10 with [x] = x + 3) + let t = opaque (replicate 50 x with [x] = 10) + let o = iota 50 + in opaque ([y, d, t, o])) + xs diff --git a/tests/flattening/arraylit-simple.fut b/tests/flattening/arraylit-simple.fut new file mode 100644 index 0000000000..11fde0c4e2 --- /dev/null +++ b/tests/flattening/arraylit-simple.fut @@ -0,0 +1,6 @@ +-- Simple test for flattening an update with a constant value +-- == +-- input { [1i64,2i64,3i64] } +-- output { [[10i64, 1i64, 2i64], [10i64, 2i64, 3i64], [10i64, 3i64, 4i64]] } +entry main (xs: []i64) = + map (\x -> [10, x, x + 1]) xs \ No newline at end of file diff --git a/tests/flattening/avoidance0.fut b/tests/flattening/avoidance0.fut new file mode 100644 index 0000000000..9a0b7dbbbe --- /dev/null +++ b/tests/flattening/avoidance0.fut @@ -0,0 +1,12 @@ +-- == +-- input { [3i64, 3i64, 3i64] [5i64,6i64,8i64] } +-- auto output +def main (xs: []i64) (ys: []i64) = + unzip (map2 (\a b -> + let r1 = a - b + let r2 = a + b + let r3 = r1 * r1 + let r4 = a * b + r3 + in (r2, r4)) + xs + ys) diff --git a/tests/flattening/avoidance1.fut b/tests/flattening/avoidance1.fut new file mode 100644 index 0000000000..c8dacfd207 --- /dev/null +++ b/tests/flattening/avoidance1.fut @@ -0,0 +1,13 @@ +-- == +-- input { [3i64, 7i64, 10i64, 1i64, 20i64] } +-- auto output +def main (xs: []i64) = + map (\x -> + let g = x * 5 * 100 + let y = g * 55 + let r2 = y * 2 + let t = y + 100 + r2 + let ys = iota x + let z = t * 2 + in z * g + i64.sum ys) + xs diff --git a/tests/flattening/binop.fut b/tests/flattening/binop.fut new file mode 100644 index 0000000000..6496c9f5c3 --- /dev/null +++ b/tests/flattening/binop.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1,2,3] [4,5,6] } +-- output { [5,7,9] } + +def main = map2 (i32.+) diff --git a/tests/flattening/complex-screma.fut b/tests/flattening/complex-screma.fut new file mode 100644 index 0000000000..efc9805f74 --- /dev/null +++ b/tests/flattening/complex-screma.fut @@ -0,0 +1,9 @@ +-- input { [[1.0f32 2.0f32 3.0f32] [4.0f32 5.0f32 6.0f32] [7.0f32 8.0f32 9.0f32]] } +-- auto output + +entry main [n] [m] (a: [m][n]f32) = + map (\row -> + let row_scanned = scan (+) 0 row + in (reduce (+) 0 row, row_scanned)) + a + |> unzip diff --git a/tests/flattening/concat-check-index.fut b/tests/flattening/concat-check-index.fut new file mode 100644 index 0000000000..8e1f0bb51e --- /dev/null +++ b/tests/flattening/concat-check-index.fut @@ -0,0 +1,27 @@ +-- Validation of flattening with 2 lists +-- +-- == +-- entry: validate_flattening_2 +-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64] +-- [0i64, 2i64, 4i64, 6i64, 8i64, 10i64] +-- [0i64,0i64,2i64,6i64,12i64,20i64,30i64] +-- [0i64,0i64, +-- 0i64,1i64,0i64,1i64, +-- 0i64,1i64,2i64,0i64,1i64,4i64, +-- 0i64,1i64,2i64,3i64,0i64,1i64,4i64,9i64, +-- 0i64,1i64,2i64,3i64,4i64,0i64,1i64,4i64,9i64,16i64 +--]} +-- output {[true, true, true, true, true, true]} +entry validate_flattening_2 (ns: []i64) (shp: []i64) (offsets: []i64) (expected: []i64) : []bool = + map2 (\n i -> + let irreg = opaque (iota n `concat` (iota n |> map (**2))) + in + if shp[i] == 0i64 && length irreg == 0i64 then true + else if shp[i] == 0i64 then false + else + let gts = iota shp[i] |> map (\j -> expected[offsets[i] + j]) :> [n + n]i64 + let pairs = zip irreg gts + let eq = map (\(pd, gt) -> pd == gt) pairs + in + reduce (&&) true eq + )ns (indices ns) diff --git a/tests/flattening/concat-d1-3d.fut b/tests/flattening/concat-d1-3d.fut new file mode 100644 index 0000000000..98268f6108 --- /dev/null +++ b/tests/flattening/concat-d1-3d.fut @@ -0,0 +1,16 @@ +-- == +-- input { [0i64, 1i64] +-- [[[[1i32,2i32],[0i32,4i32],[5i32,6i32]], +-- [[7i32,8i32],[9i32,10i32],[11i32,12i32]]], +-- [[[13i32,14i32],[1i32,4i32],[17i32,18i32]], +-- [[19i32,20i32],[100i32,5i32],[23i32,24i32]]]] } +-- auto output + +let main [k][a][b][c] (is: [k]i64) (xsss: [k][a][b][c]i32) = + map2 + (\i x -> + let p = map (\rows -> opaque rows[i:]) x + let q = map (\rows -> opaque rows[:i]) x + in (map2 concat p q) :> [a][b][c]i32) + is + xsss diff --git a/tests/flattening/concat-d1.fut b/tests/flattening/concat-d1.fut new file mode 100644 index 0000000000..aafb57598d --- /dev/null +++ b/tests/flattening/concat-d1.fut @@ -0,0 +1,13 @@ +-- == +-- input { [0i64, 1i64] +-- [[[1i32,2i32,3i32],[4i32,5i32,6i32]], +-- [[7i32,8i32,9i32],[10i32,11i32,12i32]]] } +-- auto output + +def main [k] [n] [m] (is: [k]i64) (ass: [k][n][m]i32) = + map2 (\i a -> + let x = map (\row -> opaque row[i:]) a + let y = map (\row -> opaque row[:i]) a + in (map2 concat x y) :> [n][m]i32) + is + ass diff --git a/tests/flattening/concat-d2.fut b/tests/flattening/concat-d2.fut new file mode 100644 index 0000000000..39bb1a0992 --- /dev/null +++ b/tests/flattening/concat-d2.fut @@ -0,0 +1,13 @@ +-- == +-- input { [1i64, 0i64] +-- [[[[1i64,2i64],[3i64,4i64]],[[5i64,6i64],[7i64,8i64]]], +-- [[[9i64,10i64],[11i64,12i64]],[[13i64,14i64],[15i64,16i64]]]] } +-- auto output + +def main [k] [a] [b] [c] (is: [k]i64) (xsss: [k][a][b][c]i64) = + map2 (\i x -> + let p = map (\rows -> map (\row -> opaque row[i:]) rows) x + let q = map (\rows -> map (\row -> opaque row[:i]) rows) x + in (map2 (map2 concat) p q) :> [a][b][c]i64) + is + xsss diff --git a/tests/flattening/concat-iota.fut b/tests/flattening/concat-iota.fut new file mode 100644 index 0000000000..32b7b9517f --- /dev/null +++ b/tests/flattening/concat-iota.fut @@ -0,0 +1,8 @@ +-- Validation of flattening with 2 lists +-- +-- == +-- entry: validate_flattening +-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64]} +-- output {[0i64, 0i64, 2i64, 6i64, 12i64, 20i64]} +entry validate_flattening (ns: []i64) : []i64 = + map (\n -> i64.sum (opaque (iota n `concat` iota n))) ns \ No newline at end of file diff --git a/tests/flattening/concat-rep.fut b/tests/flattening/concat-rep.fut new file mode 100644 index 0000000000..eb5e650cbf --- /dev/null +++ b/tests/flattening/concat-rep.fut @@ -0,0 +1,8 @@ +-- Validation of flattening with 3 lists +-- +-- == +-- entry: validate_flattening2 +-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 10i64, 13i64]} +-- output {[0i64, 1i64, 4i64, 9i64, 16i64, 25i64, 100i64, 169i64]} +entry validate_flattening2 (ns: []i64) : []i64 = + map (\n -> i64.sum (opaque ((replicate n 1) `concat` iota n `concat` iota n))) ns \ No newline at end of file diff --git a/tests/flattening/dup2d.fut b/tests/flattening/dup2d.fut new file mode 100644 index 0000000000..e33a722966 --- /dev/null +++ b/tests/flattening/dup2d.fut @@ -0,0 +1,7 @@ +-- == +-- input { [[1,2,3],[4,5,6]] } +-- auto output + +def dup = replicate 2 >-> transpose >-> flatten + +entry main (z: [][]i32) = z |> map dup |> dup diff --git a/tests/flattening/dup3d.fut b/tests/flattening/dup3d.fut new file mode 100644 index 0000000000..1b8e2a228e --- /dev/null +++ b/tests/flattening/dup3d.fut @@ -0,0 +1,9 @@ +-- Currently fails; an array that is too small is produced somehow. I +-- suspect replication. +-- == +-- input { [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] } +-- auto output + +def dup = replicate 5 >-> transpose >-> flatten + +def main (z: [2][3][4]i32) = z |> map (map dup) |> map dup |> dup diff --git a/tests/flattening/fail-irreg-inner-map.fut b/tests/flattening/fail-irreg-inner-map.fut new file mode 100644 index 0000000000..f545abd968 --- /dev/null +++ b/tests/flattening/fail-irreg-inner-map.fut @@ -0,0 +1,23 @@ +-- == +-- input { [3i64, 5i64, 9, 1] } +-- auto output +def main [n] (xs: [n]i64) = + map (\x -> + let mat = opaque (iota x) + let mat2 = map (+5) mat + in mat2[0]) + xs + +-- -- == +-- -- input { [3i64, 5i64] } +-- -- auto output +-- def main [n] (xs: [n]i64) = +-- map (\x -> +-- let mat = map (\i -> map (\j -> i + j) (iota x)) (iota x) +-- let mat2 = +-- map (\row -> +-- let z = row[0] +-- in map (\e -> e + z) row) +-- mat +-- in mat2[0][0]) +-- xs diff --git a/tests/flattening/flattening-pipeline b/tests/flattening/flattening-pipeline deleted file mode 100755 index ed91df97eb..0000000000 --- a/tests/flattening/flattening-pipeline +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/sh -futhark -s --flattening -i "$1" diff --git a/tests/flattening/flattening-test b/tests/flattening/flattening-test deleted file mode 100755 index 92bc4de552..0000000000 --- a/tests/flattening/flattening-test +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh - -HERE=$(dirname "$0") - -if [ $# -lt 1 ]; then - FILES="$HERE/"*.fut -else - FILES=$* -fi - -futhark-test --only-interpret --interpreter="$HERE/flattening-pipeline" $FILES diff --git a/tests/flattening/for/for-inner-map-reduce.fut b/tests/flattening/for/for-inner-map-reduce.fut new file mode 100644 index 0000000000..c6506b5da8 --- /dev/null +++ b/tests/flattening/for/for-inner-map-reduce.fut @@ -0,0 +1,14 @@ +-- == +-- input { [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + let d = + loop (ys, s) = (xs, 0) + for i < 4 do + let ys' = map (* i) ys + let s' = s + i64.sum ys' + in (ys', s') + in d.1) + xss diff --git a/tests/flattening/for/for-irregular-init.fut b/tests/flattening/for/for-irregular-init.fut new file mode 100644 index 0000000000..381086b2c1 --- /dev/null +++ b/tests/flattening/for/for-irregular-init.fut @@ -0,0 +1,11 @@ +-- == +-- input { [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + let ys = map (*2) xs + let res = loop zs = ys for i < 4 do + map (+i) zs + in res) + xss diff --git a/tests/flattening/for/for-irregular-param-uniform-it.fut b/tests/flattening/for/for-irregular-param-uniform-it.fut new file mode 100644 index 0000000000..7d6b9855c5 --- /dev/null +++ b/tests/flattening/for/for-irregular-param-uniform-it.fut @@ -0,0 +1,14 @@ +-- == +-- input { [10i64,7i64,0i64,10i64] 8i64 } +-- auto output +def main [n] (xs: [n]i64) (i :i64) = + map (\xs -> + let ys = iota xs + let (acc_res, _) = + loop (acc, j) = (ys, 1) + for i < i do + let acc' = map (\y -> y * j) acc + let j' = j + i + in (acc', j') + in reduce (+) 0 (acc_res)) + xs \ No newline at end of file diff --git a/tests/flattening/for/for-irregular-param.fut b/tests/flattening/for/for-irregular-param.fut new file mode 100644 index 0000000000..4a1d566c69 --- /dev/null +++ b/tests/flattening/for/for-irregular-param.fut @@ -0,0 +1,15 @@ +-- == +-- input { [10i64,7i64,3i64] [10i64,7i64,3i64]} +-- auto output +def main [n] (xs: [n]i64) (is: [n]i64) = + map2 (\xs it -> + let ys = iota xs + let (acc_res, _) = + loop (acc, j) = (ys, 1) + for i < it do + let acc' = map (\y -> y * j) acc + let j' = j + i + in (acc', j') + in reduce (+) 0 (acc_res)) + xs + is \ No newline at end of file diff --git a/tests/flattening/for/for-regular-inner2dmap-irregular-param.fut b/tests/flattening/for/for-regular-inner2dmap-irregular-param.fut new file mode 100644 index 0000000000..8d370ae5a6 --- /dev/null +++ b/tests/flattening/for/for-regular-inner2dmap-irregular-param.fut @@ -0,0 +1,18 @@ +-- == +-- input { [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64]] [13i64, 11i64, 6i64] } +-- auto output + +def main [n][m][k] (xss: [n][m]i64) (ys: [k]i64) = + map (\xs-> + map (\x -> + let zs = iota x + let (acc, _) = + loop (acc, zs_acc) = (ys, zs) for i < 10 do + let ys' = map (+ i) acc + let zs' = map (+ i) zs_acc + let sum_zs = i64.sum zs' + let acc' = map (+ sum_zs) ys' + in (acc', zs') + in acc + ) xs + ) xss diff --git a/tests/flattening/for/for-regular-inner2dmap-irregular.fut b/tests/flattening/for/for-regular-inner2dmap-irregular.fut new file mode 100644 index 0000000000..7c5ab69195 --- /dev/null +++ b/tests/flattening/for/for-regular-inner2dmap-irregular.fut @@ -0,0 +1,15 @@ +-- == +-- input { [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64]] [13i64, 11i64, 6i64] } +-- auto output + +def main [n][m][k] (xss: [n][m]i64) (ys: [k]i64) = + map (\xs-> + map (\x -> + loop acc = ys for i < 10 do + let ys' = map (+ i ) acc + let zs = iota x + let zs' = map (*i) zs + let sum_zs = i64.sum zs' + in map (+ sum_zs) ys' + ) xs + ) xss diff --git a/tests/flattening/for/for-regular-inner2dmap.fut b/tests/flattening/for/for-regular-inner2dmap.fut new file mode 100644 index 0000000000..8e3e30bc81 --- /dev/null +++ b/tests/flattening/for/for-regular-inner2dmap.fut @@ -0,0 +1,13 @@ +-- == +-- input { [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64]] [13i64, 11i64, 6i64] } +-- auto output + +def main [n][m][k] (xss: [n][m]i64) (ys: [k]i64) = + map (\xs-> + let xs' = map (*2) xs in + map (\x -> + loop acc = ys for i < 10 do + let ys' = map (+ i ) acc + in map (+ x) ys' + ) xs' + ) xss diff --git a/tests/flattening/for/for-regular-iota.fut b/tests/flattening/for/for-regular-iota.fut new file mode 100644 index 0000000000..797730f2e4 --- /dev/null +++ b/tests/flattening/for/for-regular-iota.fut @@ -0,0 +1,11 @@ +-- == +-- input { [[1i64, 2i64, 0i64], [4i64, 0i64, 6i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + loop acc = xs for _i < 3 do + let ys = iota m + let acc' = map2 (+) acc ys + in acc' + ) xss diff --git a/tests/flattening/for/for-regular-replicate.fut b/tests/flattening/for/for-regular-replicate.fut new file mode 100644 index 0000000000..cf5079d602 --- /dev/null +++ b/tests/flattening/for/for-regular-replicate.fut @@ -0,0 +1,12 @@ +-- == +-- input { [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + loop acc = xs for i < 3 do + let s = reduce (+) 0 acc + let ys = replicate m (s + i) + let acc' = map2 (+) acc ys + in acc' + ) xss diff --git a/tests/flattening/for/for-regular-while-inner.fut b/tests/flattening/for/for-regular-while-inner.fut new file mode 100644 index 0000000000..d294712f3e --- /dev/null +++ b/tests/flattening/for/for-regular-while-inner.fut @@ -0,0 +1,10 @@ +-- == +-- input { [[5i64, 3i64, 10i64], [2i64, 4i64, 6i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + loop acc = xs for i < 5 do + loop acc2 = acc while reduce (+) 0 acc2 > 0 do + map (\x -> x - 1) acc2 + ) xss diff --git a/tests/flattening/for/for-sequential.fut b/tests/flattening/for/for-sequential.fut new file mode 100644 index 0000000000..a3a9c21b0c --- /dev/null +++ b/tests/flattening/for/for-sequential.fut @@ -0,0 +1,15 @@ +-- == +-- input { [0i64, 1i64, 5i64, 10i64] } +-- auto output + +def main [n] (xs: [n]i64) = + map (\x -> + let acc = + loop a = x + for i < 10 do + a + i + let mid = acc * 2 + in loop b = mid + for j < mid do + b + j) + xs diff --git a/tests/flattening/for/for-simple.fut b/tests/flattening/for/for-simple.fut new file mode 100644 index 0000000000..ca067d984f --- /dev/null +++ b/tests/flattening/for/for-simple.fut @@ -0,0 +1,13 @@ +-- == +-- input { [1i64, 5i64, 10i64] } +-- auto output + +def main [n] (xs: [n]i64) = + map (\x -> + let y = x * 2 + let z = + loop acc = y + for i < 4 do + acc + i + in z * x) + xs diff --git a/tests/flattening/for/for-tag-stress.fut b/tests/flattening/for/for-tag-stress.fut new file mode 100644 index 0000000000..c0c196e845 --- /dev/null +++ b/tests/flattening/for/for-tag-stress.fut @@ -0,0 +1,14 @@ +-- == +-- input { [[1i64, 2i64, 3i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) : [n]i64 = + map (\xs -> + let (row_sum, scaled_prod) = + ( reduce (+) 0 xs + , reduce (*) 1 (map (* 5) xs) + ) + in loop acc = row_sum + for i < 3 do + acc + i * 2 + scaled_prod) + xss diff --git a/tests/flattening/function-lifting/func_const.fut b/tests/flattening/function-lifting/func_const.fut new file mode 100644 index 0000000000..d6102298d5 --- /dev/null +++ b/tests/flattening/function-lifting/func_const.fut @@ -0,0 +1,22 @@ +-- Lifting a function with a constants as argument and result +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [7i64, 7i64,10i64,16i64,25i64,37i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (x : i64) (xs : []i64) : ([]i64, i64) = + let ys = map (x*) xs + in (ys, 7) + +#[noinline] +let foo (x : i64) = + let xs = iota x + let (ys, z) = bar 3 xs + in z + reduce (+) 0 ys + +def main (xs : []i64) = map foo xs + diff --git a/tests/flattening/function-lifting/func_free.fut b/tests/flattening/function-lifting/func_free.fut new file mode 100644 index 0000000000..433ba3156d --- /dev/null +++ b/tests/flattening/function-lifting/func_free.fut @@ -0,0 +1,27 @@ +-- Lifting a function with a free variables as argument and result +-- == +-- entry: main +-- input { [ 0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [280i64,294i64,308i64,322i64,336i64,350i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let v1 : []i64 = [5,9,6] + +#[noinline] +let v2 : []i64 = [3,1,4,1,5] + +#[noinline] +let bar (xs : []i64) (y : i64) : (i64, []i64) = + let z = y + reduce (+) 0 xs + in (z, copy v2) + +#[noinline] +let foo (x : i64) = + let (y, zs) = bar v1 x + let z = reduce (+) 0 zs + in (y * z) + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_fully_irreg.fut b/tests/flattening/function-lifting/func_fully_irreg.fut new file mode 100644 index 0000000000..27382dfba8 --- /dev/null +++ b/tests/flattening/function-lifting/func_fully_irreg.fut @@ -0,0 +1,23 @@ +-- Lifting a function with an irregular +-- parameter and return type +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 0i64, 3i64, 15i64,45i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (xs : []i64) : []i64 = + let y = reduce (+) 0 xs + in iota y + +#[noinline] +let foo (x : i64) = + let xs = iota x + let ys = bar xs + in reduce (+) 0 ys + +def main (xs : []i64) = map foo xs + diff --git a/tests/flattening/function-lifting/func_irreg_input.fut b/tests/flattening/function-lifting/func_irreg_input.fut new file mode 100644 index 0000000000..718895b430 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_input.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 1i64, 3i64, 6i64,10i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (xs : []i64) : i64 = reduce (+) 0 xs + +#[noinline] +let foo (x : i64) = + let xs = iota x + in bar xs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_irreg_result.fut b/tests/flattening/function-lifting/func_irreg_result.fut new file mode 100644 index 0000000000..0225943179 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_result.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- input { [0i64,1i64,2i64,3i64,4i64, 5i64] } +-- output { [0i64,0i64,1i64,3i64,6i64,10i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (x : i64) : []i64 = iota x + +#[noinline] +let foo (x : i64) = + let xs = bar x + in reduce (+) 0 xs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_irreg_update.fut b/tests/flattening/function-lifting/func_irreg_update.fut new file mode 100644 index 0000000000..b329759b91 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_update.fut @@ -0,0 +1,23 @@ +-- Lifting a function which consumes its argument +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 0i64, 1i64, 2i64, 4i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar [n] (xs : *[n]i64) (z : i64) (ys : [z]i64) : [n]i64 = + let m = n - z + in xs with [m:n] = ys + +#[noinline] +let foo (a : i64) = + let b = a / 2 + let xs = iota a + let ys = iota b :> [b]i64 + let zs = bar xs b ys + in reduce (+) 0 zs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_mix.fut b/tests/flattening/function-lifting/func_mix.fut new file mode 100644 index 0000000000..cca3be9c17 --- /dev/null +++ b/tests/flattening/function-lifting/func_mix.fut @@ -0,0 +1,25 @@ +-- Lifting a function with both regular and irregular +-- parameters and return types. +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, -1i64, 27i64, 252i64, 1175i64] } +-- input { [5i64, 4i64, 3i64, 2i64, 1i64, 0i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 9i64, 9i64, 0i64, 0i64, 0i64] } +-- input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (y : i64) (xs : []i64) : ([]i64, i64) = + let z = y * reduce (+) 0 xs + in (iota z, z) + +#[noinline] +let foo (a : i64) (b : i64) = + let xs = iota a + let (ys, z) = bar b xs + in reduce (+) 0 ys - z + +def main (as : []i64) (bs : []i64) = map2 foo as bs + diff --git a/tests/flattening/function-lifting/func_mix_nested.fut b/tests/flattening/function-lifting/func_mix_nested.fut new file mode 100644 index 0000000000..0431af1e94 --- /dev/null +++ b/tests/flattening/function-lifting/func_mix_nested.fut @@ -0,0 +1,31 @@ +-- Lifting a function that calls another function +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0i64, 1i64, 2i64, 3i64, 4i64] } +-- output { [0i64, 0i64, 0i64, 52290i64, 21935100i64] } +-- input { [5i64, 4i64, 3i64, 2i64, 1i64, 0i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 3990i64, 3990i64, 33i64, 0i64, 0i64] } +-- input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let baz (xs : []i64) (y : i64) : ([]i64, []i64) = + let z = y * reduce (+) 0 xs + in (iota y, iota z) + +#[noinline] +let bar (y : i64) (xs : []i64) : ([]i64, i64) = + let z = y * reduce (+) 0 xs + let (as, bs) = baz (iota z) z + let a = reduce (+) 0 as + in (bs, a) + +#[noinline] +let foo (a : i64) (b : i64) = + let xs = iota a + let (ys, z) = bar b xs + in reduce (+) 0 ys - z + +def main (as : []i64) (bs : []i64) = map2 foo as bs + diff --git a/tests/flattening/function-lifting/func_simple.fut b/tests/flattening/function-lifting/func_simple.fut new file mode 100644 index 0000000000..22002bbf3e --- /dev/null +++ b/tests/flattening/function-lifting/func_simple.fut @@ -0,0 +1,16 @@ +-- Lifting a simple function +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [1i64, 2i64, 3i64, 4i64, 5i64, 6i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +def bar (x : i64) = x + 1 + +#[noinline] +def foo (x : i64) = bar x + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/identity-irreg.fut b/tests/flattening/identity-irreg.fut new file mode 100644 index 0000000000..97b27c107b --- /dev/null +++ b/tests/flattening/identity-irreg.fut @@ -0,0 +1,7 @@ +-- == +-- entry: main +-- input { [[1.0f32, 2.0f32], [3.0f32, 4.0f32]] } +-- auto output + +def main [n] [m] (arr: [m][n]f32) : [m][n]f32 = + map (\i -> arr[i]) (iota m) diff --git a/tests/flattening/inner-map-result-order.fut b/tests/flattening/inner-map-result-order.fut new file mode 100644 index 0000000000..339d0f70fb --- /dev/null +++ b/tests/flattening/inner-map-result-order.fut @@ -0,0 +1,12 @@ +-- A test for preserving result order when flattening a distributed inner map. +-- == +-- input { [[1i32, 2i32], [3i32, 4i32]] 3i32} +-- auto output + +def main [n] [m] (xss: [n][m]i32) (k: i32) = + map (\xs -> + let a = map (+ 1) xs + let b = map (+ k) a + in (b, a)) + xss + |> unzip diff --git a/tests/flattening/inner_maposcanomap-both-free.fut b/tests/flattening/inner_maposcanomap-both-free.fut new file mode 100644 index 0000000000..f4c5574121 --- /dev/null +++ b/tests/flattening/inner_maposcanomap-both-free.fut @@ -0,0 +1,11 @@ +-- == +-- input { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32]] } +-- auto output + +entry main [n] [m] (xss: [n][m]i32) = + map (\xs -> + let a = xs[0] + let ys = map (+ a) xs + let zs = scan (+) 0 ys + in map (+ a) zs) + xss diff --git a/tests/flattening/inner_maposcanomap-no-free.fut b/tests/flattening/inner_maposcanomap-no-free.fut new file mode 100644 index 0000000000..22900b603e --- /dev/null +++ b/tests/flattening/inner_maposcanomap-no-free.fut @@ -0,0 +1,11 @@ +-- == +-- input { 3i32 [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32]] } +-- auto output + +entry main [n] [m] (k: i32) (xss: [n][m]i32) = + map (\xs -> + let ys = map (+ k) xs + let zs = scan (+) 0 ys + let ks = map2 (\z y -> z * 2 + y) zs ys + in (ks ++ ys ++ zs)) + xss diff --git a/tests/flattening/inner_maposcanomap-post-free.fut b/tests/flattening/inner_maposcanomap-post-free.fut new file mode 100644 index 0000000000..3bae8b97d4 --- /dev/null +++ b/tests/flattening/inner_maposcanomap-post-free.fut @@ -0,0 +1,11 @@ +-- == +-- input { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32]] } +-- auto output + +entry main [n] [m] (xss: [n][m]i32) = + map (\xs -> + let a = xs[0] + let ys = map (+ 1) xs + let zs = scan (+) 0 ys + in map (+ a) zs) + xss diff --git a/tests/flattening/inner_maposcanomap-pre-free.fut b/tests/flattening/inner_maposcanomap-pre-free.fut new file mode 100644 index 0000000000..8f59279fea --- /dev/null +++ b/tests/flattening/inner_maposcanomap-pre-free.fut @@ -0,0 +1,11 @@ +-- == +-- input { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32]] } +-- auto output + +entry main [n] [m] (xss: [n][m]i32) = + map (\xs -> + let a = xs[0] + let ys = map (+ a) xs + let zs = scan (+) 0 ys + in map (+ 1) zs) + xss diff --git a/tests/flattening/iota-index.fut b/tests/flattening/iota-index.fut new file mode 100644 index 0000000000..a21c5f4096 --- /dev/null +++ b/tests/flattening/iota-index.fut @@ -0,0 +1,10 @@ +-- iota is probably simplified away, but certs must be kept. +-- == +-- input { [1i64,2i64] [0,1] } +-- output { [0i64,1i64] } +-- input { [1i64,2i64] [0,2] } +-- error: out of bounds +-- input { [1i64,-2i64] [0,1] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i32) -> (iota n)[i]) diff --git a/tests/flattening/iota-opaque-index.fut b/tests/flattening/iota-opaque-index.fut new file mode 100644 index 0000000000..065c55d294 --- /dev/null +++ b/tests/flattening/iota-opaque-index.fut @@ -0,0 +1,9 @@ +-- == +-- input { [1i64,2i64] [0,1] } +-- output { [0i64,1i64] } +-- input { [1i64,2i64] [0,2] } +-- error: out of bounds +-- input { [1i64,-2i64] [0,1] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i32) -> (opaque (iota n))[i]) diff --git a/tests/flattening/iota-opaque-slice-red.fut b/tests/flattening/iota-opaque-slice-red.fut new file mode 100644 index 0000000000..a0bb5220f0 --- /dev/null +++ b/tests/flattening/iota-opaque-slice-red.fut @@ -0,0 +1,11 @@ +-- == +-- input { [1i64,2i64] [0i64,1i64] } +-- output { [0i64,1i64] } +-- input { [1i64,5i64] [0i64,3i64] } +-- output { [0i64,7i64] } +-- input { [1i64,2i64] [0i64,3i64] } +-- error: out of bounds +-- input { [1i64,-2i64] [0i64,1i64] } +-- error: out of bounds + +def main = map2 (\n (i: i64) -> i64.sum (opaque (iota n))[i:]) diff --git a/tests/flattening/iota-red.fut b/tests/flattening/iota-red.fut new file mode 100644 index 0000000000..ba2d5ea6fa --- /dev/null +++ b/tests/flattening/iota-red.fut @@ -0,0 +1,7 @@ +-- == +-- input { [0i64,1i64,2i64] } +-- output { [0i64, 0i64, 1i64] } +-- input { [0i64,1i64,-2i64] } +-- error: Range 0..1..<-2 is invalid + +def main = map (\n -> i64.sum (iota n)) diff --git a/tests/flattening/iota-scan.fut b/tests/flattening/iota-scan.fut new file mode 100644 index 0000000000..f32604e4ee --- /dev/null +++ b/tests/flattening/iota-scan.fut @@ -0,0 +1,8 @@ +-- == +-- input { [3i64,4i64,5i64] [1i64,2i64,3i64] } +-- auto output +-- input { [1i64,2i64] [0i64,3i64] } +-- error: out of bounds + +def main = +map2 (\n (i: i64) -> (scan (+) 0 (iota n))[i]) \ No newline at end of file diff --git a/tests/flattening/issue1143.fut b/tests/flattening/issue1143.fut new file mode 100644 index 0000000000..3f4b3a9fa1 --- /dev/null +++ b/tests/flattening/issue1143.fut @@ -0,0 +1,47 @@ +def dotprod [n] (xs: [n]f32) (ys: [n]f32) : f32 = + reduce (+) 0.0 (map2 (*) xs ys) + +def house [d] (x: [d]f32) : ([d]f32, f32) = + let dot = dotprod x x + let dot' = dot - x[0] ** 2 + x[0] ** 2 + let beta = if dot' != 0 then 2.0 / dot' else 0 + in (x, beta) + +def matmul [n] [p] [m] (xss: [n][p]f32) (yss: [p][m]f32) : [n][m]f32 = + map (\xs -> map (dotprod xs) (transpose yss)) xss + +def outer [n] [m] (xs: [n]f32) (ys: [m]f32) : [n][m]f32 = + matmul (map (\x -> [x]) xs) [ys] + +def matsub [m] [n] (xss: [m][n]f32) (yss: [m][n]f32) : *[m][n]f32 = + map2 (\xs ys -> map2 (-) xs ys) xss yss + +def matadd [m] [n] (xss: [m][n]f32) (yss: [m][n]f32) : [m][n]f32 = + map2 (\xs ys -> map2 (+) xs ys) xss yss + +def matmul_scalar [m] [n] (xss: [m][n]f32) (k: f32) : *[m][n]f32 = + map (map (* k)) xss + +def block_householder [m] [n] (A: [m][n]f32) (r: i64) : ([][]f32, [][]f32) = + #[unsafe] + let Q = replicate m (replicate m 0) + let (Q, A) = + loop (Q, A) = (Q, copy A) + for k in 0..<(n / r) do + let s = k * r + let V = replicate m (replicate r 0f32) + let Bs = replicate r 0f32 + let (A) = + loop (A) for j in 0.. block_householder arr r) arrs diff --git a/tests/flattening/iswim0.fut b/tests/flattening/iswim0.fut new file mode 100644 index 0000000000..3eb0aff4e1 --- /dev/null +++ b/tests/flattening/iswim0.fut @@ -0,0 +1,16 @@ +-- This test fails if the ISWIM transformation messes up the size +-- annotations. +-- == +-- input { +-- [1,1,1] +-- [[1,2,3],[4,5,6],[7,8,9],[0,1,2],[3,4,5]] +-- } +-- output { +-- [[3i32, 6i32, 11i32], [54i32, 162i32, 418i32], [2754i32, 10692i32, 34694i32], [5508i32, 32076i32, 208164i32], [60588i32, 577368i32, 5620428i32]] +-- } +def combineVs [n] (n_row: [n]i32) : [n]i32 = + map2 (*) n_row n_row + +def main [n] [m] (md_starts: [m]i32) (md_vols: [n][m]i32) : [][]i32 = + let e_rows = map (\x -> map (+ 2) x) (map combineVs md_vols) + in scan (\x y -> map2 (*) x y) md_starts e_rows diff --git a/tests/flattening/iswim1.fut b/tests/flattening/iswim1.fut new file mode 100644 index 0000000000..790584af3c --- /dev/null +++ b/tests/flattening/iswim1.fut @@ -0,0 +1,9 @@ +-- Segmented scan with array operator (interchangeable). +-- == +-- random input { [10][1][10]i32 } auto output +-- random input { [10][10][1]i32 } auto output +-- random input { [10][10][10]i32 } auto output +-- structure gpu { /SegScan 1 /SegScan/Loop 0 } + +def main [n] [m] [k] (xss: [n][m][k]i32) = + map (scan (map2 (+)) (replicate k 0)) xss diff --git a/tests/flattening/loop_inner_map_reduce.fut b/tests/flattening/loop_inner_map_reduce.fut new file mode 100644 index 0000000000..d6f79c9c4f --- /dev/null +++ b/tests/flattening/loop_inner_map_reduce.fut @@ -0,0 +1,12 @@ +-- == +-- input { [[1i64,2i64,3i64],[4i64,5i64,6i64]] } +-- auto output + +def main [n][m] (xss : [n][m]i64) = + map (\xs -> + let d = loop (ys, s) = (xs, 0) for i < 4 do + let ys' = map (*2) ys + let s' = s + i64.sum ys' + in (ys', s') + in d.1 + ) xss \ No newline at end of file diff --git a/tests/flattening/loop_simple.fut b/tests/flattening/loop_simple.fut new file mode 100644 index 0000000000..de7b3ccf01 --- /dev/null +++ b/tests/flattening/loop_simple.fut @@ -0,0 +1,12 @@ +-- For loop nested in map with statements before and after the loop. +-- == +-- input { [1i64, 2i64, 3i64] } +-- auto output + +def main [n] (xs : [n]i64) = + map (\x -> + let y = x * 2 + let z = loop acc = y for i < 4 do + acc + i + in z * x + ) xs diff --git a/tests/flattening/map-nested-deeper.fut b/tests/flattening/map-nested-deeper.fut new file mode 100644 index 0000000000..f941f80e24 --- /dev/null +++ b/tests/flattening/map-nested-deeper.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64,7i64] [[5],[7]] } +-- output { [7,9] } + +def main = map2 (\n xs -> + #[unsafe] + let A = #[opaque] replicate n xs + let B = #[opaque] map (\x -> (opaque x)[0]+2i32) A + in B[0]) diff --git a/tests/flattening/map-nested-free2d.fut b/tests/flattening/map-nested-free2d.fut new file mode 100644 index 0000000000..57af621dd1 --- /dev/null +++ b/tests/flattening/map-nested-free2d.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64,7i64] [5i64,7i64] [3i64,2i64] } +-- output { [3i64, 2i64] } + +def main = map3 (\n m x -> + #[unsafe] + let A = #[opaque] replicate n (replicate m x) + let B = #[opaque] map (\i -> A[i%x,i%x]) (iota n) + in B[0]) diff --git a/tests/flattening/map-nested.fut b/tests/flattening/map-nested.fut new file mode 100644 index 0000000000..3942a7868d --- /dev/null +++ b/tests/flattening/map-nested.fut @@ -0,0 +1,5 @@ +-- == +-- input { [5i64,7i64] } +-- output { [20i64, 35i64] } + +def main = map (\n -> i64.sum (map (+2) (iota n))) diff --git a/tests/flattening/map-opaque-reduce-map.fut b/tests/flattening/map-opaque-reduce-map.fut new file mode 100644 index 0000000000..df4f3d8a4a --- /dev/null +++ b/tests/flattening/map-opaque-reduce-map.fut @@ -0,0 +1,11 @@ +-- == +-- input { [[0i32, 1i32, 2i32], [3i32, 4i32, 5i32]] } +-- auto output + +def main (xss: [][]i32) = + #[incremental_flattening(only_intra)] + map (\xs -> + let ys = map (+ 1) xs |> opaque + let s = reduce (+) 0 ys + in map (+ s) xs) + xss \ No newline at end of file diff --git a/tests/flattening/map-slice-nested.fut b/tests/flattening/map-slice-nested.fut new file mode 100644 index 0000000000..0b01ac7880 --- /dev/null +++ b/tests/flattening/map-slice-nested.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1i64,2i64,3i64,4i64,5i64] [-5i64,7i64] [2i64,3i64] [3i64,4i64] } +-- output { [-2i64, 11i64] } + +def main A = map3 (\x i j -> i64.sum (map (+x) A[i:j])) diff --git a/tests/flattening/mapmultidim-irreg-pat.fut b/tests/flattening/mapmultidim-irreg-pat.fut new file mode 100644 index 0000000000..c2ad25ab2b --- /dev/null +++ b/tests/flattening/mapmultidim-irreg-pat.fut @@ -0,0 +1,5 @@ +def main [n] [m] (xs: [n]i64) (ys: [m]i64) = + map (\x -> + let rows = map (\y -> opaque (replicate x y)) ys + in map (\row -> reduce (+) 0 (opaque row)) rows) + xs diff --git a/tests/flattening/mapout.fut b/tests/flattening/mapout.fut new file mode 100644 index 0000000000..67ee76a39c --- /dev/null +++ b/tests/flattening/mapout.fut @@ -0,0 +1,11 @@ +-- A redomap where part of the result is not reduced. +-- == +-- input { [5i64,7i64] [0i64,1i64] } +-- output { [20i64, 35i64] [0i64, 1i64] } + +def main ns is = map2 (\n (i:i64) -> let is = iota n + let xs = map (+2) is + let ys = map (*i) is + in (i64.sum xs, (opaque ys)[i])) + ns is + |> unzip diff --git a/tests/flattening/match-case/if.fut b/tests/flattening/match-case/if.fut new file mode 100644 index 0000000000..cc1594d32e --- /dev/null +++ b/tests/flattening/match-case/if.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- nobench input { [-1i64,1i64,-2i64,2i64,-3i64,3i64] } +-- output { [ 1i64,2i64, 4i64,4i64, 9i64,6i64] } +-- nobench input { [-5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [25i64, 9i64,8i64,4i64,0i64, 1i64,6i64,2i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 2i64, 4i64, 6i64, 8i64,10i64] } +-- nobench input { [-1i64,-2i64,-3i64,-4i64,-5i64] } +-- output { [ 1i64, 4i64, 9i64,16i64,25i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) = if x < 0 then x * x else x * 2 + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_2d_irreg.fut b/tests/flattening/match-case/if_2d_irreg.fut new file mode 100644 index 0000000000..86cbc8b287 --- /dev/null +++ b/tests/flattening/match-case/if_2d_irreg.fut @@ -0,0 +1,9 @@ +-- == +-- input { [[5i64, 10i64, 15i64], [3i64, 4i64, 15i64]] [1i64, 4i64] } +-- auto output +def main [n] [m] (xss: [n][m]i64) (ys: [n]i64) = + map (\xs -> + map (\x -> + opaque (if x >= 10 then map (* x) ys else map (+ x) ys)) + xs) + xss diff --git a/tests/flattening/match-case/if_2d_irreg2.fut b/tests/flattening/match-case/if_2d_irreg2.fut new file mode 100644 index 0000000000..92f47faedb --- /dev/null +++ b/tests/flattening/match-case/if_2d_irreg2.fut @@ -0,0 +1,9 @@ +-- == +-- input { [[5i64, 10i64, 15i64], [3i64, 4i64, 15i64]] [[1i64, 4i64], [2i64, 3i64]] } +-- auto output +def main [n] [m] [k] (xss: [n][m]i64) (yss: [n][k]i64) = + map (\xs -> + map (\x -> + opaque (if x >= 10 then map (map (* x)) yss else map (map (+ x)) yss)) + xs) + xss diff --git a/tests/flattening/match-case/if_fully_irreg.fut b/tests/flattening/match-case/if_fully_irreg.fut new file mode 100644 index 0000000000..4efbbe1771 --- /dev/null +++ b/tests/flattening/match-case/if_fully_irreg.fut @@ -0,0 +1,24 @@ +-- == +-- entry: main +-- nobench input { [ 2i64, 7i64, 1i64, 8i64, 7i64] } +-- output { [ 2i64, 23i64, 0i64, 31i64, 23i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 2i64, 6i64, 12i64, 20i64] } +-- nobench input { [ 6i64, 7i64, 8i64, 9i64, 10i64] } +-- output { [16i64, 23i64, 31i64, 40i64, 50i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let bar [n] (xs : [n]i64) = + if n <= 5 then (false, xs) + else (true, copy xs with [5] = n) + +#[noinline] +let foo (x : i64) = + let xs = iota x in + let (b, ys) = bar xs + let z = reduce (+) 0 ys + in if b then z else z * 2 + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_invariant.fut b/tests/flattening/match-case/if_invariant.fut new file mode 100644 index 0000000000..f699ba1f73 --- /dev/null +++ b/tests/flattening/match-case/if_invariant.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64, 10i64, 15i64] [1i64, 4i64, 3i64] 5i64 } +-- auto output +-- input { [5i64, 10i64, 15i64] [1i64, 4i64, 3i64] 6i64 } +-- auto output +def main [n] (xs: [n]i64) (ys: [n]i64) (b: i64) = + map (\x -> + opaque (if b == 5 then map (* x) ys else map (+ x) ys)) + xs diff --git a/tests/flattening/match-case/if_invariant_2irreg.fut b/tests/flattening/match-case/if_invariant_2irreg.fut new file mode 100644 index 0000000000..1f8f246d10 --- /dev/null +++ b/tests/flattening/match-case/if_invariant_2irreg.fut @@ -0,0 +1,18 @@ +-- == +-- input { [5i64, 10i64, 15i64] [10i64, 8i64, 6i64] 5i64 } +-- auto output +-- input { [5i64, 10i64, 15i64] [10i64, 8i64, 6i64] 0i64 } +-- auto output +def main [n] (xs: [n]i64) (ys: [n]i64) (b: i64) = + map2 (\x y -> + let (if_res, if_res2) = + opaque (if b == 5 + then let zs = iota x + let zs' = iota y + in (map (+ x) zs, map (* y) zs') + else let zs = iota x + let zs' = iota y + in (map (+ x) zs, map (+ y) zs')) + in if_res[1] + if_res2[1]) + xs + ys diff --git a/tests/flattening/match-case/if_invariant_2irreg_reg.fut b/tests/flattening/match-case/if_invariant_2irreg_reg.fut new file mode 100644 index 0000000000..6599b086c1 --- /dev/null +++ b/tests/flattening/match-case/if_invariant_2irreg_reg.fut @@ -0,0 +1,22 @@ +-- == +-- input { [5i64, 10i64, 15i64] [1i64, 4i64, 3i64] 5i64} +-- auto output +-- input { [5i64, 10i64, 15i64] [1i64, 4i64, 3i64] 1i64} +-- auto output +def main [n] (xs: [n]i64) (ys: [n]i64) (b: i64) = + map2 (\x y -> + let (if_res, if_res2,if_res3) = + opaque (if b == 5 + then let zs = iota x + let zs' = iota y + let lit0 = [x,4,5,x] + let lit0' = map (+ x) lit0 + in (map (* x) zs, map (* y) zs', lit0') + else let zs = iota x + let zs' = iota y + let lit1 = [y,4,5,y,100] + let lit1' = map (* y) lit1 + in (map (+ x) zs, map (+ y) zs', lit1')) + in i64.sum (if_res ++ if_res2 ++ if_res3)) + xs + ys diff --git a/tests/flattening/match-case/if_invariant_dif_branch.fut b/tests/flattening/match-case/if_invariant_dif_branch.fut new file mode 100644 index 0000000000..e5614fb29d --- /dev/null +++ b/tests/flattening/match-case/if_invariant_dif_branch.fut @@ -0,0 +1,17 @@ +-- == +-- input { [5i64, 10i64, 15i64] [10i64, 14i64, 8i64] 5i64 } +-- auto output +-- input { [5i64, 10i64, 15i64, 3i64] [10i64, 14i64, 8i64, 100i64] 1i64 } +-- auto output +def main [n] (xs: [n]i64) (ys: [n]i64) (b: i64) = + map2 (\x y -> + let (if_res) = + opaque (if b == 5 + then let zs = iota x + in (map (+ x) zs) + else let lit1 = [y, 4, 5, y, 100] + let lit1' = map (* y) lit1 + in (lit1')) + in if_res[4]) + xs + ys diff --git a/tests/flattening/match-case/if_invariant_irreg.fut b/tests/flattening/match-case/if_invariant_irreg.fut new file mode 100644 index 0000000000..34d71d39cf --- /dev/null +++ b/tests/flattening/match-case/if_invariant_irreg.fut @@ -0,0 +1,17 @@ +-- == +-- input { [5i64, 10i64, 15i64] 5i64 } +-- auto output +-- input { [5i64, 10i64, 15i64, 20i64, 25i64] 50i64 } +-- auto output +def main [n] (xs: [n]i64) (b: i64) = + map (\x -> + let if_res = + opaque (if b == 5 + then let zs = iota x + let zs = map (\e -> e + 5 + x) zs + in map (+ x) zs + else let zs = iota x + let zs = map (\e -> e * 5 + x) zs + in map (+ x) zs) + in if_res[0]) + xs diff --git a/tests/flattening/match-case/if_irreg_input.fut b/tests/flattening/match-case/if_irreg_input.fut new file mode 100644 index 0000000000..8b4a164b0b --- /dev/null +++ b/tests/flattening/match-case/if_irreg_input.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- nobench input { [-5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [-1i64,-1i64,6i64,1i64,0i64,-1i64,3i64,0i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 1i64, 3i64, 6i64,10i64] } +-- nobench input { [-1i64,-2i64,-3i64,-4i64,-5i64] } +-- output { [-1i64,-1i64,-1i64,-1i64,-1i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) = + let ys = iota (i64.abs x) + in if x < 0 then -1 else reduce (+) 0 ys + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_irreg_result.fut b/tests/flattening/match-case/if_irreg_result.fut new file mode 100644 index 0000000000..218780ae1e --- /dev/null +++ b/tests/flattening/match-case/if_irreg_result.fut @@ -0,0 +1,20 @@ +-- == +-- entry: main +-- nobench input { [ -5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [300i64,36i64,6i64,1i64,0i64, 0i64,3i64,0i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 1i64, 3i64, 6i64, 10i64] } +-- nobench input { [ 1i64,-2i64,-3i64, -4i64, -5i64] } +-- output { [ 0i64, 6i64,36i64,120i64,300i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let bar (x : i64) = if x < 0 then iota (x*x) else iota x + +#[noinline] +let foo (x : i64) = + let ys = bar x + in reduce (+) 0 ys + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/match_fully_irreg.fut b/tests/flattening/match-case/match_fully_irreg.fut new file mode 100644 index 0000000000..72f4b6b2c5 --- /dev/null +++ b/tests/flattening/match-case/match_fully_irreg.fut @@ -0,0 +1,25 @@ +-- == +-- entry: main +-- nobench input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 2i64, 2i64, 2i64] [0i64, 1i64, 2i64, 0i64, 1i64, 2i64, 0i64, 1i64, 2i64] } +-- output { [7i64, -5i64, -4i64, 2i64, -1i64, -1i64, 1i64, -1i64, 2i64] } +-- nobench input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 2i64, 2i64, 2i64] [2i64, 2i64, 2i64, 1i64, 1i64, 1i64, 0i64, 0i64, 0i64] } +-- output { [-4i64, -4i64, -4i64, -1i64, -1i64, -1i64, 1i64, 1i64, 1i64] } +-- nobench input { [1i64, 2i64, 3i64] [4i64, 5i64, 6i64] } +-- output { [2i64, 35i64, 135i64] } +-- nobench input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) (y : i64) (zs : []i64) = + let (a, as) = + match (x, y) + case (0,0) -> (3,iota 5) + case (0,b) -> (5,iota b) + case (a,0) -> (a,iota 3) + case (a,b) -> (a*b, zs) + in reduce (+) 0 as - a + +let bar (x : i64) (y : i64) = + let zs = iota (x * y) in foo x y zs + +def main [n] (xs : [n]i64) (ys : [n]i64) = map2 bar xs ys diff --git a/tests/flattening/matmul.fut b/tests/flattening/matmul.fut new file mode 100644 index 0000000000..4f191a906b --- /dev/null +++ b/tests/flattening/matmul.fut @@ -0,0 +1,9 @@ +-- == +-- input { [[1i64, 2i64, 3i64],[4i64, 5i64, 6i64]] [[7i64, 8i64],[9i64, 10i64], [11i64, 12i64]] } +-- auto output +def main [n] [m] [p] (A: [n][m]i64) (B: [m][p]i64) : [n][p]i64 = + map (\A_row -> + map (\B_col -> + #[sequential] reduce (+) 0 (map2 (*) A_row B_col)) + (transpose B)) + A diff --git a/tests/flattening/nested-map-iota.fut b/tests/flattening/nested-map-iota.fut new file mode 100644 index 0000000000..d10a9419fe --- /dev/null +++ b/tests/flattening/nested-map-iota.fut @@ -0,0 +1,15 @@ +-- Nested maps over irregular iota arrays +-- == +-- input { [2i64, 3i64, 5i64] [1i64, 2i64] } +-- auto output +def main (xs) (ys) = + map (\x -> + let zs = iota x + let some_res = + map (\y -> + let zs' = map (+ y) zs + let res = i64.sum zs' + in res) + ys + in i64.sum some_res) + xs diff --git a/tests/flattening/problem-Inner0.fut b/tests/flattening/problem-Inner0.fut new file mode 100644 index 0000000000..8454abc96d --- /dev/null +++ b/tests/flattening/problem-Inner0.fut @@ -0,0 +1,13 @@ +-- == +-- input { [3i64, 7i64, 10i64, 2i64, 20i64] } +-- auto output +-- input { [3i64,55] } +-- auto output +-- input { [3i64, 7i64, 10i64] } +-- auto output +def main (xs: []i64) = + map (\x -> + let res = opaque (iota x) + let mes = map (+5) res + in mes[1]) + xs \ No newline at end of file diff --git a/tests/flattening/problem-Inner1.fut b/tests/flattening/problem-Inner1.fut new file mode 100644 index 0000000000..a2cb0a172c --- /dev/null +++ b/tests/flattening/problem-Inner1.fut @@ -0,0 +1,15 @@ +-- == +-- input { 4i64 [1i64, 2i64, 3i64] [4i64, 5i64, 6i64,9i64, 5i64] } +-- input { 3i64 [5i64, 2i64, 3i64,3i64,3i64] [4i64, 5i64, 6i64,9i64, 5i64] } +-- auto output +def main [n] [m] (o: i64) (xs: [n]i64) (ys: [m]i64) = + map (\j -> + map (\x -> + map (\y -> + let row_y = replicate y j with [1] = j + x + y + let mat = replicate x (replicate y (j + 4)) with [0] = row_y + let mat2 = opaque (map (\row -> map (+1) row) mat) + in mat2[0][0]) + ys) + xs) + (iota o) diff --git a/tests/flattening/range-irreg-stride.fut b/tests/flattening/range-irreg-stride.fut new file mode 100644 index 0000000000..464fd4acef --- /dev/null +++ b/tests/flattening/range-irreg-stride.fut @@ -0,0 +1,7 @@ +-- == +-- input { 10i64 [1,2] } +-- output { [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +-- [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]] +-- } + +def main k = map (\s -> (0..s.. [k]i32) diff --git a/tests/flattening/range-opaque-red.fut b/tests/flattening/range-opaque-red.fut new file mode 100644 index 0000000000..a2ca853795 --- /dev/null +++ b/tests/flattening/range-opaque-red.fut @@ -0,0 +1,7 @@ +-- == +-- input { [1i64,2i64] [3i64,3i64] [10i64,8i64] } +-- output { [25i64, 27i64] } +-- input { [1i64,2i64] [3i64,2i64] [10i64,-8i64] } +-- error: Range 2..2..<-8 is invalid + +def main = map3 (\a b c -> i64.sum (opaque (a..b.. + let block = xs[i:, i:r] + let colsums = map i64.sum (transpose block) + in colsums[0]) + is xss diff --git a/tests/flattening/rearrange0.fut b/tests/flattening/rearrange0.fut new file mode 100644 index 0000000000..4bbddbfdf5 --- /dev/null +++ b/tests/flattening/rearrange0.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[[1,2],[4,5],[7,8]],[[8,4],[5,1],[7,2]]] } +-- output { [5,1] } + +def main (xsss: [][][]i32) = map (\xs -> (opaque (transpose (opaque xs)))[1,1]) xsss diff --git a/tests/flattening/rearrange1.fut b/tests/flattening/rearrange1.fut new file mode 100644 index 0000000000..371768b399 --- /dev/null +++ b/tests/flattening/rearrange1.fut @@ -0,0 +1,5 @@ +-- == +-- input { [3i64,4i64] } +-- output { [1i64,1i64] } + +def main = map (\n -> ((transpose (replicate (n+1) (iota n))))[1,1]) diff --git a/tests/flattening/red-arr-res.fut b/tests/flattening/red-arr-res.fut new file mode 100644 index 0000000000..beb489ed8f --- /dev/null +++ b/tests/flattening/red-arr-res.fut @@ -0,0 +1,11 @@ +-- failing now +-- == +-- input { 3i64 4i64 } +-- auto output + +def main (m: i64) (n: i64) : [n][n]f32 = + tabulate n (\j -> + let zeros = replicate m (replicate n 0.0f32) + let row = replicate n 10.0f32 with [j] = 1.0f32 + let updated = zeros with [0] = row + in reduce (map2 (+)) (replicate n 0.0f32) updated) \ No newline at end of file diff --git a/tests/flattening/redomap1.fut b/tests/flattening/redomap1.fut deleted file mode 100644 index 283429d18c..0000000000 --- a/tests/flattening/redomap1.fut +++ /dev/null @@ -1,17 +0,0 @@ --- == --- input { --- [[1,2,3],[1,2,3]] --- [[3,2,1],[6,7,8]] --- } --- output { --- [12, 27] --- } -def main [m] [n] (xss: [m][n]i32) (yss: [m][n]i32) : [m]i32 = - let final_res = - map (\(xs: [n]i32, ys: [n]i32) : i32 -> - let tmp = - map (\(x: i32, y: i32) : i32 -> x + y) - (zip xs ys) - in reduce (+) 0 tmp) - (zip xss yss) - in final_res diff --git a/tests/flattening/redomap2.fut b/tests/flattening/redomap2.fut deleted file mode 100644 index 42f6df5591..0000000000 --- a/tests/flattening/redomap2.fut +++ /dev/null @@ -1,13 +0,0 @@ --- == --- input { --- [1,2,3] --- [6,7,8] --- } --- output { --- 27 --- } -def main [n] (xs: [n]i32) (ys: [n]i32) : i32 = - let tmp = - map (\(x: i32, y: i32) : i32 -> x + y) - (zip xs ys) - in reduce (+) 0 tmp diff --git a/tests/flattening/reduce-non-primitive.fut b/tests/flattening/reduce-non-primitive.fut new file mode 100644 index 0000000000..2a2fb5cde1 --- /dev/null +++ b/tests/flattening/reduce-non-primitive.fut @@ -0,0 +1,8 @@ +-- Currently fails +-- == +-- input { [1i64,2i64,3i64] } +-- output { [[1i64, 0i64], [4i64, 1i64], [9i64, 3i64]]} + +def main (xs : []i64) = + map (\x -> let ys = iota x + in reduce (\y1 y2 -> [y1[0] + y2[0], y1[1] + y2[1]]) [0,0] (map (\y -> [x, y]) ys)) xs \ No newline at end of file diff --git a/tests/flattening/replicate-multidim-fail.fut b/tests/flattening/replicate-multidim-fail.fut new file mode 100644 index 0000000000..64671e4aa5 --- /dev/null +++ b/tests/flattening/replicate-multidim-fail.fut @@ -0,0 +1,12 @@ +-- == +-- input { 4i64 [1i64,2i64] [3i64,4i64] } +-- auto output +def main [n] [m] (o: i64) (xs: [n]i64) (ys: [m]i64) = + map (\x -> + map (\y -> + let row_y = replicate y x with [1] = x + o + let z = map (+ 10) row_y + let mat = replicate x (row_y) with [0] = z + in mat[0][0]) + ys) + xs diff --git a/tests/flattening/replicate-multidim-var.fut b/tests/flattening/replicate-multidim-var.fut new file mode 100644 index 0000000000..92124e1d92 --- /dev/null +++ b/tests/flattening/replicate-multidim-var.fut @@ -0,0 +1,16 @@ +-- == +-- input { 5i64 4i64 } +-- auto output +def main (m: i64) (n: i64) = + map (\j -> + let row_j = replicate n j with [j] = 7i64 + let mat = replicate m (replicate n m) with [0] = row_j + in map (\row -> + let s = mat[0][1] + let rs = row[0] + let ds = iota j + let y = map (\d -> d + mat[0][j] + rs + s) ds + let sum = reduce (+) 0 y + in map (\x -> x + sum + 3 + rs + s) row) + mat) + (iota n) diff --git a/tests/flattening/replicate-multidim-var2.fut b/tests/flattening/replicate-multidim-var2.fut new file mode 100644 index 0000000000..5f04138063 --- /dev/null +++ b/tests/flattening/replicate-multidim-var2.fut @@ -0,0 +1,15 @@ +-- == +-- input { 5i64 4i64 } +-- auto output +def main (m: i64) (n: i64) = + map (\j -> + let ds = iota j + let fs = map (\d -> d + 3) ds + let tmp = replicate m (replicate n fs) + let mat_3d = tmp with [0][0][0] = ds[0] + 3 + -- let mat_3d_2 = + -- map (\mat -> + -- let val_3d = mat[0][0][0] + -- in map (\row -> map (\elem -> elem + val_3d) row) mat) mat_3d + in mat_3d[0][0][0]) + (map (+3) (iota n)) diff --git a/tests/flattening/replicate-multidim-var3.fut b/tests/flattening/replicate-multidim-var3.fut new file mode 100644 index 0000000000..0af06c6b56 --- /dev/null +++ b/tests/flattening/replicate-multidim-var3.fut @@ -0,0 +1,16 @@ +-- == +-- input { 4i64 [1i64, 2i64, 3i64] [4i64, 5i64, 6i64,9i64, 5i64] } +-- input { 3i64 [5i64, 2i64, 3i64,3i64,3i64] [4i64, 5i64, 6i64,9i64, 5i64] } +-- auto output +def main [n] [m] (o: i64) (xs: [n]i64) (ys: [m]i64) = + map (\j -> + map (\x -> + map (\y -> + let row_y = replicate y j with [1] = j + x + y + let mat = replicate x (replicate y (j + 4)) with [0] = row_y + -- let d = map (\row -> i64.sum row) mat + -- in i64.sum d) + in mat[0][0]) + ys) + xs) + (iota o) diff --git a/tests/flattening/replicate-multidim-var4.fut b/tests/flattening/replicate-multidim-var4.fut new file mode 100644 index 0000000000..1ee2f9cede --- /dev/null +++ b/tests/flattening/replicate-multidim-var4.fut @@ -0,0 +1,20 @@ +-- == +-- input { 4i64 [1i64, 2i64, 3i64] [4i64, 5i64, 6i64,9i64, 5i64] } +-- input { 3i64 [5i64, 2i64, 3i64,3i64,3i64] [4i64, 5i64, 6i64,9i64, 5i64] } +-- auto output +def main [n] [m] (o: i64) (xs: [n]i64) (ys: [m]i64) = + map (\j -> + map (\x -> + map (\y -> + let row_y = replicate y j with [1] = j + x + y + let mat = replicate x (replicate y (j + 4)) with [0] = row_y + let mat2 = + map (\row -> + let s = mat[0][0] + let rs = row[0] + in map (\x -> x + 3 + rs + s) row) + mat + in mat2[0][0]) + ys) + xs) + (iota o) \ No newline at end of file diff --git a/tests/flattening/replicate-multidim.fut b/tests/flattening/replicate-multidim.fut new file mode 100644 index 0000000000..fec70707a3 --- /dev/null +++ b/tests/flattening/replicate-multidim.fut @@ -0,0 +1,9 @@ +-- == +-- input { 3i64 4i64 } +-- auto output +def main (m: i64) (n: i64) : [n][m]f32 = + map (\j -> + let row_j = replicate n 0.0f32 with [j] = 7.0f32 + let mat = replicate m (replicate n 5.0f32) with [0] = row_j + in map (\row -> reduce (+) 0 row) mat + ) (iota n) diff --git a/tests/flattening/replicate0.fut b/tests/flattening/replicate0.fut new file mode 100644 index 0000000000..1e34ee3240 --- /dev/null +++ b/tests/flattening/replicate0.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1i64,2i64] [0, 1] [4,5] } +-- output { [4,5] } + +def main = map3 (\n (i:i32) (x:i32) -> let A = opaque (replicate n x) + in #[unsafe] A[i]) diff --git a/tests/flattening/replicate1.fut b/tests/flattening/replicate1.fut new file mode 100644 index 0000000000..9d61bf17d3 --- /dev/null +++ b/tests/flattening/replicate1.fut @@ -0,0 +1,7 @@ +-- Now we are replicating a regular array. +-- == +-- input { [1i64,2i64] [0, 1] [[4,5],[5,6]] } +-- output { [[4,5],[5,6]] } + +def main = map3 (\n (i:i32) (x:[2]i32) -> let A = opaque (replicate n x) + in #[unsafe] A[i]) diff --git a/tests/flattening/replicate2.fut b/tests/flattening/replicate2.fut new file mode 100644 index 0000000000..5f19f2ccc4 --- /dev/null +++ b/tests/flattening/replicate2.fut @@ -0,0 +1,9 @@ +-- == +-- input { [2i64,5i64] [3i64,4i64] [0,1] [[5,4], [4,5]] } +-- output { [[5, 4], [4, 5]] } + +def main = + map4 \n m (i: i32) (x: [2]i32) -> + let A = opaque (replicate n x) + let B = opaque (replicate m A) + in #[unsafe] B[i, 0] diff --git a/tests/flattening/scan-map-reduce.fut b/tests/flattening/scan-map-reduce.fut new file mode 100644 index 0000000000..bb587455fd --- /dev/null +++ b/tests/flattening/scan-map-reduce.fut @@ -0,0 +1,12 @@ +-- Currently fails +-- == +-- input { [5i64,7i64] [0i64,1i64] } +-- auto output +def main ns is = map2 (\n (i:i64) -> let is = iota n + let xs = map (+2) is + -- let ts = scan (*) 1 xs + let ys = map (*i) is + let zs = scan (+) 0 ys + in (i64.sum xs, (opaque zs)[i])) + ns is + |> unzip diff --git a/tests/flattening/scan-map.fut b/tests/flattening/scan-map.fut new file mode 100644 index 0000000000..b8236afe12 --- /dev/null +++ b/tests/flattening/scan-map.fut @@ -0,0 +1,6 @@ +-- == +-- input { [[1i64,2i64,3i64],[4i64,5i64,6i64]] } +-- auto output + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> scan (+) 0 (map (* 2) xs)) xss \ No newline at end of file diff --git a/tests/flattening/simple-2d-map.fut b/tests/flattening/simple-2d-map.fut new file mode 100644 index 0000000000..f8ca552868 --- /dev/null +++ b/tests/flattening/simple-2d-map.fut @@ -0,0 +1,10 @@ +-- == +-- input { [1i32, 2i32, 3i32] [4i32, 5i32] [7i32, 8i32, 9i32] } +-- auto output + +def main [n] [m] [p] (xs: [n]i32) (ys: [m]i32) (zs: [p]i32) : [n][m][p]i32 = + map (\x -> + map (\y -> + map (\z -> x - z + y) zs) + ys) + xs \ No newline at end of file diff --git a/tests/flattening/simple-scan.fut b/tests/flattening/simple-scan.fut new file mode 100644 index 0000000000..ba2216e16b --- /dev/null +++ b/tests/flattening/simple-scan.fut @@ -0,0 +1,6 @@ +-- == +-- input { [[1i64,2i64,3i64],[4i64,5i64,6i64]] } +-- output { [[1i64,3i64,6i64],[4i64,9i64,15i64]] } + +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> scan (+) 0 xs) xss \ No newline at end of file diff --git a/tests/flattening/slice-red.fut b/tests/flattening/slice-red.fut new file mode 100644 index 0000000000..4362860300 --- /dev/null +++ b/tests/flattening/slice-red.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[0i64,1i64,5i64],[-2i64,9i64,1i64]] [0i64,1i64] } +-- output { [6i64,10i64] } + +def main = map2 (\A (i:i64) -> i64.sum A[i:]) diff --git a/tests/flattening/slice2d-red.fut b/tests/flattening/slice2d-red.fut new file mode 100644 index 0000000000..ad2bc650e5 --- /dev/null +++ b/tests/flattening/slice2d-red.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[[0i64,1i64],[4i64,5i64]],[[-2i64,9i64],[9i64,2i64]]] [0i64,1i64] [1i64,0i64] } +-- output { [6i64,11i64] } + +def main = map3 (\A (i:i64) (j: i64) -> i64.sum (flatten A[i:,j:])) diff --git a/tests/flattening/test_regular_undefined.fut b/tests/flattening/test_regular_undefined.fut new file mode 100644 index 0000000000..9ab35d65e4 --- /dev/null +++ b/tests/flattening/test_regular_undefined.fut @@ -0,0 +1,7 @@ +-- Triggers the current undefined case in onMapInputArr. +-- == +-- input { [1i64, 2i64, 3i64] } +-- auto output + +def main (xs: []i64) = + map (\x -> map (+1) [x, x*x]) xs diff --git a/tests/flattening/update_constant.fut b/tests/flattening/update_constant.fut new file mode 100644 index 0000000000..b20a6e69ac --- /dev/null +++ b/tests/flattening/update_constant.fut @@ -0,0 +1,7 @@ +-- Simple test for flattening an update with a constant value +-- == +-- input { [1i64,2i64,3i64] } +-- output { [12i64,11i64,10i64] } + +entry main [n] (xs : [n]i64) = + map (\x -> reduce (+) 0 (iota 5 with [x] = 3)) xs diff --git a/tests/flattening/update_dimfix.fut b/tests/flattening/update_dimfix.fut new file mode 100644 index 0000000000..6ddad56a10 --- /dev/null +++ b/tests/flattening/update_dimfix.fut @@ -0,0 +1,38 @@ +-- Test with fixed dimension +-- == +-- input { [0,1,2,3,4] [0,1,2,3,4] [5,6,7,8,9] } +-- output { +-- [[[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]]] +-- } + +let main (arr: []i32) (is: []i32) (js: []i32) = + [map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ] diff --git a/tests/flattening/update_fully_irregular.fut b/tests/flattening/update_fully_irregular.fut new file mode 100644 index 0000000000..cbd5a98f93 --- /dev/null +++ b/tests/flattening/update_fully_irregular.fut @@ -0,0 +1,7 @@ +-- Fully irregular test-case +-- == +-- input { [5i64,6i64,7i64] [2i64,3i64,1i64] [3i64,1i64,2i64] [5i64,6i64,3i64] [1i64,2i64,3i64] } +-- output { [4i64,9i64,19i64] } + +entry main [n] (xs : [n]i64) (vs : [n]i64) (is : [n]i64) (js : [n]i64) (ss: [n]i64) = + map5 (\x v i j s -> reduce (+) 0 (iota x with [i:j:s] = iota v)) xs vs is js ss diff --git a/tests/flattening/update_invariant_is.fut b/tests/flattening/update_invariant_is.fut new file mode 100644 index 0000000000..31c8672602 --- /dev/null +++ b/tests/flattening/update_invariant_is.fut @@ -0,0 +1,7 @@ +-- Test with only invariant indices. +-- == +-- input { [4i64,5i64,6i64] [3i64,3i64,3i64] } +-- output { [3i64,7i64,12i64] } + +entry main [n] (xs : [n]i64) (vs : [n]i64) = + map2(\x v -> reduce (+) 0 (iota x with [1:4] = iota v)) xs vs diff --git a/tests/flattening/update_invariant_vs.fut b/tests/flattening/update_invariant_vs.fut new file mode 100644 index 0000000000..ab2a01225a --- /dev/null +++ b/tests/flattening/update_invariant_vs.fut @@ -0,0 +1,8 @@ +-- Test with only invariant 'vs'. +-- == +-- input { [6i64,7i64,8i64] [0i64,1i64,2i64] [5i64,6i64,7i64] } +-- output { [15i64,16i64,18i64] } + +entry main [n] (xs : [n]i64) (is : [n]i64) (js : [n]i64) = + map3(\x i j -> reduce (+) 0 (iota x with [i:j] = iota 5)) xs is js + diff --git a/tests/flattening/update_invariant_xs.fut b/tests/flattening/update_invariant_xs.fut new file mode 100644 index 0000000000..fe9dc215dc --- /dev/null +++ b/tests/flattening/update_invariant_xs.fut @@ -0,0 +1,7 @@ +-- Test with only invariant 'xs'. +-- == +-- input { [1i64,2i64,3i64] [3i64,3i64,3i64] } +-- output { [8i64,8i64,10i64] } + +entry main [n] (is : [n]i64) (js : [n]i64) = + map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js diff --git a/tests/flattening/update_mixdim.fut b/tests/flattening/update_mixdim.fut new file mode 100644 index 0000000000..77b6aa16bf --- /dev/null +++ b/tests/flattening/update_mixdim.fut @@ -0,0 +1,12 @@ +-- Mixing slices and indexes in complex ways. +-- == +-- input { [0i64,1i64] +-- [2i64,3i64] +-- [[[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]], +-- [[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]]] +-- [[0f64,1f64],[4f64,5f64]] +-- } +-- output { [91.0, 99.0] } + +let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vs : [n][]f64) = + map4(\i j as vs -> f64.sum(flatten(copy as with [i,i:j] = vs))) is js ass vs diff --git a/tests/flattening/update_multdim.fut b/tests/flattening/update_multdim.fut new file mode 100644 index 0000000000..f90aa3f2b3 --- /dev/null +++ b/tests/flattening/update_multdim.fut @@ -0,0 +1,11 @@ +-- == +-- input { [0i64,1i64] +-- [2i64,3i64] +-- [[[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]], +-- [[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]]] +-- [[[0f64,1f64],[2f64,3f64]],[[4f64,5f64],[6f64,7f64]]] +-- } +-- output { [78.0, 94.0] } + +let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vss : [n][][]f64) = + map4(\i j as vs -> f64.sum(flatten(copy as with [i:j,i:j] = vs))) is js ass vss diff --git a/tests/flattening/update_variant_is.fut b/tests/flattening/update_variant_is.fut new file mode 100644 index 0000000000..8c62f3ba46 --- /dev/null +++ b/tests/flattening/update_variant_is.fut @@ -0,0 +1,7 @@ +-- Test with only variant indices. +-- == +-- input { [0i64,3i64,1i64] [5i64,8i64,6i64] } +-- output { [28i64,13i64,23i64] } + +entry main [n] (is : [n]i64) (js : [n]i64) = + map2 (\i j -> reduce (+) 0 (iota 8 with [i:j] = iota 5)) is js diff --git a/tests/flattening/update_variant_vs.fut b/tests/flattening/update_variant_vs.fut new file mode 100644 index 0000000000..3fdec27aae --- /dev/null +++ b/tests/flattening/update_variant_vs.fut @@ -0,0 +1,7 @@ +-- Test with only variant 'vs'. +-- == +-- input { [3i64,3i64,3i64] } +-- output { [7i64,7i64,7i64] } + +entry main (vs : []i64) = + map (\v -> reduce (+) 0 (iota 5 with [1:4] = iota v)) vs diff --git a/tests/flattening/update_variant_xs.fut b/tests/flattening/update_variant_xs.fut new file mode 100644 index 0000000000..89730113b3 --- /dev/null +++ b/tests/flattening/update_variant_xs.fut @@ -0,0 +1,7 @@ +-- Test with only variant 'xs'. +-- == +-- input { [4i64,5i64,6i64] } +-- output { [3i64,7i64,12i64] } + +entry main [n] (xs : [n]i64) = + map (\x -> reduce (+) 0 (iota x with [1:4] = iota 3)) xs diff --git a/tests/flattening/while/while-2d-irregular-param.fut b/tests/flattening/while/while-2d-irregular-param.fut new file mode 100644 index 0000000000..08de8d1fea --- /dev/null +++ b/tests/flattening/while/while-2d-irregular-param.fut @@ -0,0 +1,15 @@ +-- == +-- input { [[3i64, 1i64, 4i64], [2i64, 5i64, 3i64]] } +-- auto output +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + map (\x -> + let zs = iota x + let (_, res) = + loop (i, arr) = (0i64, zs) + while i < x do + let arr' = opaque (map (\z -> z * i + x) arr) + in (i + 1, arr') + in reduce (+) 0 res + x) + xs) + xss diff --git a/tests/flattening/while/while-2d-regular-param.fut b/tests/flattening/while/while-2d-regular-param.fut new file mode 100644 index 0000000000..0a1c721054 --- /dev/null +++ b/tests/flattening/while/while-2d-regular-param.fut @@ -0,0 +1,16 @@ +-- == +-- input { [[3i64, 1i64, 4i64], [2i64, 5i64, 3i64]] } +-- auto output +def main [n] [m] (xss: [n][m]i64) = + map (\xs -> + map (\x -> + let extra0 = map (\z -> z + x) xs + let (_, arr, extra) = + loop (i, arr, extra) = (0i64, xs, extra0) + while i < x do + let arr' = opaque (map (\z -> z * i + x) arr) + let extra' = opaque (map (\z -> z + i) extra) + in (i + 1, arr', extra') + in reduce (+) 0 arr + reduce (+) 0 extra + x) + xs) + xss diff --git a/tests/flattening/while/while-filter.fut b/tests/flattening/while/while-filter.fut new file mode 100644 index 0000000000..5ec7d34a90 --- /dev/null +++ b/tests/flattening/while/while-filter.fut @@ -0,0 +1,15 @@ +-- == Currently fails +-- input { [3i64, 7i64, 100i64, 1i64, 1000i64] } +-- auto output +def main (xs) = + map (\x -> + let z' = iota x + let (_, res, _) = + loop (i, s, ac) = (0, 0, z') + while i < x do + let ac' = filter (\t -> t < i) ac + let s' = reduce (+) 0 ac' + let s'' = s + s' + in (i + 1, s'', ac') + in res) + xs diff --git a/tests/flattening/while/while-irregular-param.fut b/tests/flattening/while/while-irregular-param.fut new file mode 100644 index 0000000000..118c510957 --- /dev/null +++ b/tests/flattening/while/while-irregular-param.fut @@ -0,0 +1,14 @@ +-- == +-- input { [3i64, 7i64, 10i64, 1i64, 20i64] } +-- auto output +def main (xs) = + map (\x -> + let zs = iota x in + let (_,b) = + loop (i,zs) = (0, zs) while i < x do + let tes = map (\z -> z * i + x) zs + in (i + 1, tes) + in reduce (+) 0 b + x + ) xs + + diff --git a/tests/flattening/while/while-map-reduce.fut b/tests/flattening/while/while-map-reduce.fut new file mode 100644 index 0000000000..752fe4130a --- /dev/null +++ b/tests/flattening/while/while-map-reduce.fut @@ -0,0 +1,20 @@ +-- == +-- input { [2i64, 5i64, 4i64] } +-- auto output +def main (xs) = + map (\x -> + let zs = iota x + let some_res = + map (\z -> + let z' = iota z + let (_, res, _) = + loop (i, s, ac) = (0, 0, z') + while i < x do + let ac' = map (+ i) ac + let s' = reduce (+) 0 ac' + let s'' = s + s' + in (i + 1, s'', ac') + in res) + zs + in reduce (+) 0 some_res) + xs diff --git a/tests/flattening/while/while-simple.fut b/tests/flattening/while/while-simple.fut new file mode 100644 index 0000000000..312e0933c1 --- /dev/null +++ b/tests/flattening/while/while-simple.fut @@ -0,0 +1,13 @@ +-- == +-- input { [3i64, 7i64, 1i64, 10i64] } +-- auto output +def main [n] (xs: [n]i64) = + map (\x -> + let zs = iota x + let some_res = + map (\z -> + let res = loop acc = z while acc < 10 do acc + 2 + in res) + zs + in reduce (+) 0 some_res + x) + xs diff --git a/tests/flattening/while/while-variant-shape.fut b/tests/flattening/while/while-variant-shape.fut new file mode 100644 index 0000000000..2949f745c3 --- /dev/null +++ b/tests/flattening/while/while-variant-shape.fut @@ -0,0 +1,11 @@ +-- == +-- input { [1i32, 4i32, 8i32, 16i32] } +-- auto output +def main (ns: []i32) : []i32 = + map (\n -> + let res = + loop arr = [1] + while length arr < i64.i32 n do + arr ++ arr + in i32.sum res) + ns \ No newline at end of file diff --git a/tests/flattening/withacc/scatter_irregular.fut b/tests/flattening/withacc/scatter_irregular.fut new file mode 100644 index 0000000000..0eb95964a1 --- /dev/null +++ b/tests/flattening/withacc/scatter_irregular.fut @@ -0,0 +1,13 @@ +-- Maximally irregular case. +-- == +-- input { [0i64, 2, 1, 3, 2, 3, 10] +-- [2,3,1,4,5,6,7] +-- [1i64, 2, 3, 4] +-- [4i64, 3, 4, 5] +-- } +-- output { [2, 1, 3, 4] } + +def main is vs = + map2 \n m -> + let arr = scatter (replicate n 0i32) (take m is) (take m vs) + in last (opaque arr) diff --git a/tests/flattening/withacc/scatter_regular.fut b/tests/flattening/withacc/scatter_regular.fut new file mode 100644 index 0000000000..c822e8f725 --- /dev/null +++ b/tests/flattening/withacc/scatter_regular.fut @@ -0,0 +1,11 @@ +-- Completely regular case. +-- == +-- input { [[1,2,3], [4,5,6], [7,8,9]] +-- [[0i64, 2], [-1i64, 0], [1i64,0]] +-- [[1,2], [3,4], [5,6]] +-- } +-- output { [[1, 2, 2], [4, 5, 6], [6, 5, 9]] } + +def main = + map3 \(xs: []i32) (is: []i64) (vs: []i32) -> + scatter (copy xs) is vs