h2/frame/
headers.rs

1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{Buf, BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13
14type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15
16/// Header frame
17///
18/// This could be either a request or a response.
19#[derive(Eq, PartialEq)]
20pub struct Headers {
21    /// The ID of the stream with which this frame is associated.
22    stream_id: StreamId,
23
24    /// The stream dependency information, if any.
25    stream_dep: Option<StreamDependency>,
26
27    /// The header block fragment
28    header_block: HeaderBlock,
29
30    /// The associated flags
31    flags: HeadersFlag,
32}
33
34#[derive(Copy, Clone, Eq, PartialEq)]
35pub struct HeadersFlag(u8);
36
37#[derive(Eq, PartialEq)]
38pub struct PushPromise {
39    /// The ID of the stream with which this frame is associated.
40    stream_id: StreamId,
41
42    /// The ID of the stream being reserved by this PushPromise.
43    promised_id: StreamId,
44
45    /// The header block fragment
46    header_block: HeaderBlock,
47
48    /// The associated flags
49    flags: PushPromiseFlag,
50}
51
52#[derive(Copy, Clone, Eq, PartialEq)]
53pub struct PushPromiseFlag(u8);
54
55#[derive(Debug)]
56pub struct Continuation {
57    /// Stream ID of continuation frame
58    stream_id: StreamId,
59
60    header_block: EncodingHeaderBlock,
61}
62
63// TODO: These fields shouldn't be `pub`
64#[derive(Debug, Default, Eq, PartialEq)]
65pub struct Pseudo {
66    // Request
67    pub method: Option<Method>,
68    pub scheme: Option<BytesStr>,
69    pub authority: Option<BytesStr>,
70    pub path: Option<BytesStr>,
71    pub protocol: Option<Protocol>,
72
73    // Response
74    pub status: Option<StatusCode>,
75}
76
77#[derive(Debug)]
78pub struct Iter {
79    /// Pseudo headers
80    pseudo: Option<Pseudo>,
81
82    /// Header fields
83    fields: header::IntoIter<HeaderValue>,
84}
85
86#[derive(Debug, PartialEq, Eq)]
87struct HeaderBlock {
88    /// The decoded header fields
89    fields: HeaderMap,
90
91    /// Precomputed size of all of our header fields, for perf reasons
92    field_size: usize,
93
94    /// Set to true if decoding went over the max header list size.
95    is_over_size: bool,
96
97    /// Pseudo headers, these are broken out as they must be sent as part of the
98    /// headers frame.
99    pseudo: Pseudo,
100}
101
102#[derive(Debug)]
103struct EncodingHeaderBlock {
104    hpack: Bytes,
105}
106
107const END_STREAM: u8 = 0x1;
108const END_HEADERS: u8 = 0x4;
109const PADDED: u8 = 0x8;
110const PRIORITY: u8 = 0x20;
111const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
112
113// ===== impl Headers =====
114
115impl Headers {
116    /// Create a new HEADERS frame
117    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
118        Headers {
119            stream_id,
120            stream_dep: None,
121            header_block: HeaderBlock {
122                field_size: calculate_headermap_size(&fields),
123                fields,
124                is_over_size: false,
125                pseudo,
126            },
127            flags: HeadersFlag::default(),
128        }
129    }
130
131    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
132        let mut flags = HeadersFlag::default();
133        flags.set_end_stream();
134
135        Headers {
136            stream_id,
137            stream_dep: None,
138            header_block: HeaderBlock {
139                field_size: calculate_headermap_size(&fields),
140                fields,
141                is_over_size: false,
142                pseudo: Pseudo::default(),
143            },
144            flags,
145        }
146    }
147
148    /// Loads the header frame but doesn't actually do HPACK decoding.
149    ///
150    /// HPACK decoding is done in the `load_hpack` step.
151    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
152        let flags = HeadersFlag(head.flag());
153        let mut pad = 0;
154
155        tracing::trace!("loading headers; flags={:?}", flags);
156
157        if head.stream_id().is_zero() {
158            return Err(Error::InvalidStreamId);
159        }
160
161        // Read the padding length
162        if flags.is_padded() {
163            if src.is_empty() {
164                return Err(Error::MalformedMessage);
165            }
166            pad = src[0] as usize;
167
168            // Drop the padding
169            src.advance(1);
170        }
171
172        // Read the stream dependency
173        let stream_dep = if flags.is_priority() {
174            if src.len() < 5 {
175                return Err(Error::MalformedMessage);
176            }
177            let stream_dep = StreamDependency::load(&src[..5])?;
178
179            if stream_dep.dependency_id() == head.stream_id() {
180                return Err(Error::InvalidDependencyId);
181            }
182
183            // Drop the next 5 bytes
184            src.advance(5);
185
186            Some(stream_dep)
187        } else {
188            None
189        };
190
191        if pad > 0 {
192            if pad > src.len() {
193                return Err(Error::TooMuchPadding);
194            }
195
196            let len = src.len() - pad;
197            src.truncate(len);
198        }
199
200        let headers = Headers {
201            stream_id: head.stream_id(),
202            stream_dep,
203            header_block: HeaderBlock {
204                fields: HeaderMap::new(),
205                field_size: 0,
206                is_over_size: false,
207                pseudo: Pseudo::default(),
208            },
209            flags,
210        };
211
212        Ok((headers, src))
213    }
214
215    pub fn load_hpack(
216        &mut self,
217        src: &mut BytesMut,
218        max_header_list_size: usize,
219        decoder: &mut hpack::Decoder,
220    ) -> Result<(), Error> {
221        self.header_block.load(src, max_header_list_size, decoder)
222    }
223
224    pub fn stream_id(&self) -> StreamId {
225        self.stream_id
226    }
227
228    pub fn is_end_headers(&self) -> bool {
229        self.flags.is_end_headers()
230    }
231
232    pub fn set_end_headers(&mut self) {
233        self.flags.set_end_headers();
234    }
235
236    pub fn is_end_stream(&self) -> bool {
237        self.flags.is_end_stream()
238    }
239
240    pub fn set_end_stream(&mut self) {
241        self.flags.set_end_stream()
242    }
243
244    pub fn is_over_size(&self) -> bool {
245        self.header_block.is_over_size
246    }
247
248    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249        (self.header_block.pseudo, self.header_block.fields)
250    }
251
252    #[cfg(feature = "unstable")]
253    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254        &mut self.header_block.pseudo
255    }
256
257    pub(crate) fn pseudo(&self) -> &Pseudo {
258        &self.header_block.pseudo
259    }
260
261    /// Whether it has status 1xx
262    pub(crate) fn is_informational(&self) -> bool {
263        self.header_block.pseudo.is_informational()
264    }
265
266    pub fn fields(&self) -> &HeaderMap {
267        &self.header_block.fields
268    }
269
270    pub fn into_fields(self) -> HeaderMap {
271        self.header_block.fields
272    }
273
274    pub fn encode(
275        self,
276        encoder: &mut hpack::Encoder,
277        dst: &mut EncodeBuf<'_>,
278    ) -> Option<Continuation> {
279        // At this point, the `is_end_headers` flag should always be set
280        debug_assert!(self.flags.is_end_headers());
281
282        // Get the HEADERS frame head
283        let head = self.head();
284
285        self.header_block
286            .into_encoding(encoder)
287            .encode(&head, dst, |_| {})
288    }
289
290    fn head(&self) -> Head {
291        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
292    }
293}
294
295impl<T> From<Headers> for Frame<T> {
296    fn from(src: Headers) -> Self {
297        Frame::Headers(src)
298    }
299}
300
301impl fmt::Debug for Headers {
302    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
303        let mut builder = f.debug_struct("Headers");
304        builder
305            .field("stream_id", &self.stream_id)
306            .field("flags", &self.flags);
307
308        if let Some(ref protocol) = self.header_block.pseudo.protocol {
309            builder.field("protocol", protocol);
310        }
311
312        if let Some(ref dep) = self.stream_dep {
313            builder.field("stream_dep", dep);
314        }
315
316        // `fields` and `pseudo` purposefully not included
317        builder.finish()
318    }
319}
320
321// ===== util =====
322
323#[derive(Debug, PartialEq, Eq)]
324pub struct ParseU64Error;
325
326pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
327    if src.len() > 19 {
328        // At danger for overflow...
329        return Err(ParseU64Error);
330    }
331
332    let mut ret = 0;
333
334    for &d in src {
335        if d < b'0' || d > b'9' {
336            return Err(ParseU64Error);
337        }
338
339        ret *= 10;
340        ret += (d - b'0') as u64;
341    }
342
343    Ok(ret)
344}
345
346// ===== impl PushPromise =====
347
348#[derive(Debug)]
349pub enum PushPromiseHeaderError {
350    InvalidContentLength(Result<u64, ParseU64Error>),
351    NotSafeAndCacheable,
352}
353
354impl PushPromise {
355    pub fn new(
356        stream_id: StreamId,
357        promised_id: StreamId,
358        pseudo: Pseudo,
359        fields: HeaderMap,
360    ) -> Self {
361        PushPromise {
362            flags: PushPromiseFlag::default(),
363            header_block: HeaderBlock {
364                field_size: calculate_headermap_size(&fields),
365                fields,
366                is_over_size: false,
367                pseudo,
368            },
369            promised_id,
370            stream_id,
371        }
372    }
373
374    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
375        use PushPromiseHeaderError::*;
376        // The spec has some requirements for promised request headers
377        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
378
379        // A promised request "that indicates the presence of a request body
380        // MUST reset the promised stream with a stream error"
381        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
382            let parsed_length = parse_u64(content_length.as_bytes());
383            if parsed_length != Ok(0) {
384                return Err(InvalidContentLength(parsed_length));
385            }
386        }
387        // "The server MUST include a method in the :method pseudo-header field
388        // that is safe and cacheable"
389        if !Self::safe_and_cacheable(req.method()) {
390            return Err(NotSafeAndCacheable);
391        }
392
393        Ok(())
394    }
395
396    fn safe_and_cacheable(method: &Method) -> bool {
397        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
398        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
399        method == Method::GET || method == Method::HEAD
400    }
401
402    pub fn fields(&self) -> &HeaderMap {
403        &self.header_block.fields
404    }
405
406    #[cfg(feature = "unstable")]
407    pub fn into_fields(self) -> HeaderMap {
408        self.header_block.fields
409    }
410
411    /// Loads the push promise frame but doesn't actually do HPACK decoding.
412    ///
413    /// HPACK decoding is done in the `load_hpack` step.
414    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
415        let flags = PushPromiseFlag(head.flag());
416        let mut pad = 0;
417
418        if head.stream_id().is_zero() {
419            return Err(Error::InvalidStreamId);
420        }
421
422        // Read the padding length
423        if flags.is_padded() {
424            if src.is_empty() {
425                return Err(Error::MalformedMessage);
426            }
427
428            // TODO: Ensure payload is sized correctly
429            pad = src[0] as usize;
430
431            // Drop the padding
432            src.advance(1);
433        }
434
435        if src.len() < 5 {
436            return Err(Error::MalformedMessage);
437        }
438
439        let (promised_id, _) = StreamId::parse(&src[..4]);
440        // Drop promised_id bytes
441        src.advance(4);
442
443        if pad > 0 {
444            if pad > src.len() {
445                return Err(Error::TooMuchPadding);
446            }
447
448            let len = src.len() - pad;
449            src.truncate(len);
450        }
451
452        let frame = PushPromise {
453            flags,
454            header_block: HeaderBlock {
455                fields: HeaderMap::new(),
456                field_size: 0,
457                is_over_size: false,
458                pseudo: Pseudo::default(),
459            },
460            promised_id,
461            stream_id: head.stream_id(),
462        };
463        Ok((frame, src))
464    }
465
466    pub fn load_hpack(
467        &mut self,
468        src: &mut BytesMut,
469        max_header_list_size: usize,
470        decoder: &mut hpack::Decoder,
471    ) -> Result<(), Error> {
472        self.header_block.load(src, max_header_list_size, decoder)
473    }
474
475    pub fn stream_id(&self) -> StreamId {
476        self.stream_id
477    }
478
479    pub fn promised_id(&self) -> StreamId {
480        self.promised_id
481    }
482
483    pub fn is_end_headers(&self) -> bool {
484        self.flags.is_end_headers()
485    }
486
487    pub fn set_end_headers(&mut self) {
488        self.flags.set_end_headers();
489    }
490
491    pub fn is_over_size(&self) -> bool {
492        self.header_block.is_over_size
493    }
494
495    pub fn encode(
496        self,
497        encoder: &mut hpack::Encoder,
498        dst: &mut EncodeBuf<'_>,
499    ) -> Option<Continuation> {
500        // At this point, the `is_end_headers` flag should always be set
501        debug_assert!(self.flags.is_end_headers());
502
503        let head = self.head();
504        let promised_id = self.promised_id;
505
506        self.header_block
507            .into_encoding(encoder)
508            .encode(&head, dst, |dst| {
509                dst.put_u32(promised_id.into());
510            })
511    }
512
513    fn head(&self) -> Head {
514        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
515    }
516
517    /// Consume `self`, returning the parts of the frame
518    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
519        (self.header_block.pseudo, self.header_block.fields)
520    }
521}
522
523impl<T> From<PushPromise> for Frame<T> {
524    fn from(src: PushPromise) -> Self {
525        Frame::PushPromise(src)
526    }
527}
528
529impl fmt::Debug for PushPromise {
530    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
531        f.debug_struct("PushPromise")
532            .field("stream_id", &self.stream_id)
533            .field("promised_id", &self.promised_id)
534            .field("flags", &self.flags)
535            // `fields` and `pseudo` purposefully not included
536            .finish()
537    }
538}
539
540// ===== impl Continuation =====
541
542impl Continuation {
543    fn head(&self) -> Head {
544        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
545    }
546
547    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
548        // Get the CONTINUATION frame head
549        let head = self.head();
550
551        self.header_block.encode(&head, dst, |_| {})
552    }
553}
554
555// ===== impl Pseudo =====
556
557impl Pseudo {
558    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
559        let parts = uri::Parts::from(uri);
560
561        let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
562            (None, None)
563        } else {
564            let path = parts
565                .path_and_query
566                .map(|v| BytesStr::from(v.as_str()))
567                .unwrap_or(BytesStr::from_static(""));
568
569            let path = if !path.is_empty() {
570                path
571            } else if method == Method::OPTIONS {
572                BytesStr::from_static("*")
573            } else {
574                BytesStr::from_static("/")
575            };
576
577            (parts.scheme, Some(path))
578        };
579
580        let mut pseudo = Pseudo {
581            method: Some(method),
582            scheme: None,
583            authority: None,
584            path,
585            protocol,
586            status: None,
587        };
588
589        // If the URI includes a scheme component, add it to the pseudo headers
590        if let Some(scheme) = scheme {
591            pseudo.set_scheme(scheme);
592        }
593
594        // If the URI includes an authority component, add it to the pseudo
595        // headers
596        if let Some(authority) = parts.authority {
597            pseudo.set_authority(BytesStr::from(authority.as_str()));
598        }
599
600        pseudo
601    }
602
603    pub fn response(status: StatusCode) -> Self {
604        Pseudo {
605            method: None,
606            scheme: None,
607            authority: None,
608            path: None,
609            protocol: None,
610            status: Some(status),
611        }
612    }
613
614    #[cfg(feature = "unstable")]
615    pub fn set_status(&mut self, value: StatusCode) {
616        self.status = Some(value);
617    }
618
619    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
620        let bytes_str = match scheme.as_str() {
621            "http" => BytesStr::from_static("http"),
622            "https" => BytesStr::from_static("https"),
623            s => BytesStr::from(s),
624        };
625        self.scheme = Some(bytes_str);
626    }
627
628    #[cfg(feature = "unstable")]
629    pub fn set_protocol(&mut self, protocol: Protocol) {
630        self.protocol = Some(protocol);
631    }
632
633    pub fn set_authority(&mut self, authority: BytesStr) {
634        self.authority = Some(authority);
635    }
636
637    /// Whether it has status 1xx
638    pub(crate) fn is_informational(&self) -> bool {
639        self.status
640            .map_or(false, |status| status.is_informational())
641    }
642}
643
644// ===== impl EncodingHeaderBlock =====
645
646impl EncodingHeaderBlock {
647    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
648    where
649        F: FnOnce(&mut EncodeBuf<'_>),
650    {
651        let head_pos = dst.get_ref().len();
652
653        // At this point, we don't know how big the h2 frame will be.
654        // So, we write the head with length 0, then write the body, and
655        // finally write the length once we know the size.
656        head.encode(0, dst);
657
658        let payload_pos = dst.get_ref().len();
659
660        f(dst);
661
662        // Now, encode the header payload
663        let continuation = if self.hpack.len() > dst.remaining_mut() {
664            dst.put((&mut self.hpack).take(dst.remaining_mut()));
665
666            Some(Continuation {
667                stream_id: head.stream_id(),
668                header_block: self,
669            })
670        } else {
671            dst.put_slice(&self.hpack);
672
673            None
674        };
675
676        // Compute the header block length
677        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
678
679        // Write the frame length
680        let payload_len_be = payload_len.to_be_bytes();
681        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
682        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
683
684        if continuation.is_some() {
685            // There will be continuation frames, so the `is_end_headers` flag
686            // must be unset
687            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
688
689            dst.get_mut()[head_pos + 4] -= END_HEADERS;
690        }
691
692        continuation
693    }
694}
695
696// ===== impl Iter =====
697
698impl Iterator for Iter {
699    type Item = hpack::Header<Option<HeaderName>>;
700
701    fn next(&mut self) -> Option<Self::Item> {
702        use crate::hpack::Header::*;
703
704        if let Some(ref mut pseudo) = self.pseudo {
705            if let Some(method) = pseudo.method.take() {
706                return Some(Method(method));
707            }
708
709            if let Some(scheme) = pseudo.scheme.take() {
710                return Some(Scheme(scheme));
711            }
712
713            if let Some(authority) = pseudo.authority.take() {
714                return Some(Authority(authority));
715            }
716
717            if let Some(path) = pseudo.path.take() {
718                return Some(Path(path));
719            }
720
721            if let Some(protocol) = pseudo.protocol.take() {
722                return Some(Protocol(protocol));
723            }
724
725            if let Some(status) = pseudo.status.take() {
726                return Some(Status(status));
727            }
728        }
729
730        self.pseudo = None;
731
732        self.fields
733            .next()
734            .map(|(name, value)| Field { name, value })
735    }
736}
737
738// ===== impl HeadersFlag =====
739
740impl HeadersFlag {
741    pub fn empty() -> HeadersFlag {
742        HeadersFlag(0)
743    }
744
745    pub fn load(bits: u8) -> HeadersFlag {
746        HeadersFlag(bits & ALL)
747    }
748
749    pub fn is_end_stream(&self) -> bool {
750        self.0 & END_STREAM == END_STREAM
751    }
752
753    pub fn set_end_stream(&mut self) {
754        self.0 |= END_STREAM;
755    }
756
757    pub fn is_end_headers(&self) -> bool {
758        self.0 & END_HEADERS == END_HEADERS
759    }
760
761    pub fn set_end_headers(&mut self) {
762        self.0 |= END_HEADERS;
763    }
764
765    pub fn is_padded(&self) -> bool {
766        self.0 & PADDED == PADDED
767    }
768
769    pub fn is_priority(&self) -> bool {
770        self.0 & PRIORITY == PRIORITY
771    }
772}
773
774impl Default for HeadersFlag {
775    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
776    fn default() -> Self {
777        HeadersFlag(END_HEADERS)
778    }
779}
780
781impl From<HeadersFlag> for u8 {
782    fn from(src: HeadersFlag) -> u8 {
783        src.0
784    }
785}
786
787impl fmt::Debug for HeadersFlag {
788    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
789        util::debug_flags(fmt, self.0)
790            .flag_if(self.is_end_headers(), "END_HEADERS")
791            .flag_if(self.is_end_stream(), "END_STREAM")
792            .flag_if(self.is_padded(), "PADDED")
793            .flag_if(self.is_priority(), "PRIORITY")
794            .finish()
795    }
796}
797
798// ===== impl PushPromiseFlag =====
799
800impl PushPromiseFlag {
801    pub fn empty() -> PushPromiseFlag {
802        PushPromiseFlag(0)
803    }
804
805    pub fn load(bits: u8) -> PushPromiseFlag {
806        PushPromiseFlag(bits & ALL)
807    }
808
809    pub fn is_end_headers(&self) -> bool {
810        self.0 & END_HEADERS == END_HEADERS
811    }
812
813    pub fn set_end_headers(&mut self) {
814        self.0 |= END_HEADERS;
815    }
816
817    pub fn is_padded(&self) -> bool {
818        self.0 & PADDED == PADDED
819    }
820}
821
822impl Default for PushPromiseFlag {
823    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
824    fn default() -> Self {
825        PushPromiseFlag(END_HEADERS)
826    }
827}
828
829impl From<PushPromiseFlag> for u8 {
830    fn from(src: PushPromiseFlag) -> u8 {
831        src.0
832    }
833}
834
835impl fmt::Debug for PushPromiseFlag {
836    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
837        util::debug_flags(fmt, self.0)
838            .flag_if(self.is_end_headers(), "END_HEADERS")
839            .flag_if(self.is_padded(), "PADDED")
840            .finish()
841    }
842}
843
844// ===== HeaderBlock =====
845
846impl HeaderBlock {
847    fn load(
848        &mut self,
849        src: &mut BytesMut,
850        max_header_list_size: usize,
851        decoder: &mut hpack::Decoder,
852    ) -> Result<(), Error> {
853        let mut reg = !self.fields.is_empty();
854        let mut malformed = false;
855        let mut headers_size = self.calculate_header_list_size();
856
857        macro_rules! set_pseudo {
858            ($field:ident, $val:expr) => {{
859                if reg {
860                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
861                    malformed = true;
862                } else if self.pseudo.$field.is_some() {
863                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
864                    malformed = true;
865                } else {
866                    let __val = $val;
867                    headers_size +=
868                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
869                    if headers_size < max_header_list_size {
870                        self.pseudo.$field = Some(__val);
871                    } else if !self.is_over_size {
872                        tracing::trace!("load_hpack; header list size over max");
873                        self.is_over_size = true;
874                    }
875                }
876            }};
877        }
878
879        let mut cursor = Cursor::new(src);
880
881        // If the header frame is malformed, we still have to continue decoding
882        // the headers. A malformed header frame is a stream level error, but
883        // the hpack state is connection level. In order to maintain correct
884        // state for other streams, the hpack decoding process must complete.
885        let res = decoder.decode(&mut cursor, |header| {
886            use crate::hpack::Header::*;
887
888            match header {
889                Field { name, value } => {
890                    // Connection level header fields are not supported and must
891                    // result in a protocol error.
892
893                    if name == header::CONNECTION
894                        || name == header::TRANSFER_ENCODING
895                        || name == header::UPGRADE
896                        || name == "keep-alive"
897                        || name == "proxy-connection"
898                    {
899                        tracing::trace!("load_hpack; connection level header");
900                        malformed = true;
901                    } else if name == header::TE && value != "trailers" {
902                        tracing::trace!(
903                            "load_hpack; TE header not set to trailers; val={:?}",
904                            value
905                        );
906                        malformed = true;
907                    } else {
908                        reg = true;
909
910                        headers_size += decoded_header_size(name.as_str().len(), value.len());
911                        if headers_size < max_header_list_size {
912                            self.field_size +=
913                                decoded_header_size(name.as_str().len(), value.len());
914                            self.fields.append(name, value);
915                        } else if !self.is_over_size {
916                            tracing::trace!("load_hpack; header list size over max");
917                            self.is_over_size = true;
918                        }
919                    }
920                }
921                Authority(v) => set_pseudo!(authority, v),
922                Method(v) => set_pseudo!(method, v),
923                Scheme(v) => set_pseudo!(scheme, v),
924                Path(v) => set_pseudo!(path, v),
925                Protocol(v) => set_pseudo!(protocol, v),
926                Status(v) => set_pseudo!(status, v),
927            }
928        });
929
930        if let Err(e) = res {
931            tracing::trace!("hpack decoding error; err={:?}", e);
932            return Err(e.into());
933        }
934
935        if malformed {
936            tracing::trace!("malformed message");
937            return Err(Error::MalformedMessage);
938        }
939
940        Ok(())
941    }
942
943    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
944        let mut hpack = BytesMut::new();
945        let headers = Iter {
946            pseudo: Some(self.pseudo),
947            fields: self.fields.into_iter(),
948        };
949
950        encoder.encode(headers, &mut hpack);
951
952        EncodingHeaderBlock {
953            hpack: hpack.freeze(),
954        }
955    }
956
957    /// Calculates the size of the currently decoded header list.
958    ///
959    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
960    ///
961    /// > The value is based on the uncompressed size of header fields,
962    /// > including the length of the name and value in octets plus an
963    /// > overhead of 32 octets for each header field.
964    fn calculate_header_list_size(&self) -> usize {
965        macro_rules! pseudo_size {
966            ($name:ident) => {{
967                self.pseudo
968                    .$name
969                    .as_ref()
970                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
971                    .unwrap_or(0)
972            }};
973        }
974
975        pseudo_size!(method)
976            + pseudo_size!(scheme)
977            + pseudo_size!(status)
978            + pseudo_size!(authority)
979            + pseudo_size!(path)
980            + self.field_size
981    }
982}
983
984fn calculate_headermap_size(map: &HeaderMap) -> usize {
985    map.iter()
986        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
987        .sum::<usize>()
988}
989
990fn decoded_header_size(name: usize, value: usize) -> usize {
991    name + value + 32
992}
993
994#[cfg(test)]
995mod test {
996    use super::*;
997    use crate::frame;
998    use crate::hpack::{huffman, Encoder};
999
1000    #[test]
1001    fn test_nameless_header_at_resume() {
1002        let mut encoder = Encoder::default();
1003        let mut dst = BytesMut::new();
1004
1005        let headers = Headers::new(
1006            StreamId::ZERO,
1007            Default::default(),
1008            HeaderMap::from_iter(vec![
1009                (
1010                    HeaderName::from_static("hello"),
1011                    HeaderValue::from_static("world"),
1012                ),
1013                (
1014                    HeaderName::from_static("hello"),
1015                    HeaderValue::from_static("zomg"),
1016                ),
1017                (
1018                    HeaderName::from_static("hello"),
1019                    HeaderValue::from_static("sup"),
1020                ),
1021            ]),
1022        );
1023
1024        let continuation = headers
1025            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1026            .unwrap();
1027
1028        assert_eq!(17, dst.len());
1029        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1030        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1031        assert_eq!("hello", huff_decode(&dst[11..15]));
1032        assert_eq!(0x80 | 4, dst[15]);
1033
1034        let mut world = dst[16..17].to_owned();
1035
1036        dst.clear();
1037
1038        assert!(continuation
1039            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1040            .is_none());
1041
1042        world.extend_from_slice(&dst[9..12]);
1043        assert_eq!("world", huff_decode(&world));
1044
1045        assert_eq!(24, dst.len());
1046        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1047
1048        // // Next is not indexed
1049        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1050        assert_eq!("zomg", huff_decode(&dst[15..18]));
1051        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1052        assert_eq!("sup", huff_decode(&dst[21..]));
1053    }
1054
1055    fn huff_decode(src: &[u8]) -> BytesMut {
1056        let mut buf = BytesMut::new();
1057        huffman::decode(src, &mut buf).unwrap()
1058    }
1059
1060    #[test]
1061    fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1062        // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields
1063        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5
1064
1065        assert_eq!(
1066            Pseudo::request(
1067                Method::CONNECT,
1068                Uri::from_static("https://example.com:8443"),
1069                None
1070            ),
1071            Pseudo {
1072                method: Method::CONNECT.into(),
1073                authority: BytesStr::from_static("example.com:8443").into(),
1074                ..Default::default()
1075            }
1076        );
1077
1078        assert_eq!(
1079            Pseudo::request(
1080                Method::CONNECT,
1081                Uri::from_static("https://example.com/test"),
1082                None
1083            ),
1084            Pseudo {
1085                method: Method::CONNECT.into(),
1086                authority: BytesStr::from_static("example.com").into(),
1087                ..Default::default()
1088            }
1089        );
1090
1091        assert_eq!(
1092            Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
1093            Pseudo {
1094                method: Method::CONNECT.into(),
1095                authority: BytesStr::from_static("example.com:8443").into(),
1096                ..Default::default()
1097            }
1098        );
1099    }
1100
1101    #[test]
1102    fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1103        // On requests that contain the :protocol pseudo-header field, the
1104        // :scheme and :path pseudo-header fields of the target URI (see
1105        // Section 5) MUST also be included.
1106        // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4
1107
1108        assert_eq!(
1109            Pseudo::request(
1110                Method::CONNECT,
1111                Uri::from_static("https://example.com:8443"),
1112                Protocol::from_static("the-bread-protocol").into()
1113            ),
1114            Pseudo {
1115                method: Method::CONNECT.into(),
1116                authority: BytesStr::from_static("example.com:8443").into(),
1117                scheme: BytesStr::from_static("https").into(),
1118                path: BytesStr::from_static("/").into(),
1119                protocol: Protocol::from_static("the-bread-protocol").into(),
1120                ..Default::default()
1121            }
1122        );
1123
1124        assert_eq!(
1125            Pseudo::request(
1126                Method::CONNECT,
1127                Uri::from_static("https://example.com:8443/test"),
1128                Protocol::from_static("the-bread-protocol").into()
1129            ),
1130            Pseudo {
1131                method: Method::CONNECT.into(),
1132                authority: BytesStr::from_static("example.com:8443").into(),
1133                scheme: BytesStr::from_static("https").into(),
1134                path: BytesStr::from_static("/test").into(),
1135                protocol: Protocol::from_static("the-bread-protocol").into(),
1136                ..Default::default()
1137            }
1138        );
1139
1140        assert_eq!(
1141            Pseudo::request(
1142                Method::CONNECT,
1143                Uri::from_static("http://example.com/a/b/c"),
1144                Protocol::from_static("the-bread-protocol").into()
1145            ),
1146            Pseudo {
1147                method: Method::CONNECT.into(),
1148                authority: BytesStr::from_static("example.com").into(),
1149                scheme: BytesStr::from_static("http").into(),
1150                path: BytesStr::from_static("/a/b/c").into(),
1151                protocol: Protocol::from_static("the-bread-protocol").into(),
1152                ..Default::default()
1153            }
1154        );
1155    }
1156
1157    #[test]
1158    fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1159        // an OPTIONS request for an "http" or "https" URI that does not include a path component;
1160        // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]).
1161        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1
1162        assert_eq!(
1163            Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
1164            Pseudo {
1165                method: Method::OPTIONS.into(),
1166                authority: BytesStr::from_static("example.com:8080").into(),
1167                path: BytesStr::from_static("*").into(),
1168                ..Default::default()
1169            }
1170        );
1171    }
1172
1173    #[test]
1174    fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1175        let methods = [
1176            Method::GET,
1177            Method::POST,
1178            Method::PUT,
1179            Method::DELETE,
1180            Method::HEAD,
1181            Method::PATCH,
1182            Method::TRACE,
1183        ];
1184
1185        for method in methods {
1186            assert_eq!(
1187                Pseudo::request(
1188                    method.clone(),
1189                    Uri::from_static("http://example.com:8080"),
1190                    None,
1191                ),
1192                Pseudo {
1193                    method: method.clone().into(),
1194                    authority: BytesStr::from_static("example.com:8080").into(),
1195                    scheme: BytesStr::from_static("http").into(),
1196                    path: BytesStr::from_static("/").into(),
1197                    ..Default::default()
1198                }
1199            );
1200            assert_eq!(
1201                Pseudo::request(
1202                    method.clone(),
1203                    Uri::from_static("https://example.com/a/b/c"),
1204                    None,
1205                ),
1206                Pseudo {
1207                    method: method.into(),
1208                    authority: BytesStr::from_static("example.com").into(),
1209                    scheme: BytesStr::from_static("https").into(),
1210                    path: BytesStr::from_static("/a/b/c").into(),
1211                    ..Default::default()
1212                }
1213            );
1214        }
1215    }
1216}