rocket_http/
listener.rs

1use std::fmt;
2use std::future::Future;
3use std::io;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8use std::sync::Arc;
9
10use log::warn;
11use tokio::time::Sleep;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio::net::TcpStream;
14use hyper::server::accept::Accept;
15use state::InitCell;
16
17pub use tokio::net::TcpListener;
18
19/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
20// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
21#[doc(inline)]
22#[cfg(feature = "tls")]
23pub use rustls::Certificate as CertificateData;
24
25/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
26#[cfg(not(feature = "tls"))]
27#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
28pub struct CertificateData(pub Vec<u8>);
29
30/// A collection of raw certificate data.
31#[derive(Clone, Default)]
32pub struct Certificates(Arc<InitCell<Vec<CertificateData>>>);
33
34impl From<Vec<CertificateData>> for Certificates {
35    fn from(value: Vec<CertificateData>) -> Self {
36        Certificates(Arc::new(value.into()))
37    }
38}
39
40impl Certificates {
41    /// Set the the raw certificate chain data. Only the first call actually
42    /// sets the data; the remaining do nothing.
43    #[cfg(feature = "tls")]
44    pub(crate) fn set(&self, data: Vec<CertificateData>) {
45        self.0.set(data);
46    }
47
48    /// Returns the raw certificate chain data, if any is available.
49    pub fn chain_data(&self) -> Option<&[CertificateData]> {
50        self.0.try_get().map(|v| v.as_slice())
51    }
52}
53
54// TODO.async: 'Listener' and 'Connection' provide common enough functionality
55// that they could be introduced in upstream libraries.
56/// A 'Listener' yields incoming connections
57pub trait Listener {
58    /// The connection type returned by this listener.
59    type Connection: Connection;
60
61    /// Return the actual address this listener bound to.
62    fn local_addr(&self) -> Option<SocketAddr>;
63
64    /// Try to accept an incoming Connection if ready. This should only return
65    /// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
66    fn poll_accept(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>
69    ) -> Poll<io::Result<Self::Connection>>;
70}
71
72/// A 'Connection' represents an open connection to a client
73pub trait Connection: AsyncRead + AsyncWrite {
74    /// The remote address, i.e. the client's socket address, if it is known.
75    fn peer_address(&self) -> Option<SocketAddr>;
76
77    /// Requests that the connection not delay reading or writing data as much
78    /// as possible. For connections backed by TCP, this corresponds to setting
79    /// `TCP_NODELAY`.
80    fn enable_nodelay(&self) -> io::Result<()>;
81
82    /// DER-encoded X.509 certificate chain presented by the client, if any.
83    ///
84    /// The certificate order must be as it appears in the TLS protocol: the
85    /// first certificate relates to the peer, the second certifies the first,
86    /// the third certifies the second, and so on.
87    ///
88    /// Defaults to an empty vector to indicate that no certificates were
89    /// presented.
90    fn peer_certificates(&self) -> Option<Certificates> { None }
91}
92
93pin_project_lite::pin_project! {
94    /// This is a generic version of hyper's AddrIncoming that is intended to be
95    /// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
96    /// sockets. It does so by bridging the `Listener` trait to what hyper wants (an
97    /// Accept). This type is internal to Rocket.
98    #[must_use = "streams do nothing unless polled"]
99    pub struct Incoming<L> {
100        sleep_on_errors: Option<Duration>,
101        nodelay: bool,
102        #[pin]
103        pending_error_delay: Option<Sleep>,
104        #[pin]
105        listener: L,
106    }
107}
108
109impl<L: Listener> Incoming<L> {
110    /// Construct an `Incoming` from an existing `Listener`.
111    pub fn new(listener: L) -> Self {
112        Self {
113            listener,
114            sleep_on_errors: Some(Duration::from_millis(250)),
115            pending_error_delay: None,
116            nodelay: false,
117        }
118    }
119
120    /// Set whether and how long to sleep on accept errors.
121    ///
122    /// A possible scenario is that the process has hit the max open files
123    /// allowed, and so trying to accept a new connection will fail with
124    /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
125    /// the application will likely close some files (or connections), and try
126    /// to accept the connection again. If this option is `true`, the error
127    /// will be logged at the `error` level, since it is still a big deal,
128    /// and then the listener will sleep for 1 second.
129    ///
130    /// In other cases, hitting the max open files should be treat similarly
131    /// to being out-of-memory, and simply error (and shutdown). Setting
132    /// this option to `None` will allow that.
133    ///
134    /// Default is 1 second.
135    pub fn sleep_on_errors(mut self, val: Option<Duration>) -> Self {
136        self.sleep_on_errors = val;
137        self
138    }
139
140    /// Set whether to request no delay on all incoming connections. The default
141    /// is `false`. See [`Connection::enable_nodelay()`] for details.
142    pub fn nodelay(mut self, nodelay: bool) -> Self {
143        self.nodelay = nodelay;
144        self
145    }
146
147    fn poll_accept_next(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>
150    ) -> Poll<io::Result<L::Connection>> {
151        /// This function defines per-connection errors: errors that affect only
152        /// a single connection's accept() and don't imply anything about the
153        /// success probability of the next accept(). Thus, we can attempt to
154        /// `accept()` another connection immediately. All other errors will
155        /// incur a delay before the next `accept()` is performed. The delay is
156        /// useful to handle resource exhaustion errors like ENFILE and EMFILE.
157        /// Otherwise, could enter into tight loop.
158        fn is_connection_error(e: &io::Error) -> bool {
159            matches!(e.kind(),
160                | io::ErrorKind::ConnectionRefused
161                | io::ErrorKind::ConnectionAborted
162                | io::ErrorKind::ConnectionReset)
163        }
164
165        let mut this = self.project();
166        loop {
167            // Check if a previous sleep timer is active, set on I/O errors.
168            if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
169                futures::ready!(delay.poll(cx));
170            }
171
172            this.pending_error_delay.set(None);
173
174            match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
175                Ok(stream) => {
176                    if *this.nodelay {
177                        if let Err(e) = stream.enable_nodelay() {
178                            warn!("failed to enable NODELAY: {}", e);
179                        }
180                    }
181
182                    return Poll::Ready(Ok(stream));
183                },
184                Err(e) => {
185                    if is_connection_error(&e) {
186                        warn!("single connection accept error {}; accepting next now", e);
187                    } else if let Some(duration) = this.sleep_on_errors {
188                        // We might be able to recover. Try again in a bit.
189                        warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
190                        this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
191                    } else {
192                        return Poll::Ready(Err(e));
193                    }
194                },
195            }
196        }
197    }
198}
199
200impl<L: Listener> Accept for Incoming<L> {
201    type Conn = L::Connection;
202    type Error = io::Error;
203
204    #[inline]
205    fn poll_accept(
206        self: Pin<&mut Self>,
207        cx: &mut Context<'_>
208    ) -> Poll<Option<io::Result<Self::Conn>>> {
209        self.poll_accept_next(cx).map(Some)
210    }
211}
212
213impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
214    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215        f.debug_struct("Incoming")
216            .field("listener", &self.listener)
217            .finish()
218    }
219}
220
221impl Listener for TcpListener {
222    type Connection = TcpStream;
223
224    #[inline]
225    fn local_addr(&self) -> Option<SocketAddr> {
226        self.local_addr().ok()
227    }
228
229    #[inline]
230    fn poll_accept(
231        self: Pin<&mut Self>,
232        cx: &mut Context<'_>
233    ) -> Poll<io::Result<Self::Connection>> {
234        (*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
235    }
236}
237
238impl Connection for TcpStream {
239    #[inline]
240    fn peer_address(&self) -> Option<SocketAddr> {
241        self.peer_addr().ok()
242    }
243
244    #[inline]
245    fn enable_nodelay(&self) -> io::Result<()> {
246        self.set_nodelay(true)
247    }
248}