rocket_http/listener.rs
1use std::fmt;
2use std::future::Future;
3use std::io;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8use std::sync::Arc;
9
10use log::warn;
11use tokio::time::Sleep;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio::net::TcpStream;
14use hyper::server::accept::Accept;
15use state::InitCell;
16
17pub use tokio::net::TcpListener;
18
19/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
20// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
21#[doc(inline)]
22#[cfg(feature = "tls")]
23pub use rustls::Certificate as CertificateData;
24
25/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
26#[cfg(not(feature = "tls"))]
27#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
28pub struct CertificateData(pub Vec<u8>);
29
30/// A collection of raw certificate data.
31#[derive(Clone, Default)]
32pub struct Certificates(Arc<InitCell<Vec<CertificateData>>>);
33
34impl From<Vec<CertificateData>> for Certificates {
35 fn from(value: Vec<CertificateData>) -> Self {
36 Certificates(Arc::new(value.into()))
37 }
38}
39
40impl Certificates {
41 /// Set the the raw certificate chain data. Only the first call actually
42 /// sets the data; the remaining do nothing.
43 #[cfg(feature = "tls")]
44 pub(crate) fn set(&self, data: Vec<CertificateData>) {
45 self.0.set(data);
46 }
47
48 /// Returns the raw certificate chain data, if any is available.
49 pub fn chain_data(&self) -> Option<&[CertificateData]> {
50 self.0.try_get().map(|v| v.as_slice())
51 }
52}
53
54// TODO.async: 'Listener' and 'Connection' provide common enough functionality
55// that they could be introduced in upstream libraries.
56/// A 'Listener' yields incoming connections
57pub trait Listener {
58 /// The connection type returned by this listener.
59 type Connection: Connection;
60
61 /// Return the actual address this listener bound to.
62 fn local_addr(&self) -> Option<SocketAddr>;
63
64 /// Try to accept an incoming Connection if ready. This should only return
65 /// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
66 fn poll_accept(
67 self: Pin<&mut Self>,
68 cx: &mut Context<'_>
69 ) -> Poll<io::Result<Self::Connection>>;
70}
71
72/// A 'Connection' represents an open connection to a client
73pub trait Connection: AsyncRead + AsyncWrite {
74 /// The remote address, i.e. the client's socket address, if it is known.
75 fn peer_address(&self) -> Option<SocketAddr>;
76
77 /// Requests that the connection not delay reading or writing data as much
78 /// as possible. For connections backed by TCP, this corresponds to setting
79 /// `TCP_NODELAY`.
80 fn enable_nodelay(&self) -> io::Result<()>;
81
82 /// DER-encoded X.509 certificate chain presented by the client, if any.
83 ///
84 /// The certificate order must be as it appears in the TLS protocol: the
85 /// first certificate relates to the peer, the second certifies the first,
86 /// the third certifies the second, and so on.
87 ///
88 /// Defaults to an empty vector to indicate that no certificates were
89 /// presented.
90 fn peer_certificates(&self) -> Option<Certificates> { None }
91}
92
93pin_project_lite::pin_project! {
94 /// This is a generic version of hyper's AddrIncoming that is intended to be
95 /// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
96 /// sockets. It does so by bridging the `Listener` trait to what hyper wants (an
97 /// Accept). This type is internal to Rocket.
98 #[must_use = "streams do nothing unless polled"]
99 pub struct Incoming<L> {
100 sleep_on_errors: Option<Duration>,
101 nodelay: bool,
102 #[pin]
103 pending_error_delay: Option<Sleep>,
104 #[pin]
105 listener: L,
106 }
107}
108
109impl<L: Listener> Incoming<L> {
110 /// Construct an `Incoming` from an existing `Listener`.
111 pub fn new(listener: L) -> Self {
112 Self {
113 listener,
114 sleep_on_errors: Some(Duration::from_millis(250)),
115 pending_error_delay: None,
116 nodelay: false,
117 }
118 }
119
120 /// Set whether and how long to sleep on accept errors.
121 ///
122 /// A possible scenario is that the process has hit the max open files
123 /// allowed, and so trying to accept a new connection will fail with
124 /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
125 /// the application will likely close some files (or connections), and try
126 /// to accept the connection again. If this option is `true`, the error
127 /// will be logged at the `error` level, since it is still a big deal,
128 /// and then the listener will sleep for 1 second.
129 ///
130 /// In other cases, hitting the max open files should be treat similarly
131 /// to being out-of-memory, and simply error (and shutdown). Setting
132 /// this option to `None` will allow that.
133 ///
134 /// Default is 1 second.
135 pub fn sleep_on_errors(mut self, val: Option<Duration>) -> Self {
136 self.sleep_on_errors = val;
137 self
138 }
139
140 /// Set whether to request no delay on all incoming connections. The default
141 /// is `false`. See [`Connection::enable_nodelay()`] for details.
142 pub fn nodelay(mut self, nodelay: bool) -> Self {
143 self.nodelay = nodelay;
144 self
145 }
146
147 fn poll_accept_next(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>
150 ) -> Poll<io::Result<L::Connection>> {
151 /// This function defines per-connection errors: errors that affect only
152 /// a single connection's accept() and don't imply anything about the
153 /// success probability of the next accept(). Thus, we can attempt to
154 /// `accept()` another connection immediately. All other errors will
155 /// incur a delay before the next `accept()` is performed. The delay is
156 /// useful to handle resource exhaustion errors like ENFILE and EMFILE.
157 /// Otherwise, could enter into tight loop.
158 fn is_connection_error(e: &io::Error) -> bool {
159 matches!(e.kind(),
160 | io::ErrorKind::ConnectionRefused
161 | io::ErrorKind::ConnectionAborted
162 | io::ErrorKind::ConnectionReset)
163 }
164
165 let mut this = self.project();
166 loop {
167 // Check if a previous sleep timer is active, set on I/O errors.
168 if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
169 futures::ready!(delay.poll(cx));
170 }
171
172 this.pending_error_delay.set(None);
173
174 match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
175 Ok(stream) => {
176 if *this.nodelay {
177 if let Err(e) = stream.enable_nodelay() {
178 warn!("failed to enable NODELAY: {}", e);
179 }
180 }
181
182 return Poll::Ready(Ok(stream));
183 },
184 Err(e) => {
185 if is_connection_error(&e) {
186 warn!("single connection accept error {}; accepting next now", e);
187 } else if let Some(duration) = this.sleep_on_errors {
188 // We might be able to recover. Try again in a bit.
189 warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
190 this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
191 } else {
192 return Poll::Ready(Err(e));
193 }
194 },
195 }
196 }
197 }
198}
199
200impl<L: Listener> Accept for Incoming<L> {
201 type Conn = L::Connection;
202 type Error = io::Error;
203
204 #[inline]
205 fn poll_accept(
206 self: Pin<&mut Self>,
207 cx: &mut Context<'_>
208 ) -> Poll<Option<io::Result<Self::Conn>>> {
209 self.poll_accept_next(cx).map(Some)
210 }
211}
212
213impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 f.debug_struct("Incoming")
216 .field("listener", &self.listener)
217 .finish()
218 }
219}
220
221impl Listener for TcpListener {
222 type Connection = TcpStream;
223
224 #[inline]
225 fn local_addr(&self) -> Option<SocketAddr> {
226 self.local_addr().ok()
227 }
228
229 #[inline]
230 fn poll_accept(
231 self: Pin<&mut Self>,
232 cx: &mut Context<'_>
233 ) -> Poll<io::Result<Self::Connection>> {
234 (*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
235 }
236}
237
238impl Connection for TcpStream {
239 #[inline]
240 fn peer_address(&self) -> Option<SocketAddr> {
241 self.peer_addr().ok()
242 }
243
244 #[inline]
245 fn enable_nodelay(&self) -> io::Result<()> {
246 self.set_nodelay(true)
247 }
248}