/ src /
/src/DrHylo.hs
1 module Main where
2
3 import System.IO
4 import System.Console.GetOpt
5 import System.Environment
6 import Language.Haskell.Exts.Syntax as Exts
7 import Language.Haskell.Exts.Parser
8 import Language.Haskell.Exts.Pretty
9 import Data.Maybe
10 import Data.List
11 import PwPf
12 import Matching
13 import FunctorOf
14 import Hylos
15 import Language.Pointfree.Pretty
16 import Language.Pointfree.Syntax as Pf
17 import Language.Pointwise.Syntax as Pw
18 import Language.Pointwise.Pretty
19 import Language.Pointwise.Parser
20 import Language.Pointwise.Matching
21 import Generics.Pointless.Combinators
22 import Control.Monad.State
23 import Data.Generics.Schemes
24 import Data.Generics.Aliases
25
26 -- Managing Options
27
28 data Flag = Input String | Output String | Fixify | Pointwise | Observable
29 deriving Eq
30
31 options :: [OptDescr Flag]
32 options = [Option "o" ["output"] (OptArg outp "FILE") "output FILE",
33 Option "i" ["input"] (OptArg inp "FILE") "input FILE",
34 Option "f" ["fix"] (NoArg Fixify) "use fixpoints instead of hylomorphisms",
35 Option "w" ["pointwise"] (NoArg Pointwise) "do not convert to point-free",
36 Option "O" ["observable"] (NoArg Observable) "generate observable hylomorphisms"
37 ]
38
39 inp,outp :: Maybe String -> Flag
40 outp = Output . fromMaybe "stdout"
41 inp = Input . fromMaybe "stdin"
42
43 parseOpts :: [String] -> IO [Flag]
44 parseOpts opts = case getOpt Permute options opts
45 of (l,[],[]) -> return l
46 (_,_,errs) -> fail (concat errs ++"\n"++ usageInfo header options)
47 where header = "DrHylo derives point-free hylomorphisms from restricted Haskell syntax\n\nUsage: DrHylo [OPTION...]"
48
49 getInput :: [Flag] -> IO Handle
50 getInput [] = return stdin
51 getInput ((Input i):_) | i=="stdin" = return stdin
52 | otherwise = openFile i ReadMode
53 getInput (_:l) = getInput l
54
55 getOutput :: [Flag] -> IO Handle
56 getOutput [] = return stdout
57 getOutput ((Output i):_) | i=="stdout" = return stdout
58 | otherwise = openFile i WriteMode
59 getOutput (_:l) = getOutput l
60
61 fixrequired :: [Flag] -> Bool
62 fixrequired = elem Fixify
63
64 pwrequired :: [Flag] -> Bool
65 pwrequired = elem Pointwise
66
67 obrequired :: [Flag] -> Bool
68 obrequired = elem Observable
69
70 -- Parsing
71
72 parse :: String -> IO Module
73 parse s = case parseModule s
74 of ParseOk m -> return m
75 ParseFailed l d -> fail (show l ++ ": " ++ d)
76
77 -- Generation of Observable function contexts
78
79 isTypeSig :: String -> Decl -> Bool
80 isTypeSig name (TypeSig _ x _) = elem (Ident name) x
81 isTypeSig _ _ = False
82
83 getTypeVars :: Exts.Type -> [Name]
84 getTypeVars = everything (++) ([] `mkQ` getVar)
85 where getVar :: Exts.Type -> [Name]
86 getVar (TyVar v) = [v]
87 getVar _ = []
88
89 addTypeSig :: Decl -> Decl
90 addTypeSig (TypeSig loc names t) = TypeSig loc names (aux t)
91 where
92 aux (TyForall mb ctx (TyFun a b)) = TyForall mb (ctx++inst typeable a b++inst observable a b) (TyFun a b)
93 aux (TyFun a b) = TyForall Nothing (inst typeable a b++inst observable a b) (TyFun a b)
94 vars a b = nub $ intersect (getTypeVars a) (getTypeVars b)
95 inst cl a = map (mkInsVar cl) . vars a
96
97 mkInsVar :: Name -> Name -> Asst
98 mkInsVar cl n = ClassA (UnQual cl) [TyVar n]
99
100 addTypeableObservableIns :: String -> [Decl] -> [Decl]
101 addTypeableObservableIns n [] = []
102 addTypeableObservableIns n (d:ds) | isTypeSig n d = addTypeSig d : addTypeableObservableIns n ds
103 | otherwise = d : addTypeableObservableIns n ds
104
105 -- From Pointwise to Point-free (or not)
106
107 pwpfModule :: [Flag] -> [(String,Pw.Term)] -> Module -> Module
108 pwpfModule f c (Module loc name pragmas warnings exports imports decls) = Module loc name pragmas' warnings exports imports decls''
109 where
110 (decls',obs) = (id >< catMaybes) $ unzip $ map aux decls
111 decls'' = if obrequired f then foldr addTypeableObservableIns decls' obs else decls'
112 pragmaNames = if obrequired f then ["TypeFamilies,","DeriveDataTypeable"] else ["TypeFamilies"]
113 pragmas' = LanguagePragma loc (map Ident pragmaNames) : pragmas
114 aux d = case pwpfDecl f c d
115 of Just (d',mb) -> (d',mb)
116 Nothing -> (d,Nothing)
117
118 consts :: [(String,Pw.Term)]
119 consts = [("[]", In (Inl Unit)),(":", Lam "h" (Lam "t" (In (Inr (Pw.Var "h" :&: Pw.Var "t")))))]
120
121 pwpfDecl :: [Flag] -> [(String,Pw.Term)] -> Decl -> Maybe (Decl,Maybe String)
122 pwpfDecl f d (PatBind loc (PVar (Ident name)) mtyp (UnGuardedRhs rhs) (BDecls [])) =
123 do pw <- hs2pw rhs
124 pw0 <- return (step (replace (d++consts) pw))
125 pw1 <- evalStateT (nomatch pw0) 0
126 pw2 <- return (if name `elem` free pw1
127 then Pw.Fix (Lam name pw1)
128 else pw1)
129 pw3 <- return (subst (map (\v -> (v, Pw.Const v)) (free pw2)) pw2)
130 (rhs',ob) <- return (if pwrequired f
131 then (pw2hs pw3,Nothing)
132 else if not (fixrequired f) && derivable pw3
133 then let (Pw.Fix (Lam nam (Lam x z))) = pw3
134 t = fun z nam
135 a = Lam "__" (alg z nam (Pw.Var "__"))
136 c = Lam x (coa z nam)
137 hyl = if obrequired f then HyloO else Hylo
138 in (pf2hs (hyl (Pf.Fix t) (unpoint (pwpf [] a)) (unpoint (pwpf [] c))),Just name)
139 else (pf2hs (unpoint (pwpf [] pw3)),Nothing))
140 return (PatBind loc (PVar (Ident name)) mtyp (UnGuardedRhs rhs') (BDecls []),ob)
141 pwpfDecl _ _ _ = fail "The transformation must be applied to simple declarations"
142
143
144 -- Handle imports
145
146 loc0 :: SrcLoc
147 loc0 = SrcLoc "" 0 0
148
149 mkImportDecl :: String -> ImportDecl
150 mkImportDecl n = ImportDecl loc0 (ModuleName n) False False Nothing Nothing
151
152 getImportName :: ImportDecl -> String
153 getImportName (ImportDecl _ (ModuleName n) _ _ _ _) = n
154
155 handleImports :: Bool -> Module -> Module
156 handleImports b (Module loc name pragmas warnings exports imports decls) =
157 let aux True = ["Generics.Pointless.Combinators", "Generics.Pointless.Functors", "Generics.Pointless.RecursionPatterns", "Data.Typeable", "Debug.Observe", "Generics.Pointless.Observe.Functors", "Generics.Pointless.Observe.RecursionPatterns"]
158 aux False = ["Generics.Pointless.Combinators", "Generics.Pointless.Functors", "Generics.Pointless.RecursionPatterns"]
159 aux' = aux b \\ map getImportName imports
160 imports' = imports ++ map mkImportDecl aux'
161 in Module loc name pragmas warnings exports imports' decls
162
163
164 -- Main
165
166 main :: IO ()
167 main = do opts <- getArgs
168 flags <- parseOpts opts
169 let ob = obrequired flags
170 ihandle <- getInput flags
171 ohandle <- getOutput flags
172 source <- hGetContents ihandle
173 hsModule <- parse source
174 hsModule0 <- return (casificate hsModule)
175 hsModule1 <- return (functorOfInst ob hsModule0)
176 hsModule2 <- return (pwpfModule flags (getCtx hsModule1) hsModule1)
177 hPutStrLn ohandle (prettyPrint (handleImports ob hsModule2))
178 hClose ihandle
179 hClose ohandle