Skip to content

Commit

Permalink
Merge pull request #83 from grin-compiler/32-trf-const-propagation-2
Browse files Browse the repository at this point in the history
Extended Syntax: constant propagation
  • Loading branch information
Anabra authored Apr 19, 2020
2 parents 47ec536 + a9a92bb commit 4972b34
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 4 deletions.
1 change: 1 addition & 0 deletions grin/grin.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ library
Transformations.ExtendedSyntax.StaticSingleAssignment
Transformations.ExtendedSyntax.Optimising.ArityRaising
Transformations.ExtendedSyntax.Optimising.CopyPropagation
Transformations.ExtendedSyntax.Optimising.ConstantPropagation
Transformations.ExtendedSyntax.Optimising.CSE
Transformations.ExtendedSyntax.Optimising.EvaluatedCaseElimination
Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionElimination
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{-# LANGUAGE LambdaCase, TupleSections, ViewPatterns #-}
module Transformations.ExtendedSyntax.Optimising.ConstantPropagation where


import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Functor.Foldable

import Lens.Micro ((^.))

import Grin.ExtendedSyntax.Grin
import Transformations.ExtendedSyntax.Util

{-
HINT:
propagates only tag values but not literals
GRIN is not a supercompiler
NOTE:
We only need the tag information to simplify case expressions.
This means that Env could be a Name -> Tag mapping.
-}

type Env = Map Name Val

constantPropagation :: Exp -> Exp
constantPropagation e = ana builder (mempty, e) where

builder :: (Env, Exp) -> ExpF (Env, Exp)
builder (env, exp) = case exp of
ECase scrut alts ->
let constVal = getValue scrut env
known = isKnown constVal || Map.member scrut env
matchingAlts = [alt | alt@(Alt cpat name body) <- alts, match cpat constVal]
defaultAlts = [alt | alt@(Alt DefaultPat name body) <- alts]
-- HINT: use cpat as known value in the alternative ; bind cpat to val
altEnv cpat = env `mappend` unify env scrut (cPatToVal cpat)
in case (known, matchingAlts, defaultAlts) of
-- known scutinee, specific pattern
(True, [Alt cpat name body], _) -> (env,) <$> SBlockF (EBind (SReturn $ constVal) (cPatToAsPat cpat name) body)

-- known scutinee, default pattern
(True, _, [Alt DefaultPat name body]) -> (env,) <$> SBlockF (EBind (SReturn $ Var scrut) (VarPat name) body)

-- unknown scutinee
-- HINT: in each alternative set val value like it was matched
_ -> ECaseF scrut [(altEnv cpat, alt) | alt@(Alt cpat name _) <- alts]

-- track values
EBind (SReturn val) bPat rightExp -> (env `mappend` unify env (bPat ^. _BPatVar) val,) <$> project exp

_ -> (env,) <$> project exp

unify :: Env -> Name -> Val -> Env
unify env var val = case val of
ConstTagNode{} -> Map.singleton var val
Unit -> Map.singleton var val -- HINT: default pattern (minor hack)
Var v -> Map.singleton var (getValue v env)
Lit{} -> mempty
_ -> error $ "ConstantPropagation/unify: unexpected value: " ++ show (val) -- TODO: PP

isKnown :: Val -> Bool
isKnown = \case
ConstTagNode{} -> True
_ -> False

match :: CPat -> Val -> Bool
match (NodePat tagA _) (ConstTagNode tagB _) = tagA == tagB
match _ _ = False

getValue :: Name -> Env -> Val
getValue varName env = Map.findWithDefault (Var varName) varName env
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ trivialCaseElimination = ana builder where
builder :: Exp -> ExpF Exp
builder = \case
ECase scrut [Alt DefaultPat altName body] -> SBlockF $ EBind (SReturn (Var scrut)) (VarPat altName) body
ECase scrut [Alt cpat altName body] -> SBlockF $ EBind (SReturn (Var scrut)) (cPatToAsPat altName cpat) body
ECase scrut [Alt cpat altName body] -> SBlockF $ EBind (SReturn (Var scrut)) (cPatToAsPat cpat altName) body
exp -> project exp
6 changes: 3 additions & 3 deletions grin/src/Transformations/ExtendedSyntax/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ cPatToVal = \case
LitPat lit -> Lit lit
DefaultPat -> Unit

cPatToAsPat :: Name -> CPat -> BPat
cPatToAsPat name (NodePat tag args) = AsPat tag args name
cPatToAsPat _ cPat = error $ "cPatToAsPat: cannot convert to as-pattern: " ++ show (PP cPat)
cPatToAsPat :: CPat -> Name -> BPat
cPatToAsPat (NodePat tag args) name = AsPat tag args name
cPatToAsPat cPat _ = error $ "cPatToAsPat: cannot convert to as-pattern: " ++ show (PP cPat)

-- monadic recursion schemes
-- see: https://jtobin.io/monadic-recursion-schemes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-}
module Transformations.ExtendedSyntax.Optimising.ConstantPropagationSpec where

import Transformations.ExtendedSyntax.Optimising.ConstantPropagation

import Test.Hspec

import Grin.ExtendedSyntax.TH
import Test.ExtendedSyntax.Assertions


runTests :: IO ()
runTests = hspec spec


spec :: Spec
spec = do
it "ignores binds" $ do
let before = [expr|
i1 <- pure 1
i2 <- pure i1
n1 <- pure (CNode i2)
n2 <- pure n1
(CNode i3) @ n3 <- pure n1
pure 2
|]
let after = [expr|
i1 <- pure 1
i2 <- pure i1
n1 <- pure (CNode i2)
n2 <- pure n1
(CNode i3) @ n3 <- pure n1
pure 2
|]
constantPropagation before `sameAs` after

it "is not interprocedural" $ do
let before = [prog|
grinMain =
x <- f
case x of
(COne) @ alt1 -> pure 0
(CTwo) @ alt2 -> pure 1

f = pure (COne)
|]
let after = [prog|
grinMain =
x <- f
case x of
(COne) @ alt1 -> pure 0
(CTwo) @ alt2 -> pure 1

f = pure (COne)
|]
constantPropagation before `sameAs` after

it "does not propagate info outwards of case expressions" $ do
let before = [prog|
grinMain =
x <- pure 0
y <- case x of
0 @ alt1 -> pure (COne)
case y of
(COne) @ alt2 -> pure 0
(CTwo) @ alt3 -> pure 1
|]
let after = [prog|
grinMain =
x <- pure 0
y <- case x of
0 @ alt1 -> pure (COne)
case y of
(COne) @ alt2 -> pure 0
(CTwo) @ alt3 -> pure 1
|]
constantPropagation before `sameAs` after

it "base case" $ do
let before = [expr|
i1 <- pure 1
n1 <- pure (CNode i1)
case n1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
|]
let after = [expr|
i1 <- pure 1
n1 <- pure (CNode i1)
do
(CNode a1) @ alt2 <- pure (CNode i1)
pure 2
|]
constantPropagation before `sameAs` after

it "ignores illformed case - multi matching" $ do
let before = [expr|
i1 <- pure 1
n1 <- pure (CNode i1)
_1 <- case n1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
(CNode b1) @ alt3 -> pure 3
case n1 of
(CNil) @ alt4 -> pure 4
#default @ alt5 -> pure 5
#default @ alt6 -> pure 6
|]
let after = [expr|
i1 <- pure 1
n1 <- pure (CNode i1)
_1 <- case n1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
(CNode b1) @ alt3 -> pure 3
case n1 of
(CNil) @ alt4 -> pure 4
#default @ alt5 -> pure 5
#default @ alt6 -> pure 6
|]
constantPropagation before `sameAs` after

it "default pattern" $ do
let before = [expr|
i1 <- pure 1
n1 <- pure (CNode i1)
case n1 of
(CNil) @ alt1 -> pure 2
#default @ alt2 -> pure 3
|]
let after = [expr|
i1 <- pure 1
n1 <- pure (CNode i1)
do
alt2 <- pure n1
pure 3
|]
constantPropagation before `sameAs` after

it "unknown scrutinee - simple" $ do
let before = [expr|
case n1 of
(CNil) @ alt1 -> pure 2
#default @ alt2 -> pure 3
|]
let after = [expr|
case n1 of
(CNil) @ alt1 -> pure 2
#default @ alt2 -> pure 3
|]
constantPropagation before `sameAs` after

it "unknown scrutinee becomes known in alternatives - specific pattern" $ do
let before = [expr|
case n1 of
(CNil) @ alt11 ->
case n1 of
(CNil) @ alt21 -> pure 1
(CNode a1) @ alt22 -> pure 2
(CNode a2) @ alt12 ->
case n1 of
(CNil) @ alt23 -> pure 3
(CNode a3) @ alt24 -> pure 4
|]
let after = [expr|
case n1 of
(CNil) @ alt11 ->
do
(CNil) @ alt21 <- pure (CNil)
pure 1
(CNode a2) @ alt12 ->
do
(CNode a3) @ alt24 <- pure (CNode a2)
pure 4
|]
constantPropagation before `sameAs` after

it "unknown scrutinee becomes known in alternatives - default pattern" $ do
let before = [expr|
case n1 of
#default @ alt11 ->
case n1 of
#default @ alt21 -> pure 1
(CNode a1) @ alt22 -> pure 2
(CNode a2) @ alt12 ->
case n1 of
#default @ alt23 -> pure 3
(CNode a3) @ alt24 -> pure 4
|]
let after = [expr|
case n1 of
#default @ alt11 ->
do
alt21 <- pure n1
pure 1
(CNode a2) @ alt12 ->
do
(CNode a3) @ alt24 <- pure (CNode a2)
pure 4
|]
constantPropagation before `sameAs` after

it "literal - specific pattern" $ do
let before = [expr|
i1 <- pure 1
case i1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
1 @ alt3 -> pure 3
2 @ alt4 -> pure 4
#default @ alt5 -> pure 5
|]
let after = [expr|
i1 <- pure 1
case i1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
1 @ alt3 -> pure 3
2 @ alt4 -> pure 4
#default @ alt5 -> pure 5
|]
constantPropagation before `sameAs` after

it "literal - default pattern" $ do
let before = [expr|
i1 <- pure 3
case i1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
1 @ alt3 -> pure 3
2 @ alt4 -> pure 4
#default @ alt5 -> pure 5
|]
let after = [expr|
i1 <- pure 3
case i1 of
(CNil) @ alt1 -> pure 1
(CNode a1) @ alt2 -> pure 2
1 @ alt3 -> pure 3
2 @ alt4 -> pure 4
#default @ alt5 -> pure 5
|]
constantPropagation before `sameAs` after

0 comments on commit 4972b34

Please sign in to comment.