1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::{Buf, Bytes, BytesMut};
6use futures_util::stream::Stream;
7
8use crate::constants;
9
10pub(crate) struct StreamBuffer<'r> {
11 pub(crate) eof: bool,
12 pub(crate) buf: BytesMut,
13 pub(crate) stream: Pin<Box<dyn Stream<Item = Result<Bytes, crate::Error>> + Send + 'r>>,
14 pub(crate) whole_stream_size_limit: u64,
15 pub(crate) stream_size_counter: u64,
16}
17
18impl<'r> StreamBuffer<'r> {
19 pub fn new<S>(stream: S, whole_stream_size_limit: u64) -> Self
20 where
21 S: Stream<Item = Result<Bytes, crate::Error>> + Send + 'r,
22 {
23 StreamBuffer {
24 eof: false,
25 buf: BytesMut::new(),
26 stream: Box::pin(stream),
27 whole_stream_size_limit,
28 stream_size_counter: 0,
29 }
30 }
31
32 pub fn poll_stream(&mut self, cx: &mut Context<'_>) -> Result<(), crate::Error> {
33 if self.eof {
34 return Ok(());
35 }
36
37 loop {
38 match self.stream.as_mut().poll_next(cx) {
39 Poll::Ready(Some(Ok(data))) => {
40 self.stream_size_counter += data.len() as u64;
41
42 if self.stream_size_counter > self.whole_stream_size_limit {
43 return Err(crate::Error::StreamSizeExceeded {
44 limit: self.whole_stream_size_limit,
45 });
46 }
47
48 self.buf.extend_from_slice(&data)
49 }
50 Poll::Ready(Some(Err(err))) => return Err(err),
51 Poll::Ready(None) => {
52 self.eof = true;
53 return Ok(());
54 }
55 Poll::Pending => return Ok(()),
56 }
57 }
58 }
59
60 pub fn read_exact(&mut self, size: usize) -> Option<Bytes> {
61 if size <= self.buf.len() {
62 Some(self.buf.split_to(size).freeze())
63 } else {
64 None
65 }
66 }
67
68 pub fn peek_exact(&mut self, size: usize) -> Option<&[u8]> {
69 self.buf.get(..size)
70 }
71
72 pub fn read_until(&mut self, pattern: &[u8]) -> Option<Bytes> {
73 memchr::memmem::find(&self.buf, pattern).map(|idx| self.buf.split_to(idx + pattern.len()).freeze())
74 }
75
76 pub fn read_to(&mut self, pattern: &[u8]) -> Option<Bytes> {
77 memchr::memmem::find(&self.buf, pattern).map(|idx| self.buf.split_to(idx).freeze())
78 }
79
80 pub fn advance_past_transport_padding(&mut self) -> bool {
81 match self.buf.iter().position(|b| *b != b' ' && *b != b'\t') {
82 Some(pos) => {
83 self.buf.advance(pos);
84 true
85 }
86 None => {
87 self.buf.clear();
88 false
89 }
90 }
91 }
92
93 pub fn read_field_data(
94 &mut self,
95 boundary: &str,
96 field_name: Option<&str>,
97 ) -> crate::Result<Option<(bool, Bytes)>> {
98 trace!("finding next field: {:?}", field_name);
99 if self.buf.is_empty() && self.eof {
100 trace!("empty buffer && EOF");
101 return Err(crate::Error::IncompleteFieldData {
102 field_name: field_name.map(|s| s.to_owned()),
103 });
104 } else if self.buf.is_empty() {
105 return Ok(None);
106 }
107
108 let boundary_deriv = format!("{}{}{}", constants::CRLF, constants::BOUNDARY_EXT, boundary);
109 let b_len = boundary_deriv.len();
110
111 match memchr::memmem::find(&self.buf, boundary_deriv.as_bytes()) {
112 Some(idx) => {
113 trace!("new field found at {}", idx);
114 let bytes = self.buf.split_to(idx).freeze();
115
116 self.buf.advance(constants::CRLF.len());
118
119 Ok(Some((true, bytes)))
120 }
121 None if self.eof => {
122 trace!("no new field found: EOF. terminating");
123 Err(crate::Error::IncompleteFieldData {
124 field_name: field_name.map(|s| s.to_owned()),
125 })
126 }
127 None => {
128 let buf_len = self.buf.len();
129 let rem_boundary_part_max_len = b_len - 1;
130 let rem_boundary_part_idx = if buf_len >= rem_boundary_part_max_len {
131 buf_len - rem_boundary_part_max_len
132 } else {
133 0
134 };
135
136 trace!("no new field found, not EOF, checking close");
137 let bytes = &self.buf[rem_boundary_part_idx..];
138 match memchr::memmem::rfind(bytes, constants::CR.as_bytes()) {
139 Some(rel_idx) => {
140 let idx = rel_idx + rem_boundary_part_idx;
141
142 match memchr::memmem::find(boundary_deriv.as_bytes(), &self.buf[idx..]) {
143 Some(_) => {
144 let bytes = self.buf.split_to(idx).freeze();
145
146 match bytes.is_empty() {
147 true => Ok(None),
148 false => Ok(Some((false, bytes))),
149 }
150 }
151 None => Ok(Some((false, self.read_full_buf()))),
152 }
153 }
154 None => Ok(Some((false, self.read_full_buf()))),
155 }
156 }
157 }
158 }
159
160 pub fn read_full_buf(&mut self) -> Bytes {
161 self.buf.split_to(self.buf.len()).freeze()
162 }
163}
164
165impl fmt::Debug for StreamBuffer<'_> {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 f.debug_struct("StreamBuffer").finish()
168 }
169}