0 | ||| Deriving traversable instances using reflection
  1 | ||| You can for instance define:
  2 | ||| ```
  3 | ||| data Tree a = Leaf a | Node (Tree a) (Tree a)
  4 | ||| treeFoldable : Traversable Tree
  5 | ||| treeFoldable = %runElab derive
  6 | ||| ```
  7 |
  8 | module Deriving.Traversable
  9 |
 10 | import public Control.Monad.Either
 11 | import public Control.Monad.State
 12 | import public Data.List1
 13 | import public Data.Maybe
 14 | import public Data.Morphisms
 15 | import public Decidable.Equality
 16 | import public Language.Reflection
 17 |
 18 | import public Deriving.Common
 19 |
 20 | %language ElabReflection
 21 | %default total
 22 |
 23 | ------------------------------------------------------------------------------
 24 | -- Errors
 25 |
 26 | ||| Possible errors for the functor-deriving machinery.
 27 | public export
 28 | data Error : Type where
 29 |   NotFreeOf : Name -> TTImp -> Error
 30 |   NotAnApplication : TTImp -> Error
 31 |   NotATraversable : TTImp -> Error
 32 |   NotABitraversable : TTImp -> Error
 33 |   NotTraversableInItsLastArg : TTImp -> Error
 34 |   UnsupportedType : TTImp -> Error
 35 |   NotAFiniteStructure : Error
 36 |   NotAnUnconstrainedValue : Count -> Error
 37 |   InvalidGoal : Error
 38 |   ConfusingReturnType : Error
 39 |   -- Contextual information
 40 |   WhenCheckingConstructor : Name -> Error -> Error
 41 |   WhenCheckingArg : TTImp -> Error -> Error
 42 |
 43 | export
 44 | Show Error where
 45 |   show = joinBy "\n" . go [<] where
 46 |
 47 |     go : SnocList String -> Error -> List String
 48 |     go acc (NotFreeOf x ty) = acc <>> ["The term \{show ty} is not free of \{show x}"]
 49 |     go acc (NotAnApplication s) = acc <>> ["The type \{show s} is not an application"]
 50 |     go acc (NotATraversable s) = acc <>> ["Couldn't find a `Traversable' instance for the type constructor \{show s}"]
 51 |     go acc (NotABitraversable s) = acc <>> ["Couldn't find a `Bitraversable' instance for the type constructor \{show s}"]
 52 |     go acc (NotTraversableInItsLastArg s) = acc <>> ["Not traversable in its last argument \{show s}"]
 53 |     go acc (UnsupportedType s) = acc <>> ["Unsupported type \{show s}"]
 54 |     go acc NotAFiniteStructure = acc <>> ["Cannot traverse an infinite structure"]
 55 |     go acc (NotAnUnconstrainedValue rig) = acc <>> ["Cannot traverse a \{enunciate rig} value"]
 56 |     go acc InvalidGoal = acc <>> ["Expected a goal of the form `Traversable f`"]
 57 |     go acc ConfusingReturnType = acc <>> ["Confusing telescope"]
 58 |     go acc (WhenCheckingConstructor nm err) = go (acc :< "When checking constructor \{show nm}") err
 59 |     go acc (WhenCheckingArg s err) = go (acc :< "When checking argument of type \{show s}") err
 60 |
 61 | record Parameters where
 62 |   constructor MkParameters
 63 |   asTraversables : List Nat
 64 |   asBitraversables : List Nat
 65 |
 66 | initParameters : Parameters
 67 | initParameters = MkParameters [] []
 68 |
 69 | paramConstraints : Parameters -> Nat -> Maybe TTImp
 70 | paramConstraints params pos
 71 |     = IVar emptyFC `{Prelude.Interfaces.Traversable}   <$ guard (pos `elem` params.asTraversables)
 72 |   <|> IVar emptyFC `{Prelude.Interfaces.Bitraversable} <$ guard (pos `elem` params.asBitraversables)
 73 |
 74 | ------------------------------------------------------------------------------
 75 | -- Core machinery: being traversable
 76 |
 77 | -- Not meant to be re-exported as it's using the internal notion of error
 78 | isFreeOf' :
 79 |   {0 m : Type -> Type} ->
 80 |   {auto elab : Elaboration m} ->
 81 |   {auto error : MonadError Error m} ->
 82 |   (x : Name) -> (ty : TTImp) -> m (IsFreeOf x ty)
 83 | isFreeOf' x ty = case isFreeOf x ty of
 84 |   Nothing => throwError (NotFreeOf x ty)
 85 |   Just prf => pure prf
 86 |
 87 | ||| IsTraversableIn is parametrised by
 88 | ||| @ t  the name of the data type whose constructors are being analysed
 89 | ||| @ x  the name of the type variable that the traversable action will act on
 90 | ||| @ ty the type being analysed
 91 | ||| The inductive type delivers a proof that x can be traversed over in ty,
 92 | ||| assuming that t also is traversable.
 93 | public export
 94 | data IsTraversableIn : (t, x : Name) -> (ty : TTImp) -> Type where
 95 |   ||| The type variable x occurs alone
 96 |   TIVar : IsTraversableIn t x (IVar fc x)
 97 |   ||| There is a recursive subtree of type (t a1 ... an u) and u is Traversable in x.
 98 |   ||| We do not insist that u is exactly x so that we can deal with nested types
 99 |   ||| like the following:
