multer/
buffer.rs

1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::{Buf, Bytes, BytesMut};
6use futures_util::stream::Stream;
7
8use crate::constants;
9
10pub(crate) struct StreamBuffer<'r> {
11    pub(crate) eof: bool,
12    pub(crate) buf: BytesMut,
13    pub(crate) stream: Pin<Box<dyn Stream<Item = Result<Bytes, crate::Error>> + Send + 'r>>,
14    pub(crate) whole_stream_size_limit: u64,
15    pub(crate) stream_size_counter: u64,
16}
17
18impl<'r> StreamBuffer<'r> {
19    pub fn new<S>(stream: S, whole_stream_size_limit: u64) -> Self
20    where
21        S: Stream<Item = Result<Bytes, crate::Error>> + Send + 'r,
22    {
23        StreamBuffer {
24            eof: false,
25            buf: BytesMut::new(),
26            stream: Box::pin(stream),
27            whole_stream_size_limit,
28            stream_size_counter: 0,
29        }
30    }
31
32    pub fn poll_stream(&mut self, cx: &mut Context<'_>) -> Result<(), crate::Error> {
33        if self.eof {
34            return Ok(());
35        }
36
37        loop {
38            match self.stream.as_mut().poll_next(cx) {
39                Poll::Ready(Some(Ok(data))) => {
40                    self.stream_size_counter += data.len() as u64;
41
42                    if self.stream_size_counter > self.whole_stream_size_limit {
43                        return Err(crate::Error::StreamSizeExceeded {
44                            limit: self.whole_stream_size_limit,
45                        });
46                    }
47
48                    self.buf.extend_from_slice(&data)
49                }
50                Poll::Ready(Some(Err(err))) => return Err(err),
51                Poll::Ready(None) => {
52                    self.eof = true;
53                    return Ok(());
54                }
55                Poll::Pending => return Ok(()),
56            }
57        }
58    }
59
60    pub fn read_exact(&mut self, size: usize) -> Option<Bytes> {
61        if size <= self.buf.len() {
62            Some(self.buf.split_to(size).freeze())
63        } else {
64            None
65        }
66    }
67
68    pub fn peek_exact(&mut self, size: usize) -> Option<&[u8]> {
69        self.buf.get(..size)
70    }
71
72    pub fn read_until(&mut self, pattern: &[u8]) -> Option<Bytes> {
73        memchr::memmem::find(&self.buf, pattern).map(|idx| self.buf.split_to(idx + pattern.len()).freeze())
74    }
75
76    pub fn read_to(&mut self, pattern: &[u8]) -> Option<Bytes> {
77        memchr::memmem::find(&self.buf, pattern).map(|idx| self.buf.split_to(idx).freeze())
78    }
79
80    pub fn advance_past_transport_padding(&mut self) -> bool {
81        match self.buf.iter().position(|b| *b != b' ' && *b != b'\t') {
82            Some(pos) => {
83                self.buf.advance(pos);
84                true
85            }
86            None => {
87                self.buf.clear();
88                false
89            }
90        }
91    }
92
93    pub fn read_field_data(
94        &mut self,
95        boundary: &str,
96        field_name: Option<&str>,
97    ) -> crate::Result<Option<(bool, Bytes)>> {
98        trace!("finding next field: {:?}", field_name);
99        if self.buf.is_empty() && self.eof {
100            trace!("empty buffer && EOF");
101            return Err(crate::Error::IncompleteFieldData {
102                field_name: field_name.map(|s| s.to_owned()),
103            });
104        } else if self.buf.is_empty() {
105            return Ok(None);
106        }
107
108        let boundary_deriv = format!("{}{}{}", constants::CRLF, constants::BOUNDARY_EXT, boundary);
109        let b_len = boundary_deriv.len();
110
111        match memchr::memmem::find(&self.buf, boundary_deriv.as_bytes()) {
112            Some(idx) => {
113                trace!("new field found at {}", idx);
114                let bytes = self.buf.split_to(idx).freeze();
115
116                // discard \r\n.
117                self.buf.advance(constants::CRLF.len());
118
119                Ok(Some((true, bytes)))
120            }
121            None if self.eof => {
122                trace!("no new field found: EOF. terminating");
123                Err(crate::Error::IncompleteFieldData {
124                    field_name: field_name.map(|s| s.to_owned()),
125                })
126            }
127            None => {
128                let buf_len = self.buf.len();
129                let rem_boundary_part_max_len = b_len - 1;
130                let rem_boundary_part_idx = if buf_len >= rem_boundary_part_max_len {
131                    buf_len - rem_boundary_part_max_len
132                } else {
133                    0
134                };
135
136                trace!("no new field found, not EOF, checking close");
137                let bytes = &self.buf[rem_boundary_part_idx..];
138                match memchr::memmem::rfind(bytes, constants::CR.as_bytes()) {
139                    Some(rel_idx) => {
140                        let idx = rel_idx + rem_boundary_part_idx;
141
142                        match memchr::memmem::find(boundary_deriv.as_bytes(), &self.buf[idx..]) {
143                            Some(_) => {
144                                let bytes = self.buf.split_to(idx).freeze();
145
146                                match bytes.is_empty() {
147                                    true => Ok(None),
148                                    false => Ok(Some((false, bytes))),
149                                }
150                            }
151                            None => Ok(Some((false, self.read_full_buf()))),
152                        }
153                    }
154                    None => Ok(Some((false, self.read_full_buf()))),
155                }
156            }
157        }
158    }
159
160    pub fn read_full_buf(&mut self) -> Bytes {
161        self.buf.split_to(self.buf.len()).freeze()
162    }
163}
164
165impl fmt::Debug for StreamBuffer<'_> {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        f.debug_struct("StreamBuffer").finish()
168    }
169}