@@ -6,14 +6,15 @@ use postgres_protocol::message::backend::Message;
6
6
use postgres_protocol::message::frontend;
7
7
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
8
8
use std::error::Error as StdError;
9
+ use std::mem;
9
10
10
11
use crate::proto::client::{Client, PendingRequest};
11
12
use crate::proto::statement::Statement;
12
13
use crate::Error;
13
14
14
- pub enum CopyMessage {
15
- Data( Vec<u8>) ,
16
- Done ,
15
+ pub struct CopyMessage {
16
+ pub data: Vec<u8>,
17
+ pub done: bool ,
17
18
}
18
19
19
20
pub struct CopyInReceiver {
@@ -40,13 +41,14 @@ impl Stream for CopyInReceiver {
40
41
}
41
42
42
43
match self.receiver.poll()? {
43
- Async::Ready(Some(CopyMessage::Data(buf))) => Ok(Async::Ready(Some(buf))),
44
- Async::Ready(Some(CopyMessage::Done)) => {
45
- self.done = true;
46
- let mut buf = vec![];
47
- frontend::copy_done(&mut buf);
48
- frontend::sync(&mut buf);
49
- Ok(Async::Ready(Some(buf)))
44
+ Async::Ready(Some(mut data)) => {
45
+ if data.done {
46
+ self.done = true;
47
+ frontend::copy_done(&mut data.data);
48
+ frontend::sync(&mut data.data);
49
+ }
50
+
51
+ Ok(Async::Ready(Some(data.data)))
50
52
}
51
53
Async::Ready(None) => {
52
54
self.done = true;
85
87
#[state_machine_future(transitions(WriteCopyDone))]
86
88
WriteCopyData {
87
89
stream: S,
90
+ buf: Vec<u8>,
88
91
pending_message: Option<CopyMessage>,
89
92
sender: mpsc::Sender<CopyMessage>,
90
93
receiver: mpsc::Receiver<Message>,
@@ -133,6 +136,7 @@ where
133
136
let state = state.take();
134
137
transition!(WriteCopyData {
135
138
stream: state.stream,
139
+ buf: vec![],
136
140
pending_message: None,
137
141
sender: state.sender,
138
142
receiver: state.receiver
@@ -148,34 +152,58 @@ where
148
152
fn poll_write_copy_data<'a>(
149
153
state: &'a mut RentToOwn<'a, WriteCopyData<S>>,
150
154
) -> Poll<AfterWriteCopyData, Error> {
155
+ if let Some(message) = state.pending_message.take() {
156
+ match state
157
+ .sender
158
+ .start_send(message)
159
+ .map_err(|_| Error::closed())?
160
+ {
161
+ AsyncSink::Ready => {}
162
+ AsyncSink::NotReady(message) => {
163
+ state.pending_message = Some(message);
164
+ return Ok(Async::NotReady);
165
+ }
166
+ }
167
+ }
168
+
151
169
loop {
152
- let message = match state.pending_message.take() {
153
- Some(message) => message,
154
- None => match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) {
170
+ let done = loop {
171
+ match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) {
155
172
Some(data) => {
156
- let mut buf = vec![];
157
173
// FIXME avoid collect
158
- frontend::copy_data(&data.into_buf().collect::<Vec<_>>(), &mut buf)
174
+ frontend::copy_data(&data.into_buf().collect::<Vec<_>>(), &mut state. buf)
159
175
.map_err(Error::encode)?;
160
- CopyMessage::Data(buf)
176
+ if state.buf.len() > 4096 {
177
+ break false;
178
+ }
161
179
}
162
- None => {
163
- let state = state.take();
164
- transition!(WriteCopyDone {
165
- future: state.sender.send(CopyMessage::Done),
166
- receiver: state.receiver
167
- })
168
- }
169
- },
180
+ None => break true,
181
+ }
170
182
};
171
183
172
- match state.sender.start_send(message) {
173
- Ok(AsyncSink::Ready) => {}
174
- Ok(AsyncSink::NotReady(message)) => {
184
+ let message = CopyMessage {
185
+ data: mem::replace(&mut state.buf, vec![]),
186
+ done,
187
+ };
188
+
189
+ if done {
190
+ let state = state.take();
191
+ transition!(WriteCopyDone {
192
+ future: state.sender.send(message),
193
+ receiver: state.receiver,
194
+ });
195
+ }
196
+
197
+ match state
198
+ .sender
199
+ .start_send(message)
200
+ .map_err(|_| Error::closed())?
201
+ {
202
+ AsyncSink::Ready => {}
203
+ AsyncSink::NotReady(message) => {
175
204
state.pending_message = Some(message);
176
205
return Ok(Async::NotReady);
177
206
}
178
- Err(_) => return Err(Error::closed()),
179
207
}
180
208
}
181
209
}
0 commit comments