100 |   |||   data Full a = Leaf a | Node (Full (a, a))
101 |   |||   data Term a = Var a | App (Term a) (Term a) | Lam (Term (Maybe a))
102 |   TIRec : (0 _ : IsAppView (_, t) _ f) -> IsTraversableIn t x arg ->
103 |           IsTraversableIn t x (IApp fc f arg)
104 |   ||| The subterm is delayed (Lazy only, we can't traverse infinite structures)
105 |   TIDelayed : IsTraversableIn t x ty -> IsTraversableIn t x (IDelayed fc LLazy ty)
106 |   ||| There are nested subtrees somewhere inside a 3rd party type constructor
107 |   ||| which satisfies the Bitraversable interface
108 |   TIBifold : IsFreeOf x sp -> HasImplementation Bitraversable sp ->
109 |              IsTraversableIn t x arg1 -> Either (IsTraversableIn t x arg2) (IsFreeOf x arg2) ->
110 |              IsTraversableIn t x (IApp fc1 (IApp fc2 sp arg1) arg2)
111 |   ||| There are nested subtrees somewhere inside a 3rd party type constructor
112 |   ||| which satisfies the Traversable interface
113 |   TIFold : IsFreeOf x sp -> HasImplementation Traversable sp ->
114 |            IsTraversableIn t x arg -> IsTraversableIn t x (IApp fc sp arg)
115 |   ||| A type free of x is trivially Traversable in it
116 |   TIFree : IsFreeOf x a -> IsTraversableIn t x a
117 |
118 | parameters
119 |   {0 m : Type -> Type}
120 |   {auto elab : Elaboration m}
121 |   {auto error : MonadError Error m}
122 |   {auto cs : MonadState Parameters m}
123 |   (t : Name)
124 |   (ps : List (Name, Nat))
125 |   (x : Name)
126 |
127 |   ||| When analysing the type of a constructor for the type family t,
128 |   ||| we hope to observe
129 |   |||   1. either that it is traversable in x
130 |   |||   2. or that it is free of x
131 |   ||| If it is not the case, we will use the `MonadError Error` constraint
132 |   ||| to fail with an informative message.
133 |   public export
134 |   TypeView : TTImp -> Type
135 |   TypeView ty = Either (IsTraversableIn t x ty) (IsFreeOf x ty)
136 |
137 |   export
138 |   fromTypeView : TypeView ty -> IsTraversableIn t x ty
139 |   fromTypeView (Left prf) = prf
140 |   fromTypeView (Right fo) = TIFree fo
141 |
142 |   ||| Hoping to observe that ty is traversable
143 |   export
144 |   typeView : (ty : TTImp) -> m (TypeView ty)
145 |
146 |   ||| To avoid code duplication in typeView, we have an auxiliary function
147 |   ||| specifically to handle the application case
148 |   typeAppView :
149 |     {fc : FC} ->
150 |     {f : TTImp} -> IsFreeOf x f ->
151 |     (arg : TTImp) ->
152 |     m (TypeView (IApp fc f arg))
153 |
154 |   typeAppView {fc, f} isFO arg = do
155 |     chka <- typeView arg
156 |     case chka of
157 |       -- if x is present in the argument then the function better be:
158 |       -- 1. free of x
159 |       -- 2. either an occurrence of t i.e. a subterm
160 |       --    or a type constructor already known to be functorial
161 |       Left sp => do
162 |         let Just (MkAppView (_, hd) ts prf) = appView f
163 |            | _ => throwError (NotAnApplication f)
164 |         case decEq t hd of
165 |           Yes Refl => pure $ Left (TIRec prf sp)
166 |           No diff => case !(hasImplementation Traversable f) of
167 |             Just prf => pure (Left (TIFold isFO prf sp))
168 |             Nothing => case lookup hd ps of
169 |               Just n => do
170 |                 -- record that the nth parameter should be functorial
171 |                 ns <- gets asTraversables
172 |                 let ns = ifThenElse (n `elem` ns) ns (n :: ns)
173 |                 modify { asTraversables := ns }
174 |                 -- and happily succeed
175 |                 logMsg "derive.traversable.assumption" 10 $
176 |                   "I am assuming that the parameter \{show hd} is a Traversable"
177 |                 pure (Left (TIFold isFO assert_hasImplementation sp))
178 |               Nothing => throwError (NotATraversable f)
179 |       -- Otherwise it better be the case that f is also free of x so that
180 |       -- we can mark the whole type as being x-free.
181 |       Right fo => do
182 |         Right _ <- typeView f
183 |           | _ => throwError $ NotTraversableInItsLastArg (IApp fc f arg)
184 |         pure (Right assert_IsFreeOf)
185 |
186 |   typeView tm@(IVar fc y) = case decEq x y of
187 |     Yes Refl => pure (Left TIVar)
188 |     No _ => pure (Right assert_IsFreeOf)
189 |   typeView fab@(IApp _ (IApp fc1 f arg1) arg2) = do
190 |     chka1 <- typeView arg1
191 |     case chka1 of
192 |       Right _ => do isFO <- isFreeOf' x (IApp _ f arg1)
193 |                     typeAppView {f = assert_smaller fab (IApp _ f arg1)} isFO arg2
194 |       Left sp => do
195 |         isFO <- isFreeOf' x f
196 |         case !(hasImplementation Bitraversable f) of
197 |           Just prf => pure (Left (TIBifold isFO prf sp !(typeView arg2)))
198 |           Nothing => do
199 |             let Just (MkAppView (_, hd) ts prf) = appView f
200 |                | _ => throwError (NotAnApplication f)
201 |             case lookup hd ps of
202 |               Just n => do
203 |                 -- record that the nth parameter should be bitraversable
204 |                 ns <- gets asBitraversables
205 |                 let ns = ifThenElse (n `elem` ns) ns (n :: ns)
206 |                 modify { asBitraversables := ns }
207 |                 -- and happily succeed
208 |                 logMsg "derive.traversable.assumption" 10 $
209 |                     "I am assuming that the parameter \{show hd} is a Bitraversable"
210 |                 pure (Left (TIBifold isFO assert_hasImplementation sp !(typeView arg2)))
211 |               Nothing => throwError (NotABitraversable f)
212 |   typeView (IApp _ f arg) = do
213 |     isFO <- isFreeOf' x f
214 |     typeAppView isFO arg
215 |   typeView (IDelayed _ lz f) = case !(typeView f) of
216 |     Left sp => case lz of
217 |       LLazy => pure (Left (TIDelayed sp))
218 |       _ => throwError NotAFiniteStructure
219 |     Right _ => pure (Right assert_IsFreeOf)
220 |   typeView (IPrimVal _ _) = pure (Right assert_IsFreeOf)
221 |   typeView (IType _) = pure (Right assert_IsFreeOf)
222 |   typeView ty = case isFreeOf x ty of
223 |     Nothing => throwError (UnsupportedType ty)
224 |     Just prf => pure (Right prf)
225 |
226 | ------------------------------------------------------------------------------
227 | -- Core machinery: building the traverse function from an IsTraversableIn proof
228 |
229 | parameters (fc : FC) (mutualWith : List Name)
230 |
231 |   ||| traverseFun takes
232 |   ||| @ mutualWith a list of mutually defined type constructors. Calls to their
233 |   ||| respective mapping functions typically need an assert_total because the
234 |   ||| termination checker is not doing enough inlining to see that things are
235 |   ||| terminating
236 |   ||| @ assert records whether we should mark recursive calls as total because
237 |   ||| we are currently constructing the argument to a higher order function
238 |   ||| which will obscure the termination argument. Starts as `Nothing`, becomes
239 |   ||| `Just False` if an `assert_total` has already been inserted.
240 |   ||| @ ty the type being transformed by the mapping function
241 |   ||| @ rec the name of the mapping function being defined (used for recursive calls)
242 |   ||| @ f the name of the function we're mapping
243 |   ||| @ arg the (optional) name of the argument being mapped over. This lets us use
244 |   ||| Nothing when generating arguments to higher order functions so that we generate
245 |   ||| the eta contracted `map (mapTree f)` instead of `map (\ ts => mapTree f ts)`.
246 |   traverseFun : (assert : Maybe Bool) -> {ty : TTImp} -> IsTraversableIn t x ty ->
247 |                 (rec, f : Name) -> (arg : Maybe TTImp) -> TTImp
248 |   traverseFun assert TIVar rec f t = apply fc (IVar fc f) (toList t)
249 |   traverseFun assert (TIRec y sp) rec f t
250 |     -- only add assert_total if it is declared to be needed
251 |     = ifThenElse (fromMaybe False assert) (IApp fc (IVar fc (UN $ Basic "assert_total"))) id
252 |     $ apply fc (IVar fc rec) (traverseFun (Just False) sp rec f Nothing :: toList t)
253 |   traverseFun assert (TIDelayed sp) rec f Nothing
254 |     -- here we need to eta-expand to avoid "Lazy t does not unify with t" errors
255 |     = let nm = UN $ Basic "eta" in
256 |       ILam fc MW ExplicitArg (Just nm) (IDelayed fc LLazy (Implicit fc False))
257 |     $ apply fc `((<$>))
258 |     [ `(delay)
259 |     , traverseFun assert sp rec f (Just (IVar fc nm))
260 |     ]
261 |   traverseFun assert (TIDelayed sp) rec f (Just t)
262 |     = apply fc `((<$>))
263 |     [ `(delay)
264 |     , traverseFun assert sp rec f (Just t)
265 |     ]
266 |   traverseFun assert {ty = IApp _ ty _} (TIFold _ _ sp) rec f t
267 |     -- only add assert_total if we are calling a mutually defined Traversable implementation.
268 |     = let isMutual = fromMaybe False (appView ty >>= \ v => pure (snd v.head `elem` mutualWith)) in
269 |       ifThenElse isMutual (IApp fc (IVar fc (UN $ Basic "assert_total"))) id
270 |     $ apply fc (IVar fc (UN $ Basic "traverse"))
271 |       (traverseFun ((False <$ guard isMutual) <|> assert <|> Just True) sp rec f Nothing
272 |        :: toList t)
273 |   traverseFun assert (TIBifold _ _ sp1 (Left sp2)) rec f t
274 |     = apply fc (IVar fc (UN $ Basic "bitraverse"))
275 |       (traverseFun (assert <|> Just True) sp1 rec f Nothing
276 |       :: traverseFun (assert <|> Just True) sp2 rec f Nothing
277 |       :: toList t)
278 |   traverseFun assert (TIBifold _ _ sp (Right _)) rec f t
279 |     = apply fc (IVar fc (UN $ Basic "bitraverseFst"))
280 |       (traverseFun (assert <|> Just True) sp rec f Nothing
281 |       :: toList t)
282 |   traverseFun assert (TIFree y) rec f t = `(mempty)
283 |
284 | ------------------------------------------------------------------------------
285 | -- User-facing: Traversable deriving
286 |
287 | applyA : FC -> TTImp -> List (Either (Argument TTImp) TTImp) -> TTImp
288 | applyA fc c [] = `(pure ~(c))
289 | applyA fc c (Right a :: as) = applyA fc (IApp fc c a) as
290 | applyA fc c as =
291 |   let (pref, suff) = spanBy canBeApplied ([<] <>< as) in
292 |   let (lams, args, vals) = preEta 0 (pref <>> []) in
293 |   let eta = foldr (\ x => ILam fc MW ExplicitArg (Just x) (Implicit fc False)) (apply c args) lams in
294 |   fire eta (map Left vals ++ (suff <>> []))
295 |
296 |   where
297 |
298 |     canBeApplied : Either (Argument TTImp) TTImp -> Maybe (Either TTImp TTImp)
299 |     canBeApplied (Left (Arg _ t)) = pure (Left t)
300 |     canBeApplied (Right t) = pure (Right t)
301 |     canBeApplied _ = Nothing
302 |
303 |     preEta : Nat -> List (Either (Argument TTImp) TTImp) ->
304 |              (List Name, List (Argument TTImp), List TTImp)
305 |     preEta n [] = ([], [], [])
306 |     preEta n (a :: as) =
307 |       let (n, ns, args, vals) = the (Nat, List Name, List (Argument TTImp), List _) $
308 |             let x = UN (Basic ("y" ++ show n))vx = IVar fc x in case a of
309 |               Left (Arg fc t) => (S n, [x], [Arg fc vx], [t])
310 |               Left (NamedArg fc nm t) => (S n, [x], [NamedArg fc nm vx], [t])
311 |               Left (AutoArg fc t) => (S n, [x], [AutoArg fc vx], [t])
312 |               Right t => (n, [], [Arg fc t], [])
313 |       in
314 |       let (nss, argss, valss) = preEta n as in
315 |       (ns ++ nss, args ++ argss, vals ++ valss)
316 |
317 |     go : TTImp -> List (Either TTImp TTImp) -> TTImp
318 |     go f [] = f
319 |     go f (Left a :: as) = go (apply fc `((<*>)) [f, a]) as
320 |     go f (Right a :: as) = go (apply fc `((<*>)) [f, IApp fc `(pure) a]) as
321 |
322 |     fire : TTImp -> List (Either TTImp TTImp) -> TTImp
323 |     fire f [] = f
324 |     fire f (a :: as) = go (apply fc `((<$>)) [f, either id id a]) as
325 |
326 | namespace Traversable
327 |
328 |   derive' : (Elaboration m, MonadError Error m) =>
329 |             {default Private vis : Visibility} ->
330 |             {default Total treq : TotalReq} ->
331 |             {default [] mutualWith : List Name} ->
332 |             m (Traversable f)
333 |   derive' = do
334 |
335 |     -- expand the mutualwith names to have the internal, fully qualified, names
336 |     mutualWith <- map concat $ for mutualWith $ \ nm => do
337 |                     ntys <- getType nm
338 |                     pure (fst <$> ntys)
339 |
340 |     -- The goal should have the shape (Traversable t)
341 |     Just (IApp _ (IVar _ traversable) t) <- goal
342 |       | _ => throwError InvalidGoal
343 |     when (`{Prelude.Interfaces.Traversable} /= traversable) $
344 |       logMsg "derive.traversable" 1 "Expected to derive Traversable but got \{show traversable}"
345 |
346 |     -- t should be a type constructor
347 |     logMsg "derive.traversable" 1 "Deriving Traversable for \{showPrec App $ mapTTImp cleanup t}"
348 |     MkIsType f params cs <- isType t
349 |     logMsg "derive.traversable.constructors" 1 $
350 |       joinBy "\n" $ "" :: map (\ (n, ty) => "  \{showPrefix True $ dropNS n} : \{show $ mapTTImp cleanup ty}") cs
351 |
352 |     -- Generate a clause for each data constructor
353 |     let fc = emptyFC
354 |     let un = UN . Basic
355 |     let traverseName = un ("traverse" ++ show (dropNS f))
356 |     let funName = un "f"
357 |     let fun  = IVar fc funName
358 |     (ns, cls) <- runStateT {m = m} initParameters $ for cs $ \ (cName, ty) =>
359 |       withError (WhenCheckingConstructor cName) $ do
360 |         -- Grab the types of the constructor's explicit arguments
361 |         let Just (MkConstructorView (paraz :< (para, _)) args) = constructorView ty
362 |               | _ => throwError ConfusingReturnType
363 |         let paras = paraz <>> []
364 |         logMsg "derive.traversable.clauses" 10 $
365 |           "\{showPrefix True (dropNS cName)} (\{joinBy ", " (map (showPrec Dollar . mapTTImp cleanup . unArg . snd) args)})"
366 |         let vars = map (map (IVar fc . un . ("x" ++) . show . (`minus` 1)))
367 |                  $ zipWith (<$) [1..length args] (map snd args)
368 |         recs <- for (zip vars args) $ \ (v, (rig, arg)) => do
369 |                   res <- withError (WhenCheckingArg (mapTTImp cleanup (unArg arg))) $ do
370 |                            res <- typeView f paras para (unArg arg)
371 |                            case res of
372 |                              Left _ => case rig of
373 |                                MW => pure ()
374 |                                _ => throwError (NotAnUnconstrainedValue rig)
375 |                              _ => pure ()
376 |                            pure res
377 |                   pure $ case res of
378 |                     Left sp => -- do not bother with assert_total if you're generating
379 |                                -- a covering/partial definition
380 |                                let useTot = False <$ guard (treq /= Total) in
381 |                                Just (v, Left (traverseFun fc mutualWith useTot sp traverseName funName . Just <$> v))
382 |                     Right free => do ignore $ isExplicit v
383 |                                      Just (v, Right (unArg v))
384 |         let (vars, recs) = unzip (catMaybes recs)
385 |         pure $ PatClause fc
386 |           (apply fc (IVar fc traverseName) [ fun, apply (IVar fc cName) vars])
387 |           (applyA fc (IVar fc cName) recs)
388 |
389 |     -- Generate the type of the mapping function
390 |     let paramNames = unArg . fst <$> params
391 |     let a = un $ freshName paramNames "a"
392 |     let b = un $ freshName paramNames "b"
393 |     let f = un $ freshName paramNames "f"
394 |     let va = IVar fc a
395 |     let vb = IVar fc b
396 |     let vf = IVar fc f
397 |     let ty = MkTy fc (NoFC traverseName) $ withParams fc (paramConstraints ns) params
398 |            $ IPi fc M0 ImplicitArg (Just a) (IType fc)
399 |            $ IPi fc M0 ImplicitArg (Just b) (IType fc)
400 |            $ IPi fc M0 ImplicitArg (Just f) (IPi fc MW ExplicitArg Nothing (IType fc) (IType fc))
401 |            $ `(Applicative ~(vf) => (~(va) -> ~(vf) ~(vb)) -> ~(t) ~(va) -> ~(vf) (~(t) ~(vb)))
402 |     logMsg "derive.traversable.clauses" 1 $
403 |       joinBy "\n" ("" :: ("  " ++ show (mapITy cleanup ty))
404 |                       :: map (("  " ++) . showClause InDecl . mapClause cleanup) cls)
405 |
406 |     -- Define the instance
407 |     check $ ILocal fc
408 |       [ IClaim (MkFCVal fc (MkIClaimData MW vis [Totality treq] ty))
409 |       , IDef fc traverseName cls
410 |       ] `(MkTraversable {t = ~(t)} ~(IVar fc traverseName))
411 |
412 |   ||| Derive an implementation of Traversable for a type constructor.
413 |   ||| This can be used like so:
414 |   ||| ```
415 |   ||| data Tree a = Leaf a | Node (Tree a) (Tree a)
416 |   ||| treeTraversable : Traversable Tree
417 |   ||| treeTraversable = %runElab derive
418 |   ||| ```
419 |   export
420 |   derive : {default Private vis : Visibility} ->
421 |            {default Total treq : TotalReq} ->
422 |            {default [] mutualWith : List Name} ->
423 |            Elab (Traversable f)
424 |   derive = do
425 |     res <- runEitherT {e = Error, m = Elab} (derive' {vis, treq, mutualWith})
426 |     case res of
427 |       Left err => fail (show err)
428 |       Right prf => pure prf
429 |