0 | module Data.Linear.LVect
  1 |
  2 | import Data.Fin
  3 |
  4 | import Data.Linear.Bifunctor
  5 | import Data.Linear.Interface
  6 | import Data.Linear.Notation
  7 | import Data.Linear.LNat
  8 | import Data.Linear.LList
  9 |
 10 | %default total
 11 |
 12 | public export
 13 | data LVect : Nat -> Type -> Type where
 14 |   Nil : LVect Z a
 15 |   (::) : a -@ LVect n a -@ LVect (S n) a
 16 |
 17 | %name LVect xs, ys, zs, ws
 18 |
 19 | export
 20 | lookup : Fin (S n) -@ LVect (S n) a -@ LPair a (LVect n a)
 21 | lookup FZ     (v :: vs) = (v # vs)
 22 | lookup (FS k) (v :: vs@(_ :: _)) = mapSnd (v ::) (lookup k vs)
 23 |
 24 | export
 25 | insertAt : Fin (S n) -@ a -@ LVect n a -@ LVect (S n) a
 26 | insertAt FZ w vs = w :: vs
 27 | insertAt (FS k) w (v :: vs) = v :: insertAt k w vs
 28 |
 29 | export
 30 | uncurry : (a -@ b -@ c) -@ (LPair a b -@ c)
 31 | uncurry f (x # y) = f x y
 32 |
 33 | export
 34 | lookupInsertAtIdentity :
 35 |   (k : Fin (S n)) -> (vs : LVect (S n) a) ->
 36 |   uncurry (insertAt k) (lookup k vs) === vs
 37 | lookupInsertAtIdentity FZ     (v :: xs) = Refl
 38 | lookupInsertAtIdentity (FS k) (v :: w :: ws)
 39 |   with (lookupInsertAtIdentity k (w :: ws)) | (lookup k (w :: ws))
 40 |   _ | prf | (x # xs) = cong (v ::) prf
 41 |
 42 | export
 43 | insertAtLookupIdentity :
 44 |   (k : Fin (S n)) -> (w : a) -> (vs : LVect n a) ->
 45 |   lookup k (insertAt k w vs) === (w # vs)
 46 | insertAtLookupIdentity FZ w vs = Refl
 47 | insertAtLookupIdentity (FS k) w (v :: vs)
 48 |   with (insertAtLookupIdentity k w vs) | (insertAt k w vs)
 49 |   _ | prf | (x :: xs) = cong (\ x => mapSnd (v ::) x) prf
 50 |
 51 | export
 52 | (<$>) : (f : a -@ b) -> LVect n a -@ LVect n b
 53 | f <$> [] = []
 54 | f <$> x :: xs = f x :: (f <$> xs)
 55 |
 56 | export
 57 | pure : {n : Nat} -> a -> LVect n a
 58 | pure {n = Z} _ = []
 59 | pure {n = S n} x = x :: pure x
 60 |
 61 | export
 62 | (<*>) : LVect n (a -@ b) -@ LVect n a -@ LVect n b
 63 | [] <*> [] = []
 64 | f :: fs <*> x :: xs = f x :: (fs <*> xs)
 65 |
 66 | export
 67 | zip : LVect n a -@ LVect n b -@ LVect n (LPair a b)
 68 | zip [] [] = []
 69 | zip (a :: as) (b :: bs) = (a # b) :: zip as bs
 70 |
 71 | export
 72 | unzip : LVect n (LPair a b) -@ LPair (LVect n a) (LVect n b)
 73 | unzip [] = [] # []
 74 | unzip ((a # b) :: abs) = let (as # bs) = LVect.unzip abs in (a :: as # b :: bs)
 75 |
 76 | export
 77 | splitAt : (1 m : Nat) -> LVect (m + n) a -@ LPair (LVect m a) (LVect n a)
 78 | splitAt Z as = [] # as
 79 | splitAt (S m) (a :: as) = let (xs # ys) = LVect.splitAt m as in (a :: xs # ys)
 80 |
 81 | export
 82 | (++) : LVect m a -@ LVect n a -@ LVect (m + n) a
 83 | [] ++ ys = ys
 84 | (x :: xs) ++ ys = x :: (xs ++ ys)
 85 |
 86 | export
 87 | lfoldr : (0 p : Nat -> Type) -> (forall n. a -@ p n -@ p (S n)) -> p Z -@ LVect n a -@ p n
 88 | lfoldr p c n [] = n
 89 | lfoldr p c n (x :: xs) = c x (lfoldr p c n xs)
 90 |
 91 | export
 92 | lfoldl : (0 p : Nat -> Type) -> (forall n. a -@ p n -@ p (S n)) -> p Z -@ LVect n a -@ p n
 93 | lfoldl p c n [] = n
 94 | lfoldl p c n (x :: xs) = lfoldl (p . S) c (c x n) xs
 95 |
 96 | export
 97 | reverse : LVect m a -@ LVect m a
 98 | reverse = lfoldl (\ m => LVect m a) (::) []
 99 |
100 | export
101 | Consumable a => Consumable (LVect n a) where
102 |   consume [] = ()
103 |   consume (x :: xs) = x `seq` consume xs
104 |
105 | export
106 | Duplicable a => Duplicable (LVect n a) where
107 |   duplicate [] = [[], []]
108 |   duplicate (x :: xs) = (::) <$> duplicate x <*> duplicate xs
109 |
110 | ||| Map a linear vector
111 | export
112 | map : (0 f : a -@ b) -> {auto 1 fns : n `Copies` f} -> LVect n a -@ LVect n b
113 | map f {fns = []} [] = []
114 | map f {fns = f :: fs} (x :: xs) = f x :: map f {fns = fs} xs
115 |
116 | ||| Extract all
117 | export
118 | length : Consumable a => LVect n a -@ LNat
119 | length [] = Zero
120 | length (x :: xs) = let () = consume x in Succ (length xs)
121 |
122 | ||| Fold a linear vector.
123 | export
124 | foldl : (0 f : acc -@ a -@ acc) -> {auto 1 fns : n `Copies` f} -> acc -@ (LVect n a) -@ acc
125 | foldl _ {fns = []} acc [] = acc
126 | foldl f {fns = f :: fs} acc (x :: xs) = foldl f {fns = fs} (f acc x) xs
127 |
128 | export
129 | replicate : (1 n : LNat) -> (0 v : a) -> {auto 1 vs : toNat n `Copies` v} -> LVect (toNat n) a
130 | replicate Zero v {vs = []} = []
131 | replicate (Succ n) v {vs = (v :: vs)} = v :: replicate n v {vs}
132 |
133 | ||| Bind a linear vector.
134 | export
135 | (>>=) : LVect n a -@ ((0 f : a -@ LVect m b) -> {1 fns : n `Copies` f} -> LVect (n * m) b)
136 | (>>=) [] _ {fns = []} = []
137 | (>>=) (v :: xs) f {fns = f :: fs} = f v ++ (>>=) {fns = fs} xs f
138 |
139 | ||| Extract all the copies into a vector of the same length as the number of copies.
140 | export
141 | copiesToVect : {0 v : a} -> n `Copies` v -@ LVect n a
142 | copiesToVect [] = []
143 | copiesToVect (v :: copies) = v :: copiesToVect copies
144 |