diff --git a/grin/grin.cabal b/grin/grin.cabal index 5fe01a2f..d2e0880e 100644 --- a/grin/grin.cabal +++ b/grin/grin.cabal @@ -151,6 +151,7 @@ library Transformations.ExtendedSyntax.Optimising.ConstantPropagation Transformations.ExtendedSyntax.Optimising.CSE Transformations.ExtendedSyntax.Optimising.EvaluatedCaseElimination + Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxing Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionElimination Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisation Transformations.ExtendedSyntax.Optimising.TrivialCaseElimination @@ -309,6 +310,7 @@ test-suite grin-test Transformations.ExtendedSyntax.Optimising.CopyPropagationSpec Transformations.ExtendedSyntax.Optimising.CSESpec Transformations.ExtendedSyntax.Optimising.EvaluatedCaseEliminationSpec + Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxingSpec Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionEliminationSpec Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisationSpec Transformations.ExtendedSyntax.Optimising.TrivialCaseEliminationSpec diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxing.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxing.hs new file mode 100644 index 00000000..3690d893 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxing.hs @@ -0,0 +1,187 @@ +{-# LANGUAGE LambdaCase, TupleSections, OverloadedStrings #-} +module Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxing where + +import Data.Set (Set) +import Data.Vector (Vector) +import Data.Map.Strict (Map) +import Data.Function (fix) +import Data.Bifunctor (second) +import Data.Functor.Infix ((<$$>)) +import Data.Functor.Foldable as Foldable +import Data.Maybe (catMaybes, mapMaybe, isJust) + +import Lens.Micro.Platform + +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set +import qualified Data.Vector as Vector + +import Transformations.ExtendedSyntax.Util (anaM, apoM) +import Transformations.ExtendedSyntax.Names + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv +import Grin.ExtendedSyntax.Pretty + + +generalizedUnboxing :: TypeEnv -> Exp -> (Exp, ExpChanges) +generalizedUnboxing te exp = if (null funs) + then (exp, NoChange) + else second + (const NewNames) -- New functions are created, but NameM monad is not used + (evalNameM exp (transformCalls funs te =<< transformReturns funs te exp)) + where + funs = functionsToUnbox te exp + +-- TODO: Support tagless nodes. + +tailCalls :: Exp -> Maybe [Name] +tailCalls = cata collect where + collect :: ExpF (Maybe [Name]) -> Maybe [Name] + collect = \case + DefF _ _ result -> result + EBindF _ _ result -> result + ECaseF _ alts -> nonEmpty $ concat $ catMaybes alts + AltF _ _ result -> result + SAppF f _ -> Just [f] + e -> Nothing + +nonEmpty :: [a] -> Maybe [a] +nonEmpty [] = Nothing +nonEmpty xs = Just xs + +doesReturnAKnownProduct :: TypeEnv -> Name -> Bool +doesReturnAKnownProduct = isJust <$$> returnsAUniqueTag + +returnsAUniqueTag :: TypeEnv -> Name -> Maybe (Tag, Type) +returnsAUniqueTag te name = do + (tag, vs) <- te ^? function . at name . _Just . _1 . _T_NodeSet . to Map.toList . to singleton . _Just + typ <- singleton (Vector.toList vs) + pure (tag, T_SimpleType typ) + +singleton :: [a] -> Maybe a +singleton = \case + [] -> Nothing + [a] -> Just a + _ -> Nothing + +transitive :: (Ord a) => (a -> Set a) -> Set a -> Set a +transitive f res0 = + let res1 = res0 `Set.union` (Set.unions $ map f $ Set.toList res0) + in if res1 == res0 + then res0 + else transitive f res1 + +-- TODO: Remove the fix combinator, explore the function +-- dependency graph and rewrite disqualify steps based on that. +functionsToUnbox :: TypeEnv -> Exp -> Set Name +functionsToUnbox te (Program exts defs) = result where + funName (Def n _ _) = n + + tailCallsMap :: Map Name [Name] + tailCallsMap = Map.fromList $ mapMaybe (\e -> (,) (funName e) <$> tailCalls e) defs + + tranisitiveTailCalls :: Map Name (Set Name) + tranisitiveTailCalls = Map.fromList $ map (\k -> (k, transitive inTailCalls (Set.singleton k))) $ Map.keys tailCallsMap + where + inTailCalls :: Name -> Set Name + inTailCalls n = maybe mempty Set.fromList $ Map.lookup n tailCallsMap + + nonCandidateTailCallMap = Map.withoutKeys tranisitiveTailCalls result0 + candidateCalledByNonCandidate = (Set.unions $ Map.elems nonCandidateTailCallMap) `Set.intersection` result0 + result = result0 `Set.difference` candidateCalledByNonCandidate + + result0 = Set.fromList $ step initial + initial = map funName $ filter (doesReturnAKnownProduct te . funName) defs + disqualify candidates = filter + (\candidate -> case Map.lookup candidate tailCallsMap of + Nothing -> True + Just calls -> all (`elem` candidates) calls) + candidates + step = fix $ \rec x0 -> + let x1 = disqualify x0 in + if x0 == x1 + then x0 + else rec x1 + +updateTypeEnv :: Set Name -> TypeEnv -> TypeEnv +updateTypeEnv funs te = te & function %~ unboxFun + where + unboxFun = Map.fromList . map changeFun . Map.toList + changeFun (n, ts@(ret, params)) = + if Set.member n funs + then (,) (n <> ".unboxed") + $ maybe ts ((\t -> (t, params)) . T_SimpleType) $ + ret ^? _T_NodeSet + . to Map.elems + . to singleton + . _Just + . to Vector.toList + . to singleton + . _Just + else (n, ts) + +transformReturns :: Set Name -> TypeEnv -> Exp -> NameM Exp +transformReturns toUnbox te exp = apoM builder (Nothing, exp) where + builder :: (Maybe (Tag, Type), Exp) -> NameM (ExpF (Either Exp (Maybe (Tag, Type), Exp))) + builder (mTagType, exp0) = case exp0 of + Def name params body + | Set.member name toUnbox -> pure $ DefF name params (Right (returnsAUniqueTag te name, body)) + | otherwise -> pure $ DefF name params (Left body) + + -- Always skip the lhs of a bind. + EBind lhs pat rhs -> pure $ EBindF (Left lhs) pat (Right (mTagType, rhs)) + + -- Remove the tag from the value + SReturn (ConstTagNode tag [arg]) -> pure $ SReturnF (Var arg) + + -- Rewrite a node variable + simpleExp + -- fromJust works, as when we enter the processing of body of the + -- expression only happens with the provided tag. + | canUnbox simpleExp + , Just (tag, typ) <- mTagType + -> do + freshName <- deriveNewName $ "unboxed." <> (showTS $ PP tag) + asPatName <- deriveWildCard + pure . SBlockF . Left $ EBind simpleExp (AsPat tag [freshName] asPatName) (SReturn $ Var freshName) + + rest -> pure (Right . (,) mTagType <$> project rest) + + -- NOTE: SApp is handled by transformCalls + canUnbox :: SimpleExp -> Bool + canUnbox = \case + SApp n ps -> n `Set.notMember` toUnbox + SReturn{} -> True + SFetch{} -> True + _ -> False + +transformCalls :: Set Name -> TypeEnv -> Exp -> NameM Exp +transformCalls toUnbox typeEnv exp = anaM builderM (True, Nothing, exp) where + builderM :: (Bool, Maybe Name, Exp) -> NameM (ExpF (Bool, Maybe Name, Exp)) + + builderM (isRightExp, mDefName, e) = case e of + + Def name params body + -> pure $ DefF (if Set.member name toUnbox then name <> ".unboxed" else name) params (True, Just name, body) + + -- track the control flow + EBind lhs pat rhs -> pure $ EBindF (False, mDefName, lhs) pat (isRightExp, mDefName, rhs) + + SApp name params + | Set.member name toUnbox + , Just defName <- mDefName + , unboxedName <- name <> ".unboxed" + , Just (tag, fstType) <- returnsAUniqueTag typeEnv name + -> if Set.member defName toUnbox && isRightExp + + -- from candidate to candidate: tailcalls do not need a transform + then pure $ SAppF unboxedName params + + -- from outside to candidate + else do + freshName <- deriveNewName $ "unboxed." <> (showTS $ PP tag) + pure . SBlockF . (isRightExp, mDefName,) $ + EBind (SApp unboxedName params) (VarPat freshName) (SReturn $ ConstTagNode tag [freshName]) + + rest -> pure ((isRightExp, mDefName,) <$> project rest) diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxingSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxingSpec.hs new file mode 100644 index 00000000..3dd65029 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxingSpec.hs @@ -0,0 +1,398 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxingSpec where + +import Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxing + + +import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import qualified Data.Vector as Vector + +import Test.Hspec + +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "Figure 4.21 (extended)" $ do + let teBefore = emptyTypeEnv + { _function = Map.fromList + [ ("test", (int64_t, Vector.fromList [int64_t])) + , ("foo", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2B", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2C", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo3", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo4", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo5", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("bar", (int64_t, Vector.fromList [])) + ] + } + let before = [prog| + test n = + k0 <- pure 1 + prim_int_add n k0 + + foo a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure (CInt b2) + + foo2 a1 a2 a3 = + c1 <- prim_int_add a1 a2 + foo c1 c1 a3 + + foo2B a1 a2 a3 = + c1 <- prim_int_add a1 a2 + do + foo c1 c1 a3 + + foo2C a1 a2 a3 = + c1 <- prim_int_add a1 a2 + case c1 of + #default @ alt1 -> pure c1 + (CInt x1) @ alt2 -> foo c1 c1 a3 + + foo3 a1 a2 a3 = + c1 <- prim_int_add a1 a2 + -- In this case the vectorisation did not happen. + c2 <- foo c1 c1 a3 + pure c2 + + foo4 a1 = + v <- pure (CInt a1) + pure v + + foo5 a1 = + n0 <- pure (CInt a1) + p <- store n0 + fetch p + + bar = + k1 <- pure 1 + n1 <- test k1 + (CInt y') @ _0 <- foo a1 a2 a3 + test y' + |] + let teAfter = emptyTypeEnv + { _function = Map.fromList + [ ("test", (int64_t, Vector.fromList [int64_t])) + , ("foo.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2B.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2C.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo3.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo4.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo5.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("bar", (int64_t, Vector.fromList [])) + ] + , _variable = Map.fromList + [ ("unboxed.CInt.0", int64_t) + , ("unboxed.CInt.1", int64_t) + , ("unboxed.CInt.2", int64_t) + , ("unboxed.CInt.3", int64_t) + , ("unboxed.CInt.4", int64_t) + , ("unboxed.CInt.5", int64_t) + ] + } + let after = [prog| + test n = + k0 <- pure 1 + prim_int_add n k0 + + foo.unboxed a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure b2 + + foo2.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + foo.unboxed c1 c1 a3 + + foo2B.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + do + foo.unboxed c1 c1 a3 + + foo2C.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + case c1 of + #default @ alt1 -> + do + (CInt unboxed.CInt.0) @ _1 <- pure c1 + pure unboxed.CInt.0 + (CInt x1) @ alt2 -> + foo.unboxed c1 c1 a3 + + foo3.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + c2 <- do + unboxed.CInt.4 <- foo.unboxed c1 c1 a3 + pure (CInt unboxed.CInt.4) + do + (CInt unboxed.CInt.1) @ _2 <- pure c2 + pure unboxed.CInt.1 + + foo4.unboxed a1 = + v <- pure (CInt a1) + do + (CInt unboxed.CInt.2) @ _3 <- pure v + pure unboxed.CInt.2 + + foo5.unboxed a1 = + n0 <- pure (CInt a1) + p <- store n0 + do + (CInt unboxed.CInt.3) @ _4 <- fetch p + pure unboxed.CInt.3 + + bar = + k1 <- pure 1 + n1 <- test k1 + (CInt y') @ _0 <- do + unboxed.CInt.5 <- foo.unboxed a1 a2 a3 + pure (CInt unboxed.CInt.5) + test y' + |] + generalizedUnboxing teBefore before `sameAs` (after, NewNames) + + it "Return values are in cases" $ do + let teBefore = emptyTypeEnv + { _function = + fun_t "int_eq" + [ T_NodeSet $ cnode_t "Int" [T_Int64] + , T_NodeSet $ cnode_t "Int" [T_Int64] + ] + (T_NodeSet $ cnode_t "Int" [T_Int64]) + , _variable = Map.fromList + [ ("eq0", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq1", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq0_1", int64_t) + , ("eq1_1", int64_t) + , ("eq2", bool_t) + ] + } + let before = [prog| + int_eq eq0 eq1 = + (CInt eq0_1) @ alt1 <- fetch eq0 + (CInt eq1_1) @ alt2 <- fetch eq1 + eq2 <- _prim_int_eq eq0_1 eq1_1 + case eq2 of + #False @ alt3 -> + k0 <- pure 0 + pure (CInt k0) + #True @ alt4 -> + k1 <- pure 1 + pure (CInt k1) + |] + let teAfter = emptyTypeEnv + { _function = + fun_t "int_eq.unboxed" + [ T_NodeSet $ cnode_t "Int" [T_Int64] + , T_NodeSet $ cnode_t "Int" [T_Int64] + ] + int64_t + , _variable = Map.fromList + [ ("eq0", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq1", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq0_1", int64_t) + , ("eq1_1", int64_t) + , ("eq2", bool_t) + ] + } + let after = [prog| + int_eq.unboxed eq0 eq1 = + (CInt eq0_1) @ alt1 <- fetch eq0 + (CInt eq1_1) @ alt2 <- fetch eq1 + eq2 <- _prim_int_eq eq0_1 eq1_1 + case eq2 of + #False @ alt3 -> + k0 <- pure 0 + pure k0 + #True @ alt4 -> + k1 <- pure 1 + pure k1 + |] + generalizedUnboxing teBefore before `sameAs` (after, NewNames) + + it "Step 1 for Figure 4.21" $ do + let teBefore = emptyTypeEnv + { _function = Map.fromList + [ ("test", (int64_t, Vector.fromList [int64_t])) + , ("foo", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("bar", (int64_t, Vector.fromList [])) + ] + } + let before = [prog| + test n = + k0 <- pure 1 + prim_int_add n k0 + + foo a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure (CInt b2) + + bar = + k1 <- pure 1 + n <- test k1 + (CInt y') @ _1 <- foo a1 a2 a3 + test y' + |] + functionsToUnbox teBefore before `shouldBe` (Set.fromList ["foo"]) + + it "Tail calls and general unboxing" $ do + let teBefore = emptyTypeEnv + { _function = Map.fromList + [ ("inside1", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("outside3", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ,(Tag C "Nat", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + , ("outside4", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + , ("outside2", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + , ("outside1", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + ] + } + let before = [prog| + inside1 a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure (CInt b2) + + outside4 = + k0 <- pure () + k1 <- pure 1 + _1 <- pure k0 + outside3 k2 + + outside3 p1 = + case p1 of + 1 @ alt1 -> inside1 p1 p1 p1 -- :: CInt Int + 2 @ alt2 -> outside2 p1 -- :: CNat Int + + outside2 p1 = + k0 <- pure () + k1 <- pure 1 + _2 <- pure k0 + outside1 p1 + + outside1 p1 = + k2 <- pure 1 + y <- prim_int_add p1 k2 + x <- pure (CNat y) + pure x + |] + functionsToUnbox teBefore before `shouldBe` mempty + + it "Tail call function 1" $ do + let fun = [def| + fun x = + l <- store x + k0 <- pure 3 + tail k0 + |] + tailCalls fun `shouldBe` (Just ["tail"]) + + it "Tail call function 2" $ do + let fun = [def| + fun x = + l <- pure x + k0 <- pure 1 + case k0 of + 1 @ alt1 -> + k1 <- pure 1 + x <- prim_int_add k1 k1 + tail1 x + 2 @ alt2 -> + k2 <- pure 2 + x <- prim_int_add k2 k2 + tail2 x + |] + tailCalls fun `shouldBe` (Just ["tail1", "tail2"]) + + it "Partially tail call function 2" $ do + let fun = [def| + fun x = + l <- store x + k0 <- pure 1 + case k0 of + 1 @ alt1 -> + k1 <- pure 1 + x <- prim_int_add k1 k1 + y <- tail x + pure y + 2 @ alt2 -> + k2 <- pure 2 + x <- prim_int_add k2 k2 + tail x + |] + tailCalls fun `shouldBe` (Just ["tail"]) + + it "Non-tail call function 1" $ do + let fun = [def| + fun x = + l <- store x + k0 <- pure 3 + y <- tail k0 + pure x + |] + tailCalls fun `shouldBe` Nothing