4 import System.Console.GetOpt
5 import System.Environment
6 import Language.Haskell.Exts.Syntax as Exts
7 import Language.Haskell.Exts.Parser
8 import Language.Haskell.Exts.Pretty
15 import Language.Pointfree.Pretty
16 import Language.Pointfree.Syntax as Pf
17 import Language.Pointwise.Syntax as Pw
18 import Language.Pointwise.Pretty
19 import Language.Pointwise.Parser
20 import Language.Pointwise.Matching
21 import Generics.Pointless.Combinators
22 import Control.Monad.State
23 import Data.Generics.Schemes
24 import Data.Generics.Aliases
28 data Flag = Input String | Output String | Fixify | Pointwise | Observable
31 options :: [OptDescr Flag]
32 options = [Option "o" ["output"] (OptArg outp "FILE") "output FILE",
33 Option "i" ["input"] (OptArg inp "FILE") "input FILE",
34 Option "f" ["fix"] (NoArg Fixify) "use fixpoints instead of hylomorphisms",
35 Option "w" ["pointwise"] (NoArg Pointwise) "do not convert to point-free",
36 Option "O" ["observable"] (NoArg Observable) "generate observable hylomorphisms"
39 inp,outp :: Maybe String -> Flag
40 outp = Output . fromMaybe "stdout"
41 inp = Input . fromMaybe "stdin"
43 parseOpts :: [String] -> IO [Flag]
44 parseOpts opts = case getOpt Permute options opts
45 of (l,[],[]) -> return l
46 (_,_,errs) -> fail (concat errs ++"\n"++ usageInfo header options)
47 where header = "DrHylo derives point-free hylomorphisms from restricted Haskell syntax\n\nUsage: DrHylo [OPTION...]"
49 getInput :: [Flag] -> IO Handle
50 getInput [] = return stdin
51 getInput ((Input i):_) | i=="stdin" = return stdin
52 | otherwise = openFile i ReadMode
53 getInput (_:l) = getInput l
55 getOutput :: [Flag] -> IO Handle
56 getOutput [] = return stdout
57 getOutput ((Output i):_) | i=="stdout" = return stdout
58 | otherwise = openFile i WriteMode
59 getOutput (_:l) = getOutput l
61 fixrequired :: [Flag] -> Bool
62 fixrequired = elem Fixify
64 pwrequired :: [Flag] -> Bool
65 pwrequired = elem Pointwise
67 obrequired :: [Flag] -> Bool
68 obrequired = elem Observable
72 parse :: String -> IO Module
73 parse s = case parseModule s
74 of ParseOk m -> return m
75 ParseFailed l d -> fail (show l ++ ": " ++ d)
77 -- Generation of Observable function contexts
79 isTypeSig :: String -> Decl -> Bool
80 isTypeSig name (TypeSig _ x _) = elem (Ident name) x
83 getTypeVars :: Exts.Type -> [Name]
84 getTypeVars = everything (++) ([] `mkQ` getVar)
85 where getVar :: Exts.Type -> [Name]
86 getVar (TyVar v) = [v]
89 addTypeSig :: Decl -> Decl
90 addTypeSig (TypeSig loc names t) = TypeSig loc names (aux t)
92 aux (TyForall mb ctx (TyFun a b)) = TyForall mb (ctx++inst typeable a b++inst observable a b) (TyFun a b)
93 aux (TyFun a b) = TyForall Nothing (inst typeable a b++inst observable a b) (TyFun a b)
94 vars a b = nub $ intersect (getTypeVars a) (getTypeVars b)
95 inst cl a = map (mkInsVar cl) . vars a
97 mkInsVar :: Name -> Name -> Asst
98 mkInsVar cl n = ClassA (UnQual cl) [TyVar n]
100 addTypeableObservableIns :: String -> [Decl] -> [Decl]
101 addTypeableObservableIns n [] = []
102 addTypeableObservableIns n (d:ds) | isTypeSig n d = addTypeSig d : addTypeableObservableIns n ds
103 | otherwise = d : addTypeableObservableIns n ds
105 -- From Pointwise to Point-free (or not)
107 pwpfModule :: [Flag] -> [(String,Pw.Term)] -> Module -> Module
108 pwpfModule f c (Module loc name pragmas warnings exports imports decls) = Module loc name pragmas' warnings exports imports decls''
110 (decls',obs) = (id >< catMaybes) $ unzip $ map aux decls
111 decls'' = if obrequired f then foldr addTypeableObservableIns decls' obs else decls'
112 pragmaNames = if obrequired f then ["TypeFamilies,","DeriveDataTypeable"] else ["TypeFamilies"]
113 pragmas' = LanguagePragma loc (map Ident pragmaNames) : pragmas
114 aux d = case pwpfDecl f c d
115 of Just (d',mb) -> (d',mb)
116 Nothing -> (d,Nothing)
118 consts :: [(String,Pw.Term)]
119 consts = [("[]", In (Inl Unit)),(":", Lam "h" (Lam "t" (In (Inr (Pw.Var "h" :&: Pw.Var "t")))))]
121 pwpfDecl :: [Flag] -> [(String,Pw.Term)] -> Decl -> Maybe (Decl,Maybe String)
122 pwpfDecl f d (PatBind loc (PVar (Ident name)) mtyp (UnGuardedRhs rhs) (BDecls [])) =
124 pw0 <- return (step (replace (d++consts) pw))
125 pw1 <- evalStateT (nomatch pw0) 0
126 pw2 <- return (if name `elem` free pw1
127 then Pw.Fix (Lam name pw1)
129 pw3 <- return (subst (map (\v -> (v, Pw.Const v)) (free pw2)) pw2)
130 (rhs',ob) <- return (if pwrequired f
131 then (pw2hs pw3,Nothing)
132 else if not (fixrequired f) && derivable pw3
133 then let (Pw.Fix (Lam nam (Lam x z))) = pw3
135 a = Lam "__" (alg z nam (Pw.Var "__"))
136 c = Lam x (coa z nam)
137 hyl = if obrequired f then HyloO else Hylo
138 in (pf2hs (hyl (Pf.Fix t) (unpoint (pwpf [] a)) (unpoint (pwpf [] c))),Just name)
139 else (pf2hs (unpoint (pwpf [] pw3)),Nothing))
140 return (PatBind loc (PVar (Ident name)) mtyp (UnGuardedRhs rhs') (BDecls []),ob)
141 pwpfDecl _ _ _ = fail "The transformation must be applied to simple declarations"
149 mkImportDecl :: String -> ImportDecl
150 mkImportDecl n = ImportDecl loc0 (ModuleName n) False False Nothing Nothing
152 getImportName :: ImportDecl -> String
153 getImportName (ImportDecl _ (ModuleName n) _ _ _ _) = n
155 handleImports :: Bool -> Module -> Module
156 handleImports b (Module loc name pragmas warnings exports imports decls) =
157 let aux True = ["Generics.Pointless.Combinators", "Generics.Pointless.Functors", "Generics.Pointless.RecursionPatterns", "Data.Typeable", "Debug.Observe", "Generics.Pointless.Observe.Functors", "Generics.Pointless.Observe.RecursionPatterns"]
158 aux False = ["Generics.Pointless.Combinators", "Generics.Pointless.Functors", "Generics.Pointless.RecursionPatterns"]
159 aux' = aux b \\ map getImportName imports
160 imports' = imports ++ map mkImportDecl aux'
161 in Module loc name pragmas warnings exports imports' decls
167 main = do opts <- getArgs
168 flags <- parseOpts opts
169 let ob = obrequired flags
170 ihandle <- getInput flags
171 ohandle <- getOutput flags
172 source <- hGetContents ihandle
173 hsModule <- parse source
174 hsModule0 <- return (casificate hsModule)
175 hsModule1 <- return (functorOfInst ob hsModule0)
176 hsModule2 <- return (pwpfModule flags (getCtx hsModule1) hsModule1)
177 hPutStrLn ohandle (prettyPrint (handleImports ob hsModule2))