1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7#[cfg(any(feature = "socks", feature = "__tls"))]
8use hyper_util::rt::TokioIo;
9#[cfg(feature = "default-tls")]
10use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
11use pin_project_lite::pin_project;
12use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer};
13use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder};
14use tower_service::Service;
15
16use std::future::Future;
17use std::io::{self, IoSlice};
18use std::net::IpAddr;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::time::Duration;
23
24#[cfg(feature = "default-tls")]
25use self::native_tls_conn::NativeTlsConn;
26#[cfg(feature = "__rustls")]
27use self::rustls_tls_conn::RustlsTlsConn;
28use crate::dns::DynResolver;
29use crate::error::{cast_to_internal_error, BoxError};
30use crate::proxy::{Proxy, ProxyScheme};
31use sealed::{Conn, Unnameable};
32
33pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector<DynResolver>;
34
35#[derive(Clone)]
36pub(crate) enum Connector {
37 Simple(ConnectorService),
39 WithLayers(BoxCloneSyncService<Unnameable, Conn, BoxError>),
42}
43
44impl Service<Uri> for Connector {
45 type Response = Conn;
46 type Error = BoxError;
47 type Future = Connecting;
48
49 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 match self {
51 Connector::Simple(service) => service.poll_ready(cx),
52 Connector::WithLayers(service) => service.poll_ready(cx),
53 }
54 }
55
56 fn call(&mut self, dst: Uri) -> Self::Future {
57 match self {
58 Connector::Simple(service) => service.call(dst),
59 Connector::WithLayers(service) => service.call(Unnameable(dst)),
60 }
61 }
62}
63
64pub(crate) type BoxedConnectorService = BoxCloneSyncService<Unnameable, Conn, BoxError>;
65
66pub(crate) type BoxedConnectorLayer =
67 BoxCloneSyncServiceLayer<BoxedConnectorService, Unnameable, Conn, BoxError>;
68
69pub(crate) struct ConnectorBuilder {
70 inner: Inner,
71 proxies: Arc<Vec<Proxy>>,
72 verbose: verbose::Wrapper,
73 timeout: Option<Duration>,
74 #[cfg(feature = "__tls")]
75 nodelay: bool,
76 #[cfg(feature = "__tls")]
77 tls_info: bool,
78 #[cfg(feature = "__tls")]
79 user_agent: Option<HeaderValue>,
80}
81
82impl ConnectorBuilder {
83 pub(crate) fn build(self, layers: Vec<BoxedConnectorLayer>) -> Connector
84where {
85 let mut base_service = ConnectorService {
87 inner: self.inner,
88 proxies: self.proxies,
89 verbose: self.verbose,
90 #[cfg(feature = "__tls")]
91 nodelay: self.nodelay,
92 #[cfg(feature = "__tls")]
93 tls_info: self.tls_info,
94 #[cfg(feature = "__tls")]
95 user_agent: self.user_agent,
96 simple_timeout: None,
97 };
98
99 if layers.is_empty() {
100 base_service.simple_timeout = self.timeout;
102 return Connector::Simple(base_service);
103 }
104
105 let unnameable_service = ServiceBuilder::new()
109 .layer(MapRequestLayer::new(|request: Unnameable| request.0))
110 .service(base_service);
111 let mut service = BoxCloneSyncService::new(unnameable_service);
112
113 for layer in layers {
114 service = ServiceBuilder::new().layer(layer).service(service);
115 }
116
117 match self.timeout {
121 Some(timeout) => {
122 let service = ServiceBuilder::new()
123 .layer(TimeoutLayer::new(timeout))
124 .service(service);
125 let service = ServiceBuilder::new()
126 .map_err(|error: BoxError| cast_to_internal_error(error))
127 .service(service);
128 let service = BoxCloneSyncService::new(service);
129
130 Connector::WithLayers(service)
131 }
132 None => {
133 let service = ServiceBuilder::new().service(service);
137 let service = ServiceBuilder::new()
138 .map_err(|error: BoxError| cast_to_internal_error(error))
139 .service(service);
140 let service = BoxCloneSyncService::new(service);
141 Connector::WithLayers(service)
142 }
143 }
144 }
145
146 #[cfg(not(feature = "__tls"))]
147 pub(crate) fn new<T>(
148 mut http: HttpConnector,
149 proxies: Arc<Vec<Proxy>>,
150 local_addr: T,
151 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
152 interface: Option<&str>,
153 nodelay: bool,
154 ) -> ConnectorBuilder
155 where
156 T: Into<Option<IpAddr>>,
157 {
158 http.set_local_address(local_addr.into());
159 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
160 if let Some(interface) = interface {
161 http.set_interface(interface.to_owned());
162 }
163 http.set_nodelay(nodelay);
164
165 ConnectorBuilder {
166 inner: Inner::Http(http),
167 proxies,
168 verbose: verbose::OFF,
169 timeout: None,
170 }
171 }
172
173 #[cfg(feature = "default-tls")]
174 pub(crate) fn new_default_tls<T>(
175 http: HttpConnector,
176 tls: TlsConnectorBuilder,
177 proxies: Arc<Vec<Proxy>>,
178 user_agent: Option<HeaderValue>,
179 local_addr: T,
180 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181 interface: Option<&str>,
182 nodelay: bool,
183 tls_info: bool,
184 ) -> crate::Result<ConnectorBuilder>
185 where
186 T: Into<Option<IpAddr>>,
187 {
188 let tls = tls.build().map_err(crate::error::builder)?;
189 Ok(Self::from_built_default_tls(
190 http,
191 tls,
192 proxies,
193 user_agent,
194 local_addr,
195 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
196 interface,
197 nodelay,
198 tls_info,
199 ))
200 }
201
202 #[cfg(feature = "default-tls")]
203 pub(crate) fn from_built_default_tls<T>(
204 mut http: HttpConnector,
205 tls: TlsConnector,
206 proxies: Arc<Vec<Proxy>>,
207 user_agent: Option<HeaderValue>,
208 local_addr: T,
209 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
210 interface: Option<&str>,
211 nodelay: bool,
212 tls_info: bool,
213 ) -> ConnectorBuilder
214 where
215 T: Into<Option<IpAddr>>,
216 {
217 http.set_local_address(local_addr.into());
218 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
219 if let Some(interface) = interface {
220 http.set_interface(interface);
221 }
222 http.set_nodelay(nodelay);
223 http.enforce_http(false);
224
225 ConnectorBuilder {
226 inner: Inner::DefaultTls(http, tls),
227 proxies,
228 verbose: verbose::OFF,
229 nodelay,
230 tls_info,
231 user_agent,
232 timeout: None,
233 }
234 }
235
236 #[cfg(feature = "__rustls")]
237 pub(crate) fn new_rustls_tls<T>(
238 mut http: HttpConnector,
239 tls: rustls::ClientConfig,
240 proxies: Arc<Vec<Proxy>>,
241 user_agent: Option<HeaderValue>,
242 local_addr: T,
243 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
244 interface: Option<&str>,
245 nodelay: bool,
246 tls_info: bool,
247 ) -> ConnectorBuilder
248 where
249 T: Into<Option<IpAddr>>,
250 {
251 http.set_local_address(local_addr.into());
252 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
253 if let Some(interface) = interface {
254 http.set_interface(interface.to_owned());
255 }
256 http.set_nodelay(nodelay);
257 http.enforce_http(false);
258
259 let (tls, tls_proxy) = if proxies.is_empty() {
260 let tls = Arc::new(tls);
261 (tls.clone(), tls)
262 } else {
263 let mut tls_proxy = tls.clone();
264 tls_proxy.alpn_protocols.clear();
265 (Arc::new(tls), Arc::new(tls_proxy))
266 };
267
268 ConnectorBuilder {
269 inner: Inner::RustlsTls {
270 http,
271 tls,
272 tls_proxy,
273 },
274 proxies,
275 verbose: verbose::OFF,
276 nodelay,
277 tls_info,
278 user_agent,
279 timeout: None,
280 }
281 }
282
283 pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
284 self.timeout = timeout;
285 }
286
287 pub(crate) fn set_verbose(&mut self, enabled: bool) {
288 self.verbose.0 = enabled;
289 }
290
291 pub(crate) fn set_keepalive(&mut self, dur: Option<Duration>) {
292 match &mut self.inner {
293 #[cfg(feature = "default-tls")]
294 Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
295 #[cfg(feature = "__rustls")]
296 Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
297 #[cfg(not(feature = "__tls"))]
298 Inner::Http(http) => http.set_keepalive(dur),
299 }
300 }
301}
302
303#[allow(missing_debug_implementations)]
304#[derive(Clone)]
305pub(crate) struct ConnectorService {
306 inner: Inner,
307 proxies: Arc<Vec<Proxy>>,
308 verbose: verbose::Wrapper,
309 simple_timeout: Option<Duration>,
314 #[cfg(feature = "__tls")]
315 nodelay: bool,
316 #[cfg(feature = "__tls")]
317 tls_info: bool,
318 #[cfg(feature = "__tls")]
319 user_agent: Option<HeaderValue>,
320}
321
322#[derive(Clone)]
323enum Inner {
324 #[cfg(not(feature = "__tls"))]
325 Http(HttpConnector),
326 #[cfg(feature = "default-tls")]
327 DefaultTls(HttpConnector, TlsConnector),
328 #[cfg(feature = "__rustls")]
329 RustlsTls {
330 http: HttpConnector,
331 tls: Arc<rustls::ClientConfig>,
332 tls_proxy: Arc<rustls::ClientConfig>,
333 },
334}
335
336impl ConnectorService {
337 #[cfg(feature = "socks")]
338 async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
339 let dns = match proxy {
340 ProxyScheme::Socks4 { .. } => socks::DnsResolve::Local,
341 ProxyScheme::Socks5 {
342 remote_dns: false, ..
343 } => socks::DnsResolve::Local,
344 ProxyScheme::Socks5 {
345 remote_dns: true, ..
346 } => socks::DnsResolve::Proxy,
347 ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
348 unreachable!("connect_socks is only called for socks proxies");
349 }
350 };
351
352 match &self.inner {
353 #[cfg(feature = "default-tls")]
354 Inner::DefaultTls(_http, tls) => {
355 if dst.scheme() == Some(&Scheme::HTTPS) {
356 let host = dst.host().ok_or("no host in url")?.to_string();
357 let conn = socks::connect(proxy, dst, dns).await?;
358 let conn = TokioIo::new(conn);
359 let conn = TokioIo::new(conn);
360 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
361 let io = tls_connector.connect(&host, conn).await?;
362 let io = TokioIo::new(io);
363 return Ok(Conn {
364 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
365 is_proxy: false,
366 tls_info: self.tls_info,
367 });
368 }
369 }
370 #[cfg(feature = "__rustls")]
371 Inner::RustlsTls { tls, .. } => {
372 if dst.scheme() == Some(&Scheme::HTTPS) {
373 use std::convert::TryFrom;
374 use tokio_rustls::TlsConnector as RustlsConnector;
375
376 let tls = tls.clone();
377 let host = dst.host().ok_or("no host in url")?.to_string();
378 let conn = socks::connect(proxy, dst, dns).await?;
379 let conn = TokioIo::new(conn);
380 let conn = TokioIo::new(conn);
381 let server_name =
382 rustls_pki_types::ServerName::try_from(host.as_str().to_owned())
383 .map_err(|_| "Invalid Server Name")?;
384 let io = RustlsConnector::from(tls)
385 .connect(server_name, conn)
386 .await?;
387 let io = TokioIo::new(io);
388 return Ok(Conn {
389 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
390 is_proxy: false,
391 tls_info: false,
392 });
393 }
394 }
395 #[cfg(not(feature = "__tls"))]
396 Inner::Http(_) => (),
397 }
398
399 socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
400 inner: self.verbose.wrap(TokioIo::new(tcp)),
401 is_proxy: false,
402 tls_info: false,
403 })
404 }
405
406 async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
407 match self.inner {
408 #[cfg(not(feature = "__tls"))]
409 Inner::Http(mut http) => {
410 let io = http.call(dst).await?;
411 Ok(Conn {
412 inner: self.verbose.wrap(io),
413 is_proxy,
414 tls_info: false,
415 })
416 }
417 #[cfg(feature = "default-tls")]
418 Inner::DefaultTls(http, tls) => {
419 let mut http = http.clone();
420
421 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
425 http.set_nodelay(true);
426 }
427
428 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
429 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
430 let io = http.call(dst).await?;
431
432 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
433 if !self.nodelay {
434 stream
435 .inner()
436 .get_ref()
437 .get_ref()
438 .get_ref()
439 .inner()
440 .inner()
441 .set_nodelay(false)?;
442 }
443 Ok(Conn {
444 inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
445 is_proxy,
446 tls_info: self.tls_info,
447 })
448 } else {
449 Ok(Conn {
450 inner: self.verbose.wrap(io),
451 is_proxy,
452 tls_info: false,
453 })
454 }
455 }
456 #[cfg(feature = "__rustls")]
457 Inner::RustlsTls { http, tls, .. } => {
458 let mut http = http.clone();
459
460 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
464 http.set_nodelay(true);
465 }
466
467 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
468 let io = http.call(dst).await?;
469
470 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
471 if !self.nodelay {
472 let (io, _) = stream.inner().get_ref();
473 io.inner().inner().set_nodelay(false)?;
474 }
475 Ok(Conn {
476 inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
477 is_proxy,
478 tls_info: self.tls_info,
479 })
480 } else {
481 Ok(Conn {
482 inner: self.verbose.wrap(io),
483 is_proxy,
484 tls_info: false,
485 })
486 }
487 }
488 }
489 }
490
491 async fn connect_via_proxy(
492 self,
493 dst: Uri,
494 proxy_scheme: ProxyScheme,
495 ) -> Result<Conn, BoxError> {
496 log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");
497
498 let (proxy_dst, _auth) = match proxy_scheme {
499 ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
500 ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
501 #[cfg(feature = "socks")]
502 ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await,
503 #[cfg(feature = "socks")]
504 ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
505 };
506
507 #[cfg(feature = "__tls")]
508 let auth = _auth;
509
510 match &self.inner {
511 #[cfg(feature = "default-tls")]
512 Inner::DefaultTls(http, tls) => {
513 if dst.scheme() == Some(&Scheme::HTTPS) {
514 let host = dst.host().to_owned();
515 let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
516 let http = http.clone();
517 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
518 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
519 let conn = http.call(proxy_dst).await?;
520 log::trace!("tunneling HTTPS over proxy");
521 let tunneled = tunnel(
522 conn,
523 host.ok_or("no host in url")?.to_string(),
524 port,
525 self.user_agent.clone(),
526 auth,
527 )
528 .await?;
529 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
530 let io = tls_connector
531 .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled))
532 .await?;
533 return Ok(Conn {
534 inner: self.verbose.wrap(NativeTlsConn {
535 inner: TokioIo::new(io),
536 }),
537 is_proxy: false,
538 tls_info: false,
539 });
540 }
541 }
542 #[cfg(feature = "__rustls")]
543 Inner::RustlsTls {
544 http,
545 tls,
546 tls_proxy,
547 } => {
548 if dst.scheme() == Some(&Scheme::HTTPS) {
549 use rustls_pki_types::ServerName;
550 use std::convert::TryFrom;
551 use tokio_rustls::TlsConnector as RustlsConnector;
552
553 let host = dst.host().ok_or("no host in url")?.to_string();
554 let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
555 let http = http.clone();
556 let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
557 let tls = tls.clone();
558 let conn = http.call(proxy_dst).await?;
559 log::trace!("tunneling HTTPS over proxy");
560 let maybe_server_name = ServerName::try_from(host.as_str().to_owned())
561 .map_err(|_| "Invalid Server Name");
562 let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
563 let server_name = maybe_server_name?;
564 let io = RustlsConnector::from(tls)
565 .connect(server_name, TokioIo::new(tunneled))
566 .await?;
567
568 return Ok(Conn {
569 inner: self.verbose.wrap(RustlsTlsConn {
570 inner: TokioIo::new(io),
571 }),
572 is_proxy: false,
573 tls_info: false,
574 });
575 }
576 }
577 #[cfg(not(feature = "__tls"))]
578 Inner::Http(_) => (),
579 }
580
581 self.connect_with_maybe_proxy(proxy_dst, true).await
582 }
583}
584
585fn into_uri(scheme: Scheme, host: Authority) -> Uri {
586 http::Uri::builder()
588 .scheme(scheme)
589 .authority(host)
590 .path_and_query(http::uri::PathAndQuery::from_static("/"))
591 .build()
592 .expect("scheme and authority is valid Uri")
593}
594
595async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
596where
597 F: Future<Output = Result<T, BoxError>>,
598{
599 if let Some(to) = timeout {
600 match tokio::time::timeout(to, f).await {
601 Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
602 Ok(Ok(try_res)) => Ok(try_res),
603 Ok(Err(e)) => Err(e),
604 }
605 } else {
606 f.await
607 }
608}
609
610impl Service<Uri> for ConnectorService {
611 type Response = Conn;
612 type Error = BoxError;
613 type Future = Connecting;
614
615 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616 Poll::Ready(Ok(()))
617 }
618
619 fn call(&mut self, dst: Uri) -> Self::Future {
620 log::debug!("starting new connection: {dst:?}");
621 let timeout = self.simple_timeout;
622 for prox in self.proxies.iter() {
623 if let Some(proxy_scheme) = prox.intercept(&dst) {
624 return Box::pin(with_timeout(
625 self.clone().connect_via_proxy(dst, proxy_scheme),
626 timeout,
627 ));
628 }
629 }
630
631 Box::pin(with_timeout(
632 self.clone().connect_with_maybe_proxy(dst, false),
633 timeout,
634 ))
635 }
636}
637
638#[cfg(feature = "__tls")]
639trait TlsInfoFactory {
640 fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
641}
642
643#[cfg(feature = "__tls")]
644impl TlsInfoFactory for tokio::net::TcpStream {
645 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
646 None
647 }
648}
649
650#[cfg(feature = "__tls")]
651impl<T: TlsInfoFactory> TlsInfoFactory for TokioIo<T> {
652 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
653 self.inner().tls_info()
654 }
655}
656
657#[cfg(feature = "default-tls")]
658impl TlsInfoFactory for tokio_native_tls::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
659 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
660 let peer_certificate = self
661 .get_ref()
662 .peer_certificate()
663 .ok()
664 .flatten()
665 .and_then(|c| c.to_der().ok());
666 Some(crate::tls::TlsInfo { peer_certificate })
667 }
668}
669
670#[cfg(feature = "default-tls")]
671impl TlsInfoFactory
672 for tokio_native_tls::TlsStream<
673 TokioIo<hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
674 >
675{
676 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
677 let peer_certificate = self
678 .get_ref()
679 .peer_certificate()
680 .ok()
681 .flatten()
682 .and_then(|c| c.to_der().ok());
683 Some(crate::tls::TlsInfo { peer_certificate })
684 }
685}
686
687#[cfg(feature = "default-tls")]
688impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
689 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
690 match self {
691 hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
692 hyper_tls::MaybeHttpsStream::Http(_) => None,
693 }
694 }
695}
696
697#[cfg(feature = "__rustls")]
698impl TlsInfoFactory for tokio_rustls::client::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
699 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
700 let peer_certificate = self
701 .get_ref()
702 .1
703 .peer_certificates()
704 .and_then(|certs| certs.first())
705 .map(|c| c.to_vec());
706 Some(crate::tls::TlsInfo { peer_certificate })
707 }
708}
709
710#[cfg(feature = "__rustls")]
711impl TlsInfoFactory
712 for tokio_rustls::client::TlsStream<
713 TokioIo<hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
714 >
715{
716 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
717 let peer_certificate = self
718 .get_ref()
719 .1
720 .peer_certificates()
721 .and_then(|certs| certs.first())
722 .map(|c| c.to_vec());
723 Some(crate::tls::TlsInfo { peer_certificate })
724 }
725}
726
727#[cfg(feature = "__rustls")]
728impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
729 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
730 match self {
731 hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
732 hyper_rustls::MaybeHttpsStream::Http(_) => None,
733 }
734 }
735}
736
737pub(crate) trait AsyncConn:
738 Read + Write + Connection + Send + Sync + Unpin + 'static
739{
740}
741
742impl<T: Read + Write + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
743
744#[cfg(feature = "__tls")]
745trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
746#[cfg(not(feature = "__tls"))]
747trait AsyncConnWithInfo: AsyncConn {}
748
749#[cfg(feature = "__tls")]
750impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {}
751#[cfg(not(feature = "__tls"))]
752impl<T: AsyncConn> AsyncConnWithInfo for T {}
753
754type BoxConn = Box<dyn AsyncConnWithInfo>;
755
756pub(crate) mod sealed {
757 use super::*;
758 #[derive(Debug)]
759 pub struct Unnameable(pub(super) Uri);
760
761 pin_project! {
762 #[allow(missing_debug_implementations)]
767 pub struct Conn {
768 #[pin]
769 pub(super)inner: BoxConn,
770 pub(super) is_proxy: bool,
771 pub(super) tls_info: bool,
773 }
774 }
775
776 impl Connection for Conn {
777 fn connected(&self) -> Connected {
778 let connected = self.inner.connected().proxy(self.is_proxy);
779 #[cfg(feature = "__tls")]
780 if self.tls_info {
781 if let Some(tls_info) = self.inner.tls_info() {
782 connected.extra(tls_info)
783 } else {
784 connected
785 }
786 } else {
787 connected
788 }
789 #[cfg(not(feature = "__tls"))]
790 connected
791 }
792 }
793
794 impl Read for Conn {
795 fn poll_read(
796 self: Pin<&mut Self>,
797 cx: &mut Context,
798 buf: ReadBufCursor<'_>,
799 ) -> Poll<io::Result<()>> {
800 let this = self.project();
801 Read::poll_read(this.inner, cx, buf)
802 }
803 }
804
805 impl Write for Conn {
806 fn poll_write(
807 self: Pin<&mut Self>,
808 cx: &mut Context,
809 buf: &[u8],
810 ) -> Poll<Result<usize, io::Error>> {
811 let this = self.project();
812 Write::poll_write(this.inner, cx, buf)
813 }
814
815 fn poll_write_vectored(
816 self: Pin<&mut Self>,
817 cx: &mut Context<'_>,
818 bufs: &[IoSlice<'_>],
819 ) -> Poll<Result<usize, io::Error>> {
820 let this = self.project();
821 Write::poll_write_vectored(this.inner, cx, bufs)
822 }
823
824 fn is_write_vectored(&self) -> bool {
825 self.inner.is_write_vectored()
826 }
827
828 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
829 let this = self.project();
830 Write::poll_flush(this.inner, cx)
831 }
832
833 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
834 let this = self.project();
835 Write::poll_shutdown(this.inner, cx)
836 }
837 }
838}
839
840pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
841
842#[cfg(feature = "__tls")]
843async fn tunnel<T>(
844 mut conn: T,
845 host: String,
846 port: u16,
847 user_agent: Option<HeaderValue>,
848 auth: Option<HeaderValue>,
849) -> Result<T, BoxError>
850where
851 T: Read + Write + Unpin,
852{
853 use hyper_util::rt::TokioIo;
854 use tokio::io::{AsyncReadExt, AsyncWriteExt};
855
856 let mut buf = format!(
857 "\
858 CONNECT {host}:{port} HTTP/1.1\r\n\
859 Host: {host}:{port}\r\n\
860 "
861 )
862 .into_bytes();
863
864 if let Some(user_agent) = user_agent {
866 buf.extend_from_slice(b"User-Agent: ");
867 buf.extend_from_slice(user_agent.as_bytes());
868 buf.extend_from_slice(b"\r\n");
869 }
870
871 if let Some(value) = auth {
873 log::debug!("tunnel to {host}:{port} using basic auth");
874 buf.extend_from_slice(b"Proxy-Authorization: ");
875 buf.extend_from_slice(value.as_bytes());
876 buf.extend_from_slice(b"\r\n");
877 }
878
879 buf.extend_from_slice(b"\r\n");
881
882 let mut tokio_conn = TokioIo::new(&mut conn);
883
884 tokio_conn.write_all(&buf).await?;
885
886 let mut buf = [0; 8192];
887 let mut pos = 0;
888
889 loop {
890 let n = tokio_conn.read(&mut buf[pos..]).await?;
891
892 if n == 0 {
893 return Err(tunnel_eof());
894 }
895 pos += n;
896
897 let recvd = &buf[..pos];
898 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
899 if recvd.ends_with(b"\r\n\r\n") {
900 return Ok(conn);
901 }
902 if pos == buf.len() {
903 return Err("proxy headers too long for tunnel".into());
904 }
905 } else if recvd.starts_with(b"HTTP/1.1 407") {
907 return Err("proxy authentication required".into());
908 } else {
909 return Err("unsuccessful tunnel".into());
910 }
911 }
912}
913
914#[cfg(feature = "__tls")]
915fn tunnel_eof() -> BoxError {
916 "unexpected eof while tunneling".into()
917}
918
919#[cfg(feature = "default-tls")]
920mod native_tls_conn {
921 use super::TlsInfoFactory;
922 use hyper::rt::{Read, ReadBufCursor, Write};
923 use hyper_tls::MaybeHttpsStream;
924 use hyper_util::client::legacy::connect::{Connected, Connection};
925 use hyper_util::rt::TokioIo;
926 use pin_project_lite::pin_project;
927 use std::{
928 io::{self, IoSlice},
929 pin::Pin,
930 task::{Context, Poll},
931 };
932 use tokio::io::{AsyncRead, AsyncWrite};
933 use tokio::net::TcpStream;
934 use tokio_native_tls::TlsStream;
935
936 pin_project! {
937 pub(super) struct NativeTlsConn<T> {
938 #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
939 }
940 }
941
942 impl Connection for NativeTlsConn<TokioIo<TokioIo<TcpStream>>> {
943 fn connected(&self) -> Connected {
944 let connected = self
945 .inner
946 .inner()
947 .get_ref()
948 .get_ref()
949 .get_ref()
950 .inner()
951 .connected();
952 #[cfg(feature = "native-tls-alpn")]
953 match self.inner.inner().get_ref().negotiated_alpn().ok() {
954 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
955 _ => connected,
956 }
957 #[cfg(not(feature = "native-tls-alpn"))]
958 connected
959 }
960 }
961
962 impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
963 fn connected(&self) -> Connected {
964 let connected = self
965 .inner
966 .inner()
967 .get_ref()
968 .get_ref()
969 .get_ref()
970 .inner()
971 .connected();
972 #[cfg(feature = "native-tls-alpn")]
973 match self.inner.inner().get_ref().negotiated_alpn().ok() {
974 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
975 _ => connected,
976 }
977 #[cfg(not(feature = "native-tls-alpn"))]
978 connected
979 }
980 }
981
982 impl<T: AsyncRead + AsyncWrite + Unpin> Read for NativeTlsConn<T> {
983 fn poll_read(
984 self: Pin<&mut Self>,
985 cx: &mut Context,
986 buf: ReadBufCursor<'_>,
987 ) -> Poll<tokio::io::Result<()>> {
988 let this = self.project();
989 Read::poll_read(this.inner, cx, buf)
990 }
991 }
992
993 impl<T: AsyncRead + AsyncWrite + Unpin> Write for NativeTlsConn<T> {
994 fn poll_write(
995 self: Pin<&mut Self>,
996 cx: &mut Context,
997 buf: &[u8],
998 ) -> Poll<Result<usize, tokio::io::Error>> {
999 let this = self.project();
1000 Write::poll_write(this.inner, cx, buf)
1001 }
1002
1003 fn poll_write_vectored(
1004 self: Pin<&mut Self>,
1005 cx: &mut Context<'_>,
1006 bufs: &[IoSlice<'_>],
1007 ) -> Poll<Result<usize, io::Error>> {
1008 let this = self.project();
1009 Write::poll_write_vectored(this.inner, cx, bufs)
1010 }
1011
1012 fn is_write_vectored(&self) -> bool {
1013 self.inner.is_write_vectored()
1014 }
1015
1016 fn poll_flush(
1017 self: Pin<&mut Self>,
1018 cx: &mut Context,
1019 ) -> Poll<Result<(), tokio::io::Error>> {
1020 let this = self.project();
1021 Write::poll_flush(this.inner, cx)
1022 }
1023
1024 fn poll_shutdown(
1025 self: Pin<&mut Self>,
1026 cx: &mut Context,
1027 ) -> Poll<Result<(), tokio::io::Error>> {
1028 let this = self.project();
1029 Write::poll_shutdown(this.inner, cx)
1030 }
1031 }
1032
1033 impl<T> TlsInfoFactory for NativeTlsConn<T>
1034 where
1035 TokioIo<TlsStream<T>>: TlsInfoFactory,
1036 {
1037 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1038 self.inner.tls_info()
1039 }
1040 }
1041}
1042
1043#[cfg(feature = "__rustls")]
1044mod rustls_tls_conn {
1045 use super::TlsInfoFactory;
1046 use hyper::rt::{Read, ReadBufCursor, Write};
1047 use hyper_rustls::MaybeHttpsStream;
1048 use hyper_util::client::legacy::connect::{Connected, Connection};
1049 use hyper_util::rt::TokioIo;
1050 use pin_project_lite::pin_project;
1051 use std::{
1052 io::{self, IoSlice},
1053 pin::Pin,
1054 task::{Context, Poll},
1055 };
1056 use tokio::io::{AsyncRead, AsyncWrite};
1057 use tokio::net::TcpStream;
1058 use tokio_rustls::client::TlsStream;
1059
1060 pin_project! {
1061 pub(super) struct RustlsTlsConn<T> {
1062 #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
1063 }
1064 }
1065
1066 impl Connection for RustlsTlsConn<TokioIo<TokioIo<TcpStream>>> {
1067 fn connected(&self) -> Connected {
1068 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
1069 self.inner
1070 .inner()
1071 .get_ref()
1072 .0
1073 .inner()
1074 .connected()
1075 .negotiated_h2()
1076 } else {
1077 self.inner.inner().get_ref().0.inner().connected()
1078 }
1079 }
1080 }
1081 impl Connection for RustlsTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
1082 fn connected(&self) -> Connected {
1083 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
1084 self.inner
1085 .inner()
1086 .get_ref()
1087 .0
1088 .inner()
1089 .connected()
1090 .negotiated_h2()
1091 } else {
1092 self.inner.inner().get_ref().0.inner().connected()
1093 }
1094 }
1095 }
1096
1097 impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustlsTlsConn<T> {
1098 fn poll_read(
1099 self: Pin<&mut Self>,
1100 cx: &mut Context,
1101 buf: ReadBufCursor<'_>,
1102 ) -> Poll<tokio::io::Result<()>> {
1103 let this = self.project();
1104 Read::poll_read(this.inner, cx, buf)
1105 }
1106 }
1107
1108 impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustlsTlsConn<T> {
1109 fn poll_write(
1110 self: Pin<&mut Self>,
1111 cx: &mut Context,
1112 buf: &[u8],
1113 ) -> Poll<Result<usize, tokio::io::Error>> {
1114 let this = self.project();
1115 Write::poll_write(this.inner, cx, buf)
1116 }
1117
1118 fn poll_write_vectored(
1119 self: Pin<&mut Self>,
1120 cx: &mut Context<'_>,
1121 bufs: &[IoSlice<'_>],
1122 ) -> Poll<Result<usize, io::Error>> {
1123 let this = self.project();
1124 Write::poll_write_vectored(this.inner, cx, bufs)
1125 }
1126
1127 fn is_write_vectored(&self) -> bool {
1128 self.inner.is_write_vectored()
1129 }
1130
1131 fn poll_flush(
1132 self: Pin<&mut Self>,
1133 cx: &mut Context,
1134 ) -> Poll<Result<(), tokio::io::Error>> {
1135 let this = self.project();
1136 Write::poll_flush(this.inner, cx)
1137 }
1138
1139 fn poll_shutdown(
1140 self: Pin<&mut Self>,
1141 cx: &mut Context,
1142 ) -> Poll<Result<(), tokio::io::Error>> {
1143 let this = self.project();
1144 Write::poll_shutdown(this.inner, cx)
1145 }
1146 }
1147 impl<T> TlsInfoFactory for RustlsTlsConn<T>
1148 where
1149 TokioIo<TlsStream<T>>: TlsInfoFactory,
1150 {
1151 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1152 self.inner.tls_info()
1153 }
1154 }
1155}
1156
1157#[cfg(feature = "socks")]
1158mod socks {
1159 use std::io;
1160 use std::net::ToSocketAddrs;
1161
1162 use http::Uri;
1163 use tokio::net::TcpStream;
1164 use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
1165
1166 use super::{BoxError, Scheme};
1167 use crate::proxy::ProxyScheme;
1168
1169 pub(super) enum DnsResolve {
1170 Local,
1171 Proxy,
1172 }
1173
1174 pub(super) async fn connect(
1175 proxy: ProxyScheme,
1176 dst: Uri,
1177 dns: DnsResolve,
1178 ) -> Result<TcpStream, BoxError> {
1179 let https = dst.scheme() == Some(&Scheme::HTTPS);
1180 let original_host = dst
1181 .host()
1182 .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
1183 let mut host = original_host.to_owned();
1184 let port = match dst.port() {
1185 Some(p) => p.as_u16(),
1186 None if https => 443u16,
1187 _ => 80u16,
1188 };
1189
1190 if let DnsResolve::Local = dns {
1191 let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
1192 if let Some(new_target) = maybe_new_target {
1193 host = new_target.ip().to_string();
1194 }
1195 }
1196
1197 match proxy {
1198 ProxyScheme::Socks4 { addr } => {
1199 let stream = Socks4Stream::connect(addr, (host.as_str(), port))
1200 .await
1201 .map_err(|e| format!("socks connect error: {e}"))?;
1202 Ok(stream.into_inner())
1203 }
1204 ProxyScheme::Socks5 { addr, ref auth, .. } => {
1205 let stream = if let Some((username, password)) = auth {
1206 Socks5Stream::connect_with_password(
1207 addr,
1208 (host.as_str(), port),
1209 &username,
1210 &password,
1211 )
1212 .await
1213 .map_err(|e| format!("socks connect error: {e}"))?
1214 } else {
1215 Socks5Stream::connect(addr, (host.as_str(), port))
1216 .await
1217 .map_err(|e| format!("socks connect error: {e}"))?
1218 };
1219
1220 Ok(stream.into_inner())
1221 }
1222 _ => unreachable!(),
1223 }
1224 }
1225}
1226
1227mod verbose {
1228 use hyper::rt::{Read, ReadBufCursor, Write};
1229 use hyper_util::client::legacy::connect::{Connected, Connection};
1230 use std::cmp::min;
1231 use std::fmt;
1232 use std::io::{self, IoSlice};
1233 use std::pin::Pin;
1234 use std::task::{Context, Poll};
1235
1236 pub(super) const OFF: Wrapper = Wrapper(false);
1237
1238 #[derive(Clone, Copy)]
1239 pub(super) struct Wrapper(pub(super) bool);
1240
1241 impl Wrapper {
1242 pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn {
1243 if self.0 && log::log_enabled!(log::Level::Trace) {
1244 Box::new(Verbose {
1245 id: crate::util::fast_random() as u32,
1247 inner: conn,
1248 })
1249 } else {
1250 Box::new(conn)
1251 }
1252 }
1253 }
1254
1255 struct Verbose<T> {
1256 id: u32,
1257 inner: T,
1258 }
1259
1260 impl<T: Connection + Read + Write + Unpin> Connection for Verbose<T> {
1261 fn connected(&self) -> Connected {
1262 self.inner.connected()
1263 }
1264 }
1265
1266 impl<T: Read + Write + Unpin> Read for Verbose<T> {
1267 fn poll_read(
1268 mut self: Pin<&mut Self>,
1269 cx: &mut Context,
1270 mut buf: ReadBufCursor<'_>,
1271 ) -> Poll<std::io::Result<()>> {
1272 let mut vbuf = hyper::rt::ReadBuf::uninit(unsafe { buf.as_mut() });
1276 match Pin::new(&mut self.inner).poll_read(cx, vbuf.unfilled()) {
1277 Poll::Ready(Ok(())) => {
1278 log::trace!("{:08x} read: {:?}", self.id, Escape(vbuf.filled()));
1279 let len = vbuf.filled().len();
1280 unsafe {
1283 buf.advance(len);
1284 }
1285 Poll::Ready(Ok(()))
1286 }
1287 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1288 Poll::Pending => Poll::Pending,
1289 }
1290 }
1291 }
1292
1293 impl<T: Read + Write + Unpin> Write for Verbose<T> {
1294 fn poll_write(
1295 mut self: Pin<&mut Self>,
1296 cx: &mut Context,
1297 buf: &[u8],
1298 ) -> Poll<Result<usize, std::io::Error>> {
1299 match Pin::new(&mut self.inner).poll_write(cx, buf) {
1300 Poll::Ready(Ok(n)) => {
1301 log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1302 Poll::Ready(Ok(n))
1303 }
1304 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1305 Poll::Pending => Poll::Pending,
1306 }
1307 }
1308
1309 fn poll_write_vectored(
1310 mut self: Pin<&mut Self>,
1311 cx: &mut Context<'_>,
1312 bufs: &[IoSlice<'_>],
1313 ) -> Poll<Result<usize, io::Error>> {
1314 match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
1315 Poll::Ready(Ok(nwritten)) => {
1316 log::trace!(
1317 "{:08x} write (vectored): {:?}",
1318 self.id,
1319 Vectored { bufs, nwritten }
1320 );
1321 Poll::Ready(Ok(nwritten))
1322 }
1323 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1324 Poll::Pending => Poll::Pending,
1325 }
1326 }
1327
1328 fn is_write_vectored(&self) -> bool {
1329 self.inner.is_write_vectored()
1330 }
1331
1332 fn poll_flush(
1333 mut self: Pin<&mut Self>,
1334 cx: &mut Context,
1335 ) -> Poll<Result<(), std::io::Error>> {
1336 Pin::new(&mut self.inner).poll_flush(cx)
1337 }
1338
1339 fn poll_shutdown(
1340 mut self: Pin<&mut Self>,
1341 cx: &mut Context,
1342 ) -> Poll<Result<(), std::io::Error>> {
1343 Pin::new(&mut self.inner).poll_shutdown(cx)
1344 }
1345 }
1346
1347 #[cfg(feature = "__tls")]
1348 impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> {
1349 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1350 self.inner.tls_info()
1351 }
1352 }
1353
1354 struct Escape<'a>(&'a [u8]);
1355
1356 impl fmt::Debug for Escape<'_> {
1357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1358 write!(f, "b\"")?;
1359 for &c in self.0 {
1360 if c == b'\n' {
1362 write!(f, "\\n")?;
1363 } else if c == b'\r' {
1364 write!(f, "\\r")?;
1365 } else if c == b'\t' {
1366 write!(f, "\\t")?;
1367 } else if c == b'\\' || c == b'"' {
1368 write!(f, "\\{}", c as char)?;
1369 } else if c == b'\0' {
1370 write!(f, "\\0")?;
1371 } else if c >= 0x20 && c < 0x7f {
1373 write!(f, "{}", c as char)?;
1374 } else {
1375 write!(f, "\\x{c:02x}")?;
1376 }
1377 }
1378 write!(f, "\"")?;
1379 Ok(())
1380 }
1381 }
1382
1383 struct Vectored<'a, 'b> {
1384 bufs: &'a [IoSlice<'b>],
1385 nwritten: usize,
1386 }
1387
1388 impl fmt::Debug for Vectored<'_, '_> {
1389 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1390 let mut left = self.nwritten;
1391 for buf in self.bufs.iter() {
1392 if left == 0 {
1393 break;
1394 }
1395 let n = min(left, buf.len());
1396 Escape(&buf[..n]).fmt(f)?;
1397 left -= n;
1398 }
1399 Ok(())
1400 }
1401 }
1402}
1403
1404#[cfg(feature = "__tls")]
1405#[cfg(test)]
1406mod tests {
1407 use super::tunnel;
1408 use crate::proxy;
1409 use hyper_util::rt::TokioIo;
1410 use std::io::{Read, Write};
1411 use std::net::TcpListener;
1412 use std::thread;
1413 use tokio::net::TcpStream;
1414 use tokio::runtime;
1415
1416 static TUNNEL_UA: &str = "tunnel-test/x.y";
1417 static TUNNEL_OK: &[u8] = b"\
1418 HTTP/1.1 200 OK\r\n\
1419 \r\n\
1420 ";
1421
1422 macro_rules! mock_tunnel {
1423 () => {{
1424 mock_tunnel!(TUNNEL_OK)
1425 }};
1426 ($write:expr) => {{
1427 mock_tunnel!($write, "")
1428 }};
1429 ($write:expr, $auth:expr) => {{
1430 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1431 let addr = listener.local_addr().unwrap();
1432 let connect_expected = format!(
1433 "\
1434 CONNECT {0}:{1} HTTP/1.1\r\n\
1435 Host: {0}:{1}\r\n\
1436 User-Agent: {2}\r\n\
1437 {3}\
1438 \r\n\
1439 ",
1440 addr.ip(),
1441 addr.port(),
1442 TUNNEL_UA,
1443 $auth
1444 )
1445 .into_bytes();
1446
1447 thread::spawn(move || {
1448 let (mut sock, _) = listener.accept().unwrap();
1449 let mut buf = [0u8; 4096];
1450 let n = sock.read(&mut buf).unwrap();
1451 assert_eq!(&buf[..n], &connect_expected[..]);
1452
1453 sock.write_all($write).unwrap();
1454 });
1455 addr
1456 }};
1457 }
1458
1459 fn ua() -> Option<http::header::HeaderValue> {
1460 Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1461 }
1462
1463 #[test]
1464 fn test_tunnel() {
1465 let addr = mock_tunnel!();
1466
1467 let rt = runtime::Builder::new_current_thread()
1468 .enable_all()
1469 .build()
1470 .expect("new rt");
1471 let f = async move {
1472 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1473 let host = addr.ip().to_string();
1474 let port = addr.port();
1475 tunnel(tcp, host, port, ua(), None).await
1476 };
1477
1478 rt.block_on(f).unwrap();
1479 }
1480
1481 #[test]
1482 fn test_tunnel_eof() {
1483 let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1484
1485 let rt = runtime::Builder::new_current_thread()
1486 .enable_all()
1487 .build()
1488 .expect("new rt");
1489 let f = async move {
1490 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1491 let host = addr.ip().to_string();
1492 let port = addr.port();
1493 tunnel(tcp, host, port, ua(), None).await
1494 };
1495
1496 rt.block_on(f).unwrap_err();
1497 }
1498
1499 #[test]
1500 fn test_tunnel_non_http_response() {
1501 let addr = mock_tunnel!(b"foo bar baz hallo");
1502
1503 let rt = runtime::Builder::new_current_thread()
1504 .enable_all()
1505 .build()
1506 .expect("new rt");
1507 let f = async move {
1508 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1509 let host = addr.ip().to_string();
1510 let port = addr.port();
1511 tunnel(tcp, host, port, ua(), None).await
1512 };
1513
1514 rt.block_on(f).unwrap_err();
1515 }
1516
1517 #[test]
1518 fn test_tunnel_proxy_unauthorized() {
1519 let addr = mock_tunnel!(
1520 b"\
1521 HTTP/1.1 407 Proxy Authentication Required\r\n\
1522 Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1523 \r\n\
1524 "
1525 );
1526
1527 let rt = runtime::Builder::new_current_thread()
1528 .enable_all()
1529 .build()
1530 .expect("new rt");
1531 let f = async move {
1532 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1533 let host = addr.ip().to_string();
1534 let port = addr.port();
1535 tunnel(tcp, host, port, ua(), None).await
1536 };
1537
1538 let error = rt.block_on(f).unwrap_err();
1539 assert_eq!(error.to_string(), "proxy authentication required");
1540 }
1541
1542 #[test]
1543 fn test_tunnel_basic_auth() {
1544 let addr = mock_tunnel!(
1545 TUNNEL_OK,
1546 "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1547 );
1548
1549 let rt = runtime::Builder::new_current_thread()
1550 .enable_all()
1551 .build()
1552 .expect("new rt");
1553 let f = async move {
1554 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1555 let host = addr.ip().to_string();
1556 let port = addr.port();
1557 tunnel(
1558 tcp,
1559 host,
1560 port,
1561 ua(),
1562 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1563 )
1564 .await
1565 };
1566
1567 rt.block_on(f).unwrap();
1568 }
1569}