0 | module System.Concurrency.Session
  1 |
  2 | import Control.Linear.LIO
  3 |
  4 | import Data.List.AtIndex
  5 | import Data.Nat
  6 |
  7 | import Data.OpenUnion
  8 | import System
  9 | import System.File
 10 | import System.Concurrency as Threads
 11 | import System.Concurrency.Linear
 12 |
 13 | import Language.Reflection
 14 |
 15 | %default total
 16 |
 17 | ------------------------------------------------------------------------
 18 | -- Session types
 19 |
 20 | namespace Session
 21 |
 22 |   ||| A session type describes the interactions one thread may have with
 23 |   ||| another over a shared bidirectional channel: it may send or receive
 24 |   ||| values of arbitrary types, or be done communicating.
 25 |   public export
 26 |   data Session : Type where
 27 |     Send : (ty : Type) -> (s : Session) -> Session
 28 |     Recv : (ty : Type) -> (s : Session) -> Session
 29 |     End  : Session
 30 |
 31 |   ||| Dual describes how the other party to the communication sees the
 32 |   ||| interactions: our sends become their receives and vice-versa.
 33 |   public export
 34 |   Dual : Session -> Session
 35 |   Dual (Send ty s) = Recv ty (Dual s)
 36 |   Dual (Recv ty s) = Send ty (Dual s)
 37 |   Dual End = End
 38 |
 39 |   ||| Duality is involutive: the dual of my dual is me
 40 |   export
 41 |   dualInvolutive : (s : Session) -> Dual (Dual s) === s
 42 |   dualInvolutive (Send ty s) = cong (Send ty) (dualInvolutive s)
 43 |   dualInvolutive (Recv ty s) = cong (Recv ty) (dualInvolutive s)
 44 |   dualInvolutive End = Refl
 45 |
 46 | ||| We can collect the list of types that will be sent over the
 47 | ||| course of a session by walking down its description
 48 | ||| This definition is purely internal and will not show up in
 49 | ||| the library's interface.
 50 | SendTypes : Session -> List Type
 51 | SendTypes (Send ty s) = ty :: SendTypes s
 52 | SendTypes (Recv ty s) = SendTypes s
 53 | SendTypes End = []
 54 |
 55 | ||| We can collect the list of types that will be received over
 56 | ||| the course of a session by walking down its description
 57 | ||| This definition is purely internal and will not show up in
 58 | ||| the library's interface.
 59 | RecvTypes : Session -> List Type
 60 | RecvTypes (Send ty s) = RecvTypes s
 61 | RecvTypes (Recv ty s) = ty :: RecvTypes s
 62 | RecvTypes End = []
 63 |
 64 | ||| The types received by my dual are exactly the ones I am sending
 65 | ||| This definition is purely internal and will not show up in
 66 | ||| the library's interface.
 67 | RecvDualTypes : (s : Session) -> RecvTypes (Dual s) === SendTypes s
 68 | RecvDualTypes (Send ty s) = cong (ty ::) (RecvDualTypes s)
 69 | RecvDualTypes (Recv ty s) = RecvDualTypes s
 70 | RecvDualTypes End = Refl
 71 |
 72 | ||| The types sent by my dual are exactly the ones I receive
 73 | ||| This definition is purely internal and will not show up in
 74 | ||| the library's interface.
 75 | SendDualTypes : (s : Session) -> SendTypes (Dual s) === RecvTypes s
 76 | SendDualTypes (Send ty s) = SendDualTypes s
 77 | SendDualTypes (Recv ty s) = cong (ty ::) (SendDualTypes s)
 78 | SendDualTypes End = Refl
 79 |
 80 | namespace Seen
 81 |
 82 |   ||| The inductive family (Seen m n f) states that the function `f`
 83 |   ||| was obtained by composing an interleaving of `m` receiving
 84 |   ||| steps and `n` sending ones.
 85 |   public export
 86 |   data Seen : Nat -> Nat -> (Session -> Session) -> Type where
 87 |     None : Seen 0 0 Prelude.id
 88 |     Recv : (ty : Type) -> Seen m n f -> Seen (S m) n (f . Recv ty)
 89 |     Send : (ty : Type) -> Seen m n f -> Seen m (S n) (f . Send ty)
 90 |
 91 | ||| If we know that `ty` is at index `n` in the list of received types
 92 | ||| and that `f` is a function defined using an interleaving of steps
 93 | ||| comprising `m` receiving stepsx then `ty` is at index `m + n` in `f s`.
 94 | atRecvIndex : Seen m _ f ->
 95 |           (s : Session) ->
 96 |           AtIndex ty (RecvTypes s) n ->
 97 |           AtIndex ty (RecvTypes (f s)) (m + n)
 98 | atRecvIndex None accS accAt = accAt
 99 | atRecvIndex (Recv ty s) accS accAt
