0 | ||| Splitting operations and their properties
  1 | module Data.Fin.Split
  2 |
  3 | import public Data.Fin
  4 | import Data.Fin.Properties
  5 |
  6 | import Syntax.WithProof
  7 | import Syntax.PreorderReasoning
  8 |
  9 | %default total
 10 |
 11 | ||| Converts `Fin`s that are used as indexes of parts to an index of a sum.
 12 | |||
 13 | ||| For example, if you have a `Vect` that is a concatenation of two `Vect`s and
 14 | ||| you have an index either in the first or the second of the original `Vect`s,
 15 | ||| using this function you can get an index in the concatenated one.
 16 | public export
 17 | indexSum : {m : Nat} -> Either (Fin m) (Fin n) -> Fin (m + n)
 18 | indexSum (Left  l) = weakenN n l
 19 | indexSum (Right r) = shift m r
 20 |
 21 | ||| Extracts an index of the first or the second part from the index of a sum.
 22 | |||
 23 | ||| For example, if you have a `Vect` that is a concatenation of the `Vect`s and
 24 | ||| you have an index of this `Vect`, you have get an index of either left or right
 25 | ||| original `Vect` using this function.
 26 | public export
 27 | splitSum : {m : Nat} -> Fin (m + n) -> Either (Fin m) (Fin n)
 28 | splitSum {m=Z}   k      = Right k
 29 | splitSum {m=S m} FZ     = Left FZ
 30 | splitSum {m=S m} (FS k) = mapFst FS $ splitSum k
 31 |
 32 | ||| Calculates the index of a square matrix of size `a * b` by indices of each side.
 33 | public export
 34 | indexProd : {n : Nat} -> Fin m -> Fin n -> Fin (m * n)
 35 | indexProd FZ     = weakenN $ mult (pred m) n
 36 | indexProd (FS k) = shift n . indexProd k
 37 |
 38 | ||| Splits the index of a square matrix of size `a * b` to indices of each side.
 39 | public export
 40 | splitProd : {m, n : Nat} -> Fin (m * n) -> (Fin m, Fin n)
 41 | splitProd {m=S _} p = case splitSum p of
 42 |   Left  k => (FZ, k)
 43 |   Right l => mapFst FS $ splitProd l
 44 |
 45 | --- Properties ---
 46 |
 47 | export
 48 | indexSumPreservesLast :
 49 |   (m, n : Nat) -> indexSum {m} (Right $ Fin.last {n}) ~~~ Fin.last {n=m+n}
 50 | indexSumPreservesLast Z     n = reflexive
 51 | indexSumPreservesLast (S m) n = FS (shiftLastIsLast m)
 52 |
 53 | export
 54 | indexProdPreservesLast : (m, n : Nat) ->
 55 |          indexProd (Fin.last {n=m}) (Fin.last {n}) = Fin.last
 56 | indexProdPreservesLast Z n = homoPointwiseIsEqual
 57 |   $ transitive (weakenNZeroIdentity last)
 58 |                (congLast (sym $ plusZeroRightNeutral n))
 59 | indexProdPreservesLast (S m) n = Calc $
 60 |   |~ indexProd (last {n=S m}) (last {n})
 61 |   ~~ FS (shift n (indexProd last last)) ...( Refl )
 62 |   ~~ FS (shift n last)                  ...( cong (FS . shift n) (indexProdPreservesLast m n ) )
 63 |   ~~ last                               ...( homoPointwiseIsEqual prf )
 64 |
 65 |   where
 66 |
 67 |     prf : shift (S n) (Fin.last {n = n + m * S n}) ~~~ Fin.last {n = n + S (n + m * S n)}
 68 |     prf = transitive (shiftLastIsLast (S n))
 69 |                      (congLast (plusSuccRightSucc n (n + m * S n)))
 70 |
 71 | export
 72 | splitSumOfWeakenN : (k : Fin m) -> splitSum {m} {n} (weakenN n k) = Left k
 73 | splitSumOfWeakenN FZ = Refl
 74 | splitSumOfWeakenN (FS k) = cong (mapFst FS) $ splitSumOfWeakenN k
 75 |
 76 | export
 77 | splitSumOfShift : {m : Nat} -> (k : Fin n) -> splitSum {m} {n} (shift m k) = Right k
 78 | splitSumOfShift {m=Z}   k = Refl
 79 | splitSumOfShift {m=S m} k = cong (mapFst FS) $ splitSumOfShift k
 80 |
 81 | export
 82 | splitOfIndexSumInverse : {m : Nat} -> (e : Either (Fin m) (Fin n)) -> splitSum (indexSum e) = e
 83 | splitOfIndexSumInverse (Left l) = splitSumOfWeakenN l
 84 | splitOfIndexSumInverse (Right r) = splitSumOfShift r
 85 |
 86 | export
 87 | indexOfSplitSumInverse : {m, n : Nat} -> (f : Fin (m + n)) -> indexSum (splitSum {m} {n} f) = f
 88 | indexOfSplitSumInverse {m=Z}   f  = Refl
 89 | indexOfSplitSumInverse {m=S _} FZ = Refl
 90 | indexOfSplitSumInverse {m=S _} (FS f) with (indexOfSplitSumInverse f)
 91 |   indexOfSplitSumInverse {m=S _} (FS f) | eq with (splitSum f)
 92 |     indexOfSplitSumInverse {m=S _} (FS _) | eq | Left  _ = cong FS eq
 93 |     indexOfSplitSumInverse {m=S _} (FS _) | eq | Right _ = cong FS eq
 94 |
 95 |
 96 | export
 97 | splitOfIndexProdInverse : {m : Nat} -> (k : Fin m) -> (l : Fin n) ->
 98 |                           splitProd (indexProd k l) = (k, l)
 99 | splitOfIndexProdInverse FZ     l
100 |    = rewrite splitSumOfWeakenN {n = mult (pred m) n} l in Refl
101 | splitOfIndexProdInverse (FS k) l
102 |    = rewrite splitSumOfShift {m=n} $ indexProd k l in
103 |      cong (mapFst FS) $ splitOfIndexProdInverse k l
104 |
105 | export
106 | indexOfSplitProdInverse : {m, n : Nat} -> (f : Fin (m * n)) ->
107 |                           uncurry (indexProd {m} {n}) (splitProd {m} {n} f) = f
108 | indexOfSplitProdInverse {m = S _} f with (@@ splitSum f)
109 |   indexOfSplitProdInverse {m = S _} f | (Left l ** eq= rewrite eq in Calc $
110 |     |~ indexSum (Left l)
111 |     ~~ indexSum (splitSum f) ...( cong indexSum (sym eq) )
112 |     ~~ f                     ...( indexOfSplitSumInverse f )
113 |   indexOfSplitProdInverse f | (Right r ** eqwith (@@ splitProd r)
114 |     indexOfSplitProdInverse f | (Right r ** eq| ((p, q) ** eq2)
115 |       = rewrite eq in rewrite eq2 in Calc $
116 |         |~ indexProd (FS p) q
117 |         ~~ shift n (indexProd p q)                   ...( Refl )
118 |         ~~ shift n (uncurry indexProd (splitProd r)) ...( cong (shift n . uncurry indexProd) (sym eq2) )
119 |         ~~ shift n r                                 ...( cong (shift n) (indexOfSplitProdInverse r) )
120 |         ~~ indexSum (splitSum f)                     ...( sym (cong indexSum eq) )
121 |         ~~ f                                         ...( indexOfSplitSumInverse f )
122 |