reqwest/async_impl/
decoder.rs

1use std::fmt;
2#[cfg(any(
3    feature = "gzip",
4    feature = "zstd",
5    feature = "brotli",
6    feature = "deflate"
7))]
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12#[cfg(any(
13    feature = "gzip",
14    feature = "zstd",
15    feature = "brotli",
16    feature = "deflate"
17))]
18use futures_util::stream::Fuse;
19
20#[cfg(feature = "gzip")]
21use async_compression::tokio::bufread::GzipDecoder;
22
23#[cfg(feature = "brotli")]
24use async_compression::tokio::bufread::BrotliDecoder;
25
26#[cfg(feature = "zstd")]
27use async_compression::tokio::bufread::ZstdDecoder;
28
29#[cfg(feature = "deflate")]
30use async_compression::tokio::bufread::ZlibDecoder;
31
32#[cfg(any(
33    feature = "gzip",
34    feature = "zstd",
35    feature = "brotli",
36    feature = "deflate",
37    feature = "blocking",
38))]
39use futures_core::Stream;
40
41use bytes::Bytes;
42use http::HeaderMap;
43use hyper::body::Body as HttpBody;
44use hyper::body::Frame;
45
46#[cfg(any(
47    feature = "gzip",
48    feature = "brotli",
49    feature = "zstd",
50    feature = "deflate"
51))]
52use tokio_util::codec::{BytesCodec, FramedRead};
53#[cfg(any(
54    feature = "gzip",
55    feature = "brotli",
56    feature = "zstd",
57    feature = "deflate"
58))]
59use tokio_util::io::StreamReader;
60
61use super::body::ResponseBody;
62
63#[derive(Clone, Copy, Debug)]
64pub(super) struct Accepts {
65    #[cfg(feature = "gzip")]
66    pub(super) gzip: bool,
67    #[cfg(feature = "brotli")]
68    pub(super) brotli: bool,
69    #[cfg(feature = "zstd")]
70    pub(super) zstd: bool,
71    #[cfg(feature = "deflate")]
72    pub(super) deflate: bool,
73}
74
75impl Accepts {
76    pub fn none() -> Self {
77        Self {
78            #[cfg(feature = "gzip")]
79            gzip: false,
80            #[cfg(feature = "brotli")]
81            brotli: false,
82            #[cfg(feature = "zstd")]
83            zstd: false,
84            #[cfg(feature = "deflate")]
85            deflate: false,
86        }
87    }
88}
89
90/// A response decompressor over a non-blocking stream of chunks.
91///
92/// The inner decoder may be constructed asynchronously.
93pub(crate) struct Decoder {
94    inner: Inner,
95}
96
97#[cfg(any(
98    feature = "gzip",
99    feature = "zstd",
100    feature = "brotli",
101    feature = "deflate"
102))]
103type PeekableIoStream = futures_util::stream::Peekable<IoStream>;
104
105#[cfg(any(
106    feature = "gzip",
107    feature = "zstd",
108    feature = "brotli",
109    feature = "deflate"
110))]
111type PeekableIoStreamReader = StreamReader<PeekableIoStream, Bytes>;
112
113enum Inner {
114    /// A `PlainText` decoder just returns the response content as is.
115    PlainText(ResponseBody),
116
117    /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
118    #[cfg(feature = "gzip")]
119    Gzip(Pin<Box<Fuse<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
120
121    /// A `Brotli` decoder will uncompress the brotlied response content before returning it.
122    #[cfg(feature = "brotli")]
123    Brotli(Pin<Box<Fuse<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
124
125    /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it.
126    #[cfg(feature = "zstd")]
127    Zstd(Pin<Box<Fuse<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
128
129    /// A `Deflate` decoder will uncompress the deflated response content before returning it.
130    #[cfg(feature = "deflate")]
131    Deflate(Pin<Box<Fuse<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
132
133    /// A decoder that doesn't have a value yet.
134    #[cfg(any(
135        feature = "brotli",
136        feature = "zstd",
137        feature = "gzip",
138        feature = "deflate"
139    ))]
140    Pending(Pin<Box<Pending>>),
141}
142
143#[cfg(any(
144    feature = "gzip",
145    feature = "zstd",
146    feature = "brotli",
147    feature = "deflate"
148))]
149/// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
150struct Pending(PeekableIoStream, DecoderType);
151
152#[cfg(any(
153    feature = "gzip",
154    feature = "zstd",
155    feature = "brotli",
156    feature = "deflate",
157    feature = "blocking",
158))]
159pub(crate) struct IoStream<B = ResponseBody>(B);
160
161#[cfg(any(
162    feature = "gzip",
163    feature = "zstd",
164    feature = "brotli",
165    feature = "deflate"
166))]
167enum DecoderType {
168    #[cfg(feature = "gzip")]
169    Gzip,
170    #[cfg(feature = "brotli")]
171    Brotli,
172    #[cfg(feature = "zstd")]
173    Zstd,
174    #[cfg(feature = "deflate")]
175    Deflate,
176}
177
178impl fmt::Debug for Decoder {
179    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
180        f.debug_struct("Decoder").finish()
181    }
182}
183
184impl Decoder {
185    #[cfg(feature = "blocking")]
186    pub(crate) fn empty() -> Decoder {
187        Decoder {
188            inner: Inner::PlainText(empty()),
189        }
190    }
191
192    #[cfg(feature = "blocking")]
193    pub(crate) fn into_stream(self) -> IoStream<Self> {
194        IoStream(self)
195    }
196
197    /// A plain text decoder.
198    ///
199    /// This decoder will emit the underlying chunks as-is.
200    fn plain_text(body: ResponseBody) -> Decoder {
201        Decoder {
202            inner: Inner::PlainText(body),
203        }
204    }
205
206    /// A gzip decoder.
207    ///
208    /// This decoder will buffer and decompress chunks that are gzipped.
209    #[cfg(feature = "gzip")]
210    fn gzip(body: ResponseBody) -> Decoder {
211        use futures_util::StreamExt;
212
213        Decoder {
214            inner: Inner::Pending(Box::pin(Pending(
215                IoStream(body).peekable(),
216                DecoderType::Gzip,
217            ))),
218        }
219    }
220
221    /// A brotli decoder.
222    ///
223    /// This decoder will buffer and decompress chunks that are brotlied.
224    #[cfg(feature = "brotli")]
225    fn brotli(body: ResponseBody) -> Decoder {
226        use futures_util::StreamExt;
227
228        Decoder {
229            inner: Inner::Pending(Box::pin(Pending(
230                IoStream(body).peekable(),
231                DecoderType::Brotli,
232            ))),
233        }
234    }
235
236    /// A zstd decoder.
237    ///
238    /// This decoder will buffer and decompress chunks that are zstd compressed.
239    #[cfg(feature = "zstd")]
240    fn zstd(body: ResponseBody) -> Decoder {
241        use futures_util::StreamExt;
242
243        Decoder {
244            inner: Inner::Pending(Box::pin(Pending(
245                IoStream(body).peekable(),
246                DecoderType::Zstd,
247            ))),
248        }
249    }
250
251    /// A deflate decoder.
252    ///
253    /// This decoder will buffer and decompress chunks that are deflated.
254    #[cfg(feature = "deflate")]
255    fn deflate(body: ResponseBody) -> Decoder {
256        use futures_util::StreamExt;
257
258        Decoder {
259            inner: Inner::Pending(Box::pin(Pending(
260                IoStream(body).peekable(),
261                DecoderType::Deflate,
262            ))),
263        }
264    }
265
266    #[cfg(any(
267        feature = "brotli",
268        feature = "zstd",
269        feature = "gzip",
270        feature = "deflate"
271    ))]
272    fn detect_encoding(headers: &mut HeaderMap, encoding_str: &str) -> bool {
273        use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
274        use log::warn;
275
276        let mut is_content_encoded = {
277            headers
278                .get_all(CONTENT_ENCODING)
279                .iter()
280                .any(|enc| enc == encoding_str)
281                || headers
282                    .get_all(TRANSFER_ENCODING)
283                    .iter()
284                    .any(|enc| enc == encoding_str)
285        };
286        if is_content_encoded {
287            if let Some(content_length) = headers.get(CONTENT_LENGTH) {
288                if content_length == "0" {
289                    warn!("{encoding_str} response with content-length of 0");
290                    is_content_encoded = false;
291                }
292            }
293        }
294        if is_content_encoded {
295            headers.remove(CONTENT_ENCODING);
296            headers.remove(CONTENT_LENGTH);
297        }
298        is_content_encoded
299    }
300
301    /// Constructs a Decoder from a hyper request.
302    ///
303    /// A decoder is just a wrapper around the hyper request that knows
304    /// how to decode the content body of the request.
305    ///
306    /// Uses the correct variant by inspecting the Content-Encoding header.
307    pub(super) fn detect(
308        _headers: &mut HeaderMap,
309        body: ResponseBody,
310        _accepts: Accepts,
311    ) -> Decoder {
312        #[cfg(feature = "gzip")]
313        {
314            if _accepts.gzip && Decoder::detect_encoding(_headers, "gzip") {
315                return Decoder::gzip(body);
316            }
317        }
318
319        #[cfg(feature = "brotli")]
320        {
321            if _accepts.brotli && Decoder::detect_encoding(_headers, "br") {
322                return Decoder::brotli(body);
323            }
324        }
325
326        #[cfg(feature = "zstd")]
327        {
328            if _accepts.zstd && Decoder::detect_encoding(_headers, "zstd") {
329                return Decoder::zstd(body);
330            }
331        }
332
333        #[cfg(feature = "deflate")]
334        {
335            if _accepts.deflate && Decoder::detect_encoding(_headers, "deflate") {
336                return Decoder::deflate(body);
337            }
338        }
339
340        Decoder::plain_text(body)
341    }
342}
343
344impl HttpBody for Decoder {
345    type Data = Bytes;
346    type Error = crate::Error;
347
348    fn poll_frame(
349        mut self: Pin<&mut Self>,
350        cx: &mut Context,
351    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
352        match self.inner {
353            #[cfg(any(
354                feature = "brotli",
355                feature = "zstd",
356                feature = "gzip",
357                feature = "deflate"
358            ))]
359            Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) {
360                Poll::Ready(Ok(inner)) => {
361                    self.inner = inner;
362                    self.poll_frame(cx)
363                }
364                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(crate::error::decode_io(e)))),
365                Poll::Pending => Poll::Pending,
366            },
367            Inner::PlainText(ref mut body) => {
368                match futures_core::ready!(Pin::new(body).poll_frame(cx)) {
369                    Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
370                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode(err)))),
371                    None => Poll::Ready(None),
372                }
373            }
374            #[cfg(feature = "gzip")]
375            Inner::Gzip(ref mut decoder) => {
376                match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
377                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
378                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
379                    None => {
380                        // poll inner connection until EOF after gzip stream is finished
381                        poll_inner_should_be_empty(
382                            decoder.get_mut().get_mut().get_mut().get_mut(),
383                            cx,
384                        )
385                    }
386                }
387            }
388            #[cfg(feature = "brotli")]
389            Inner::Brotli(ref mut decoder) => {
390                match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
391                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
392                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
393                    None => {
394                        // poll inner connection until EOF after brotli stream is finished
395                        poll_inner_should_be_empty(
396                            decoder.get_mut().get_mut().get_mut().get_mut(),
397                            cx,
398                        )
399                    }
400                }
401            }
402            #[cfg(feature = "zstd")]
403            Inner::Zstd(ref mut decoder) => {
404                match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
405                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
406                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
407                    None => {
408                        // poll inner connection until EOF after zstd stream is finished
409                        poll_inner_should_be_empty(
410                            decoder.get_mut().get_mut().get_mut().get_mut(),
411                            cx,
412                        )
413                    }
414                }
415            }
416            #[cfg(feature = "deflate")]
417            Inner::Deflate(ref mut decoder) => {
418                match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
419                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
420                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
421                    None => {
422                        // poll inner connection until EOF after deflate stream is finished
423                        poll_inner_should_be_empty(
424                            decoder.get_mut().get_mut().get_mut().get_mut(),
425                            cx,
426                        )
427                    }
428                }
429            }
430        }
431    }
432
433    fn size_hint(&self) -> http_body::SizeHint {
434        match self.inner {
435            Inner::PlainText(ref body) => HttpBody::size_hint(body),
436            // the rest are "unknown", so default
437            #[cfg(any(
438                feature = "brotli",
439                feature = "zstd",
440                feature = "gzip",
441                feature = "deflate"
442            ))]
443            _ => http_body::SizeHint::default(),
444        }
445    }
446}
447
448#[cfg(any(
449    feature = "gzip",
450    feature = "zstd",
451    feature = "brotli",
452    feature = "deflate"
453))]
454fn poll_inner_should_be_empty(
455    inner: &mut PeekableIoStream,
456    cx: &mut Context,
457) -> Poll<Option<Result<Frame<Bytes>, crate::Error>>> {
458    // poll inner connection until EOF after deflate stream is finished
459    // loop in case of empty frames
460    let mut inner = Pin::new(inner);
461    loop {
462        match futures_core::ready!(inner.as_mut().poll_next(cx)) {
463            // ignore any empty frames
464            Some(Ok(bytes)) if bytes.is_empty() => continue,
465            Some(Ok(_)) => {
466                return Poll::Ready(Some(Err(crate::error::decode(
467                    "there are extra bytes after body has been decompressed",
468                ))))
469            }
470            Some(Err(err)) => return Poll::Ready(Some(Err(crate::error::decode_io(err)))),
471            None => return Poll::Ready(None),
472        }
473    }
474}
475
476#[cfg(any(
477    feature = "gzip",
478    feature = "zstd",
479    feature = "brotli",
480    feature = "deflate",
481    feature = "blocking",
482))]
483fn empty() -> ResponseBody {
484    use http_body_util::{combinators::BoxBody, BodyExt, Empty};
485    BoxBody::new(Empty::new().map_err(|never| match never {}))
486}
487
488#[cfg(any(
489    feature = "gzip",
490    feature = "zstd",
491    feature = "brotli",
492    feature = "deflate"
493))]
494impl Future for Pending {
495    type Output = Result<Inner, std::io::Error>;
496
497    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
498        use futures_util::StreamExt;
499
500        match futures_core::ready!(Pin::new(&mut self.0).poll_peek(cx)) {
501            Some(Ok(_)) => {
502                // fallthrough
503            }
504            Some(Err(_e)) => {
505                // error was just a ref, so we need to really poll to move it
506                return Poll::Ready(Err(futures_core::ready!(
507                    Pin::new(&mut self.0).poll_next(cx)
508                )
509                .expect("just peeked Some")
510                .unwrap_err()));
511            }
512            None => return Poll::Ready(Ok(Inner::PlainText(empty()))),
513        };
514
515        let _body = std::mem::replace(&mut self.0, IoStream(empty()).peekable());
516
517        match self.1 {
518            #[cfg(feature = "brotli")]
519            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(
520                FramedRead::new(
521                    BrotliDecoder::new(StreamReader::new(_body)),
522                    BytesCodec::new(),
523                )
524                .fuse(),
525            )))),
526            #[cfg(feature = "zstd")]
527            DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(
528                FramedRead::new(
529                    ZstdDecoder::new(StreamReader::new(_body)),
530                    BytesCodec::new(),
531                )
532                .fuse(),
533            )))),
534            #[cfg(feature = "gzip")]
535            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(
536                FramedRead::new(
537                    GzipDecoder::new(StreamReader::new(_body)),
538                    BytesCodec::new(),
539                )
540                .fuse(),
541            )))),
542            #[cfg(feature = "deflate")]
543            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(
544                FramedRead::new(
545                    ZlibDecoder::new(StreamReader::new(_body)),
546                    BytesCodec::new(),
547                )
548                .fuse(),
549            )))),
550        }
551    }
552}
553
554#[cfg(any(
555    feature = "gzip",
556    feature = "zstd",
557    feature = "brotli",
558    feature = "deflate",
559    feature = "blocking",
560))]
561impl<B> Stream for IoStream<B>
562where
563    B: HttpBody<Data = Bytes> + Unpin,
564    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
565{
566    type Item = Result<Bytes, std::io::Error>;
567
568    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
569        loop {
570            return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) {
571                Some(Ok(frame)) => {
572                    // skip non-data frames
573                    if let Ok(buf) = frame.into_data() {
574                        Poll::Ready(Some(Ok(buf)))
575                    } else {
576                        continue;
577                    }
578                }
579                Some(Err(err)) => Poll::Ready(Some(Err(crate::error::into_io(err.into())))),
580                None => Poll::Ready(None),
581            };
582        }
583    }
584}
585
586// ===== impl Accepts =====
587
588impl Accepts {
589    /*
590    pub(super) fn none() -> Self {
591        Accepts {
592            #[cfg(feature = "gzip")]
593            gzip: false,
594            #[cfg(feature = "brotli")]
595            brotli: false,
596            #[cfg(feature = "zstd")]
597            zstd: false,
598            #[cfg(feature = "deflate")]
599            deflate: false,
600        }
601    }
602    */
603
604    pub(super) fn as_str(&self) -> Option<&'static str> {
605        match (
606            self.is_gzip(),
607            self.is_brotli(),
608            self.is_zstd(),
609            self.is_deflate(),
610        ) {
611            (true, true, true, true) => Some("gzip, br, zstd, deflate"),
612            (true, true, false, true) => Some("gzip, br, deflate"),
613            (true, true, true, false) => Some("gzip, br, zstd"),
614            (true, true, false, false) => Some("gzip, br"),
615            (true, false, true, true) => Some("gzip, zstd, deflate"),
616            (true, false, false, true) => Some("gzip, deflate"),
617            (false, true, true, true) => Some("br, zstd, deflate"),
618            (false, true, false, true) => Some("br, deflate"),
619            (true, false, true, false) => Some("gzip, zstd"),
620            (true, false, false, false) => Some("gzip"),
621            (false, true, true, false) => Some("br, zstd"),
622            (false, true, false, false) => Some("br"),
623            (false, false, true, true) => Some("zstd, deflate"),
624            (false, false, true, false) => Some("zstd"),
625            (false, false, false, true) => Some("deflate"),
626            (false, false, false, false) => None,
627        }
628    }
629
630    fn is_gzip(&self) -> bool {
631        #[cfg(feature = "gzip")]
632        {
633            self.gzip
634        }
635
636        #[cfg(not(feature = "gzip"))]
637        {
638            false
639        }
640    }
641
642    fn is_brotli(&self) -> bool {
643        #[cfg(feature = "brotli")]
644        {
645            self.brotli
646        }
647
648        #[cfg(not(feature = "brotli"))]
649        {
650            false
651        }
652    }
653
654    fn is_zstd(&self) -> bool {
655        #[cfg(feature = "zstd")]
656        {
657            self.zstd
658        }
659
660        #[cfg(not(feature = "zstd"))]
661        {
662            false
663        }
664    }
665
666    fn is_deflate(&self) -> bool {
667        #[cfg(feature = "deflate")]
668        {
669            self.deflate
670        }
671
672        #[cfg(not(feature = "deflate"))]
673        {
674            false
675        }
676    }
677}
678
679impl Default for Accepts {
680    fn default() -> Accepts {
681        Accepts {
682            #[cfg(feature = "gzip")]
683            gzip: true,
684            #[cfg(feature = "brotli")]
685            brotli: true,
686            #[cfg(feature = "zstd")]
687            zstd: true,
688            #[cfg(feature = "deflate")]
689            deflate: true,
690        }
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    #[test]
699    fn accepts_as_str() {
700        fn format_accept_encoding(accepts: &Accepts) -> String {
701            let mut encodings = vec![];
702            if accepts.is_gzip() {
703                encodings.push("gzip");
704            }
705            if accepts.is_brotli() {
706                encodings.push("br");
707            }
708            if accepts.is_zstd() {
709                encodings.push("zstd");
710            }
711            if accepts.is_deflate() {
712                encodings.push("deflate");
713            }
714            encodings.join(", ")
715        }
716
717        let state = [true, false];
718        let mut permutations = Vec::new();
719
720        #[allow(unused_variables)]
721        for gzip in state {
722            for brotli in state {
723                for zstd in state {
724                    for deflate in state {
725                        permutations.push(Accepts {
726                            #[cfg(feature = "gzip")]
727                            gzip,
728                            #[cfg(feature = "brotli")]
729                            brotli,
730                            #[cfg(feature = "zstd")]
731                            zstd,
732                            #[cfg(feature = "deflate")]
733                            deflate,
734                        });
735                    }
736                }
737            }
738        }
739
740        for accepts in permutations {
741            let expected = format_accept_encoding(&accepts);
742            let got = accepts.as_str().unwrap_or("");
743            assert_eq!(got, expected.as_str());
744        }
745    }
746}