100 |   = rewrite plusSuccRightSucc (pred m) n in
101 |     atRecvIndex s (Recv ty accS) (S accAt)
102 | atRecvIndex (Send ty s) accS accAt
103 |   = atRecvIndex s (Send ty accS) accAt
104 |
105 | ||| If we know that `ty` is at index `n` in the list of sent types
106 | ||| and that `f` is a function defined using an interleaving of steps
107 | ||| comprising `m` sending steps then `ty` is at index `m + n` in `f s`.
108 | atSendIndex : Seen _ m f ->
109 |           (s : Session) ->
110 |           AtIndex ty (SendTypes s) n ->
111 |           AtIndex ty (SendTypes (f s)) (m + n)
112 | atSendIndex None accS accAt = accAt
113 | atSendIndex (Recv ty s) accS accAt
114 |   = atSendIndex s (Recv ty accS) accAt
115 | atSendIndex (Send ty s) accS accAt
116 |   = rewrite plusSuccRightSucc (pred m) n in
117 |     atSendIndex s (Send ty accS) (S accAt)
118 |
119 |
120 | ||| A (bidirectional) channel is parametrised by a session it must respect.
121 | |||
122 | ||| It is implemented in terms of two low-level channels: one for sending
123 | ||| and one for receiving. This ensures that we never are in a situation
124 | ||| where a thread with session (Send Nat (Recv String ...)) sends a natural
125 | ||| number and subsequently performs a receive before the other party
126 | ||| to the communication had time to grab the Nat thus receiving it
127 | ||| instead of a String.
128 | |||
129 | ||| The low-level channels can only carry values of a single type. And so
130 | ||| they are given respective union types corresponding to the types that
131 | ||| can be sent on the one hand and the ones that can be received on the
132 | ||| other.
133 | ||| These union types are tagged unions where if `ty` is at index `k` in
134 | ||| the list of types `tys` then `(k, v)` is a value of `Union tys` provided
135 | ||| that `v` has type `ty`.
136 | |||
137 | ||| `sendStep`, `recvStep`, `seePrefix`, and `seen` encode the fact that
138 | ||| we have already performed some of the protocol and so the low-level
139 | ||| channels' respective types necessarily mention types that we won't
140 | ||| see anymore.
141 | export
142 | record Channel (s : Session) where
143 |   constructor MkChannel
144 |   {sendStep     : Nat}
145 |   {recvStep     : Nat}
146 |   {0 seenPrefix : Session -> Session}
147 |   0 seen        : Seen recvStep sendStep seenPrefix
148 |
149 |   sendChan : Threads.Channel (Union (SendTypes (seenPrefix s)))
150 |   recvChan : Threads.Channel (Union (RecvTypes (seenPrefix s)))
151 |
152 | ||| Consume a linear channel with a `Recv ty` step at the head of the
153 | ||| session type in order to obtain a value of type `ty` together with
154 | ||| a linear channel for the rest of the session.
155 | export
156 | recv : LinearIO io =>
157 |   Channel (Recv ty s) -@
158 |   L1 io (Res ty (const (Channel s)))
159 | recv (MkChannel {recvStep} seen sendCh recvCh) = do
160 |   x@(Element k prf val) <- channelGet recvCh
161 |   -- Here we check that we got the right message by projecting out of
162 |   -- the union type using the current `recvStep`. Both ends should be
163 |   -- in sync because of the `RecvDualTypes` and `SendDualTypes` lemmas.
164 |   let Just val = prj (recvStep + 0) (atRecvIndex seen (Recv ty s) Z) x
165 |     | Nothing => die1 "Error: invalid recv expected \{show recvStep} but got \{show k}"
166 |   pure1 (val # MkChannel (Recv ty seen) sendCh recvCh)
167 |
168 |
169 | ||| Consume a linear channel with a `Send ty` step at the head of the
170 | ||| session type in order to send a value of type `ty` and obtain a
171 | ||| linear channel for the rest of the session.
172 | export
173 | send : LinearIO io =>
174 |   (1 _ : Channel (Send ty s)) ->
175 |   ty ->
176 |   L1 io (Channel s)
177 | send (MkChannel {sendStep} seen sendCh recvCh) x = do
178 |   let val = inj (sendStep + 0) (atSendIndex seen (Send ty s) Z) x
179 |   channelPut sendCh val
180 |   pure1 (MkChannel (Send ty seen) sendCh recvCh)
181 |
182 | ||| Discard the channel provided that the session has reached its `End`.
183 | export
184 | end : LinearIO io => Channel End -@ L io ()
185 | end (MkChannel _ _ _) = do
186 |   pure ()
187 |
188 | ||| Given a session, create a bidirectional communication channel and
189 | ||| return its two endpoints
190 | export
191 | makeChannel :
192 |   LinearIO io =>
193 |   (0 s : Session) ->
194 |   L1 io (LPair (Channel s) (Channel (Dual s)))
195 | makeChannel s = do
196 |   sendChan <- Threads.makeChannel
197 |   recvChan <- Threads.makeChannel
198 |   let 1 posCh : Channel s
199 |     := MkChannel None sendChan recvChan
200 |   let 1 negCh : Channel (Dual s)
201 |     := MkChannel None
202 |          (rewrite SendDualTypes s in recvChan)
203 |          (rewrite RecvDualTypes s in sendChan)
204 |   pure1 (posCh # negCh)
205 |
206 | ||| Given a session and two functions communicating according to that
207 | ||| session, we can run the two programs concurrently and collect their
208 | ||| final results.
209 | export
210 | fork : (0 s : Session) ->
211 |        (Channel       s  -@ L IO a) -@
212 |        (Channel (Dual s) -@ L IO b) -@
213 |        L IO (a, b)
214 | fork s kA kB = do
215 |   let 1 io = makeChannel s
216 |   (posCh # negCh) <- io
217 |   par (kA posCh) (kB negCh)
218 |