/ src /
/src/Matching.hs
1 module Matching (casificate) where
2
3 import Language.Haskell.Exts.Syntax
4 import Control.Monad.State
5
6
7 {- Convert arguments in functions to "case of"'s, as in the example:
8 f 0 y = e1 \ ---> / f v = case v of (1,y) -> e1
9 f x y = e2 / \ (x,y) -> e2
10 -}
11 casificate :: Module -> Module
12 casificate (Module a b c d e i decls) =
13 let seed = "v"
14 newDecls = evalState (mapM cas_decl decls) seed
15 in Module a b c d e i newDecls
16
17
18 type ST a = State String a
19
20 cas_decl :: Decl -> ST Decl
21 cas_decl (FunBind ms)
22 = do newMs <- cas_matches ms
23 return $ PatBind mkLoc (PVar $ getName $ head ms)
24 Nothing (UnGuardedRhs newMs) (BDecls [])
25 where getName (Match _ name _ _ _ _) = name
26 cas_decl x = return x
27
28 cas_matches :: [Match] -> ST Exp
29 cas_matches ms@((Match _ _ pats _ _ _):_)
30 = do alts <- mapM cas_alt ms
31 seed <- gets id
32 let npats = length pats
33 return $ buildPVars npats seed $
34 Case (buildVars npats seed) alts
35 where buildPVars 1 v = Lambda mkLoc [mkPVar $ v ++ "1"]
36 buildPVars n v = Lambda mkLoc [mkPVar $ v ++ show n] . buildPVars (n-1) v
37 buildVars 1 v = mkVar $ v++"1"
38 buildVars n v = Tuple $ [mkVar $ v++show n,buildVars (n-1) v]
39
40 cas_alt :: Match -> ST Alt
41 cas_alt (Match l _ pats _ expRhs (BDecls ds))
42 = do ds' <- mapM cas_decl ds
43 return $ Alt l (pat pats) altRhs (BDecls ds')
44 where pat [x] = x
45 pat (x:xs) = PTuple [x,pat xs]
46 altRhs = case expRhs of
47 UnGuardedRhs exp -> UnGuardedAlt exp
48 GuardedRhss x -> GuardedAlts (map aux x)
49 aux (GuardedRhs l x y) = (GuardedAlt l x y)
50
51
52 -- auxiliary functions
53 mkVar = Var . UnQual . Ident
54 mkPVar = PVar . Ident
55 mkLoc = SrcLoc "" 0 0
56
57 {-
58 getSeed :: Data a => a -> String
59 getSeed = flip replicate 'x' .
60 maximum . (1:) .
61 everything (++) (mkQ [] aux)
62 where aux = (:[]) . (+1) . length . takeWhile (=='x')
63 -}