diff --git a/rttp_client/src/connection/async_connection.rs b/rttp_client/src/connection/async_connection.rs index 5537e3d..d38323e 100644 --- a/rttp_client/src/connection/async_connection.rs +++ b/rttp_client/src/connection/async_connection.rs @@ -3,14 +3,14 @@ use std::net::{TcpStream, ToSocketAddrs}; use futures::io::{AllowStdIo, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use socket2::{Domain, Protocol, Socket, Type}; use socks::{Socks4Stream, Socks5Stream}; -use std::io::{self, Read, Write}; +use std::io::{self, Write}; use std::time; use url::Url; #[cfg(feature = "tls-rustls")] use std::sync::Arc; -use crate::connection::connection::Connection; +use crate::connection::connection::{read_proxy_connect_response, Connection}; use crate::connection::connection_reader::{response_body_kind, ResponseBodyKind}; use crate::error; use crate::request::RawRequest; @@ -401,7 +401,7 @@ impl<'a> AsyncConnection<'a> { let addr = format!("{}:{}", proxy.host(), proxy.port()); let stream = self.async_tcp_stream(&addr).await?; let mut stream = AllowStdIo::new(stream); - let header = self.conn.proxy_http_header(url); + let header = self.conn.proxy_http_header(url, proxy); self.async_write_request(&mut stream, &header).await?; self.async_read_stream(url, &mut stream).await @@ -417,21 +417,7 @@ impl<'a> AsyncConnection<'a> { .write_all(connect_header.as_bytes()) .map_err(error::request)?; stream.flush().map_err(error::request)?; - - // HTTP/1.1 200 Connection Established - let mut res = vec![0u8; 1024]; - let bytes = stream.read(&mut res).map_err(error::request)?; - - let res_s = match String::from_utf8(res[..bytes].to_vec()) { - Ok(r) => r, - Err(_) => return Err(error::bad_proxy("parse proxy server response error.")), - }; - if !res_s - .to_ascii_lowercase() - .contains("connection established") - { - return Err(error::bad_proxy("Proxy server response error.")); - } + read_proxy_connect_response(&mut stream)?; self.async_send_with_stream(url, stream).await } diff --git a/rttp_client/src/connection/block_connection.rs b/rttp_client/src/connection/block_connection.rs index 24ff380..f762bee 100644 --- a/rttp_client/src/connection/block_connection.rs +++ b/rttp_client/src/connection/block_connection.rs @@ -1,4 +1,4 @@ -use std::io::{Read, Write}; +use std::io::Write; use socks::{Socks4Stream, Socks5Stream}; use url::Url; @@ -78,7 +78,7 @@ impl<'a> BlockConnection<'a> { fn call_with_proxy_http(&self, url: &Url, proxy: &Proxy) -> error::Result> { let addr = format!("{}:{}", proxy.host(), proxy.port()); let mut stream = self.conn.block_tcp_stream(&addr)?; - let header = self.conn.proxy_http_header(url); + let header = self.conn.proxy_http_header(url, proxy); stream .write_all(header.as_bytes()) @@ -101,25 +101,10 @@ impl<'a> BlockConnection<'a> { let mut stream = self.conn.block_tcp_stream(&addr)?; stream - .write(connect_header.as_bytes()) + .write_all(connect_header.as_bytes()) .map_err(error::request)?; stream.flush().map_err(error::request)?; - - //HTTP/1.1 200 Connection Established - let mut res = [0u8; 1024]; - stream.read(&mut res).map_err(error::request)?; - - let res_s = String::from_utf8(res.to_vec()) - .map_err(|_| error::bad_proxy("parse proxy server response error."))?; - if !res_s - .to_ascii_lowercase() - .contains("connection established") - { - return Err(error::bad_proxy(format!( - "Proxy server response error: {}", - res_s - ))); - } + crate::connection::connection::read_proxy_connect_response(&mut stream)?; self.conn.block_send_with_stream(url, &mut stream) } diff --git a/rttp_client/src/connection/connection.rs b/rttp_client/src/connection/connection.rs index 2b4ede2..e6cf18f 100644 --- a/rttp_client/src/connection/connection.rs +++ b/rttp_client/src/connection/connection.rs @@ -218,30 +218,23 @@ impl<'a> Connection<'a> { let mut proxy_header = String::new(); proxy_header.push_str(&format!("CONNECT {}:{} HTTP/1.1\r\n", host, port)); proxy_header.push_str(&format!("Host: {}:{}\r\n", host, port)); - - if let Some(username) = proxy.username() { - let auth = if let Some(password) = proxy.password() { - format!("{}:{}", username, password) - } else { - format!("{}:", username) - }; - let auth = STANDARD.encode(auth.as_bytes()); - proxy_header.push_str(&format!("Authorization: Basic {}\r\n", auth)); - } + append_proxy_authorization_header(&mut proxy_header, proxy); proxy_header.push_str("\r\n"); Ok(proxy_header) } - pub fn proxy_http_header(&self, url: &Url) -> String { + pub fn proxy_http_header(&self, url: &Url, proxy: &Proxy) -> String { let header = self.header(); let (_, rest) = header.split_once("\r\n").unwrap_or(("", "")); - format!( + let mut proxy_header = format!( "{} {} HTTP/1.1\r\n{}", self.request.origin().method().to_uppercase(), absolute_url(url), rest - ) + ); + append_proxy_authorization_header(&mut proxy_header, proxy); + proxy_header } pub fn redirect_url(&self, url: &Url, location: &str) -> error::Result { @@ -262,6 +255,88 @@ fn absolute_url(url: &Url) -> String { absolute.to_string() } +fn proxy_authorization_value(proxy: &Proxy) -> Option { + proxy.username().as_ref().map(|username| { + let auth = if let Some(password) = proxy.password() { + format!("{}:{}", username, password) + } else { + format!("{}:", username) + }; + STANDARD.encode(auth.as_bytes()) + }) +} + +fn append_proxy_authorization_header(header: &mut String, proxy: &Proxy) { + if let Some(auth) = proxy_authorization_value(proxy) { + header.push_str(&format!("Proxy-Authorization: Basic {}\r\n", auth)); + } +} + +fn write_http_request( + stream: &mut W, + header: &str, + body: Option<&RequestBody>, +) -> error::Result<()> +where + W: io::Write, +{ + stream + .write_all(header.as_bytes()) + .map_err(error::request)?; + if let Some(body) = body { + stream.write_all(body.bytes()).map_err(error::request)?; + } + stream.flush().map_err(error::request)?; + Ok(()) +} + +pub(crate) fn parse_proxy_connect_response(header: &[u8]) -> error::Result<()> { + let header = String::from_utf8(header.to_vec()) + .map_err(|_| error::bad_proxy("parse proxy server response error."))?; + let status_line = header + .lines() + .next() + .ok_or_else(|| error::bad_proxy("Proxy server response error."))?; + let status_code = status_line + .split_whitespace() + .nth(1) + .ok_or_else(|| error::bad_proxy("Proxy server response error."))? + .parse::() + .map_err(|_| error::bad_proxy("parse proxy server response error."))?; + + if status_code == 200 { + Ok(()) + } else { + Err(error::bad_proxy(format!( + "Proxy server response error: {}", + status_line + ))) + } +} + +pub(crate) fn read_proxy_connect_response(reader: &mut R) -> error::Result<()> +where + R: io::Read, +{ + let mut header = Vec::new(); + let mut byte = [0u8; 1]; + + loop { + let read = reader.read(&mut byte).map_err(error::request)?; + if read == 0 { + if header.is_empty() { + return Err(error::bad_proxy("Proxy server response error.")); + } + return Err(error::bad_proxy("Incomplete proxy response headers")); + } + + header.push(byte[0]); + if header.ends_with(b"\r\n\r\n") { + return parse_proxy_connect_response(&header); + } + } +} + impl<'a> Connection<'a> { pub fn block_tcp_stream(&self, addr: &String) -> error::Result { let config = self.config(); @@ -307,16 +382,7 @@ impl<'a> Connection<'a> { where S: io::Write, { - let header = self.header(); - let body = self.body(); - - stream.write(header.as_bytes()).map_err(error::request)?; - if let Some(body) = body { - stream.write(body.bytes()).map_err(error::request)?; - } - stream.flush().map_err(error::request)?; - - Ok(()) + write_http_request(stream, self.header(), self.body().as_ref()) } pub fn block_read_stream(&self, url: &Url, stream: &mut S) -> error::Result> @@ -446,3 +512,84 @@ impl<'a> Connection<'a> { self.block_read_stream(url, &mut tls) } } + +#[cfg(test)] +mod tests { + use std::io::{self, Cursor, Write}; + + use crate::request::RequestBody; + use crate::types::Proxy; + + use super::{ + parse_proxy_connect_response, proxy_authorization_value, read_proxy_connect_response, + write_http_request, + }; + + struct PartialWriter { + max_chunk: usize, + written: Vec, + } + + impl PartialWriter { + fn new(max_chunk: usize) -> Self { + Self { + max_chunk, + written: Vec::new(), + } + } + } + + impl Write for PartialWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + let take = buf.len().min(self.max_chunk); + self.written.extend_from_slice(&buf[..take]); + Ok(take) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + #[test] + fn test_write_http_request_retries_until_full_payload_is_written() { + let header = "POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n"; + let body = RequestBody::with_text("hello"); + let mut writer = PartialWriter::new(3); + + write_http_request(&mut writer, header, Some(&body)).unwrap(); + + assert_eq!( + format!("{}hello", header).as_bytes(), + writer.written.as_slice() + ); + } + + #[test] + fn test_proxy_authorization_value_encodes_credentials() { + let proxy = Proxy::http_with_authorization("127.0.0.1", 8080, "user", "secret"); + + assert_eq!( + Some("dXNlcjpzZWNyZXQ=".to_string()), + proxy_authorization_value(&proxy) + ); + } + + #[test] + fn test_parse_proxy_connect_response_requires_200_status() { + let header = b"HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"; + let err = parse_proxy_connect_response(header).unwrap_err(); + + assert!(err + .to_string() + .contains("407 Proxy Authentication Required")); + } + + #[test] + fn test_read_proxy_connect_response_waits_for_complete_headers() { + let header = b"HTTP/1.1 200 Connection Established\r\nProxy-Agent: test\r\n\r\n"; + let mut reader = Cursor::new(header); + + read_proxy_connect_response(&mut reader).unwrap(); + } +} diff --git a/rttp_client/tests/support/mod.rs b/rttp_client/tests/support/mod.rs index 51954b6..176c690 100644 --- a/rttp_client/tests/support/mod.rs +++ b/rttp_client/tests/support/mod.rs @@ -132,6 +132,17 @@ fn proxy_http_request(mut stream: TcpStream, auth: Option<(&str, &str)>) -> io:: Ok(()) } +fn header_value(request: &[u8], name: &str) -> Option { + String::from_utf8_lossy(request).lines().find_map(|line| { + let (header_name, value) = line.split_once(':')?; + if header_name.eq_ignore_ascii_case(name) { + Some(value.trim().to_string()) + } else { + None + } + }) +} + pub fn spawn_http_server() -> (SocketAddr, JoinHandle<()>) { spawn_http_server_count(1) } @@ -233,6 +244,25 @@ pub fn spawn_http_proxy_server() -> (SocketAddr, JoinHandle<()>) { (addr, handle) } +pub fn spawn_http_proxy_auth_echo_server() -> (SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind http proxy auth server"); + let addr = listener.local_addr().expect("http proxy auth addr"); + let handle = thread::spawn(move || { + if let Ok((mut stream, _)) = listener.accept() { + let request = read_http_request(&mut stream); + let auth = header_value(&request, "Proxy-Authorization").unwrap_or_default(); + let body = auth.as_bytes(); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + let _ = stream.write_all(response.as_bytes()); + let _ = stream.write_all(body); + } + }); + (addr, handle) +} + pub fn spawn_invalid_gzip_server() -> (SocketAddr, JoinHandle<()>) { let listener = TcpListener::bind("127.0.0.1:0").expect("bind invalid gzip server"); let addr = listener.local_addr().expect("invalid gzip addr"); @@ -320,6 +350,61 @@ fn spawn_socks5_proxy_server_with_auth( (addr, handle) } +#[cfg(feature = "tls-rustls")] +pub fn spawn_https_proxy_server_with_credentials( + username: &'static str, + password: &'static str, +) -> (SocketAddr, SocketAddr, JoinHandle<()>) { + use base64::Engine; + use std::io::copy; + + let (target_addr, _target_handle) = spawn_tls_server(); + let listener = TcpListener::bind("127.0.0.1:0").expect("bind https proxy server"); + let proxy_addr = listener.local_addr().expect("https proxy addr"); + let handle = thread::spawn(move || { + if let Ok((mut client, _)) = listener.accept() { + let request = read_http_request(&mut client); + let request_str = String::from_utf8_lossy(&request); + let request_line = request_str.lines().next().unwrap_or_default().to_string(); + let proxy_auth = header_value(&request, "Proxy-Authorization").unwrap_or_default(); + let expected_auth = format!( + "Basic {}", + base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", username, password)) + ); + + if proxy_auth != expected_auth { + let _ = client + .write_all(b"HTTP/1.1 407 Proxy Authentication Required\r\nContent-Length: 0\r\n\r\n"); + return; + } + + let target = request_line + .split_whitespace() + .nth(1) + .unwrap_or_default() + .to_string(); + let mut server = TcpStream::connect(&target).expect("connect tls target"); + + let _ = client.write_all(b"HTTP/1.1 200 Conne"); + let _ = client.flush(); + thread::sleep(Duration::from_millis(20)); + let _ = client.write_all(b"ction Established\r\nProxy-Agent: test\r\n\r\n"); + let _ = client.flush(); + + let mut client_reader = client.try_clone().expect("clone client"); + let mut server_writer = server.try_clone().expect("clone target"); + let relay = thread::spawn(move || { + let _ = copy(&mut client_reader, &mut server_writer); + }); + + let _ = copy(&mut server, &mut client); + let _ = relay.join(); + } + }); + + (proxy_addr, target_addr, handle) +} + #[cfg(feature = "tls-rustls")] pub fn spawn_tls_server() -> (SocketAddr, JoinHandle<()>) { use rcgen::generate_simple_self_signed; diff --git a/rttp_client/tests/test_http_async.rs b/rttp_client/tests/test_http_async.rs index 329d6e6..1ac825e 100644 --- a/rttp_client/tests/test_http_async.rs +++ b/rttp_client/tests/test_http_async.rs @@ -122,6 +122,29 @@ fn test_async_http_proxy_uses_absolute_form_for_http_requests() { }); } +#[test] +#[cfg(feature = "async")] +fn test_async_http_proxy_with_auth_uses_proxy_authorization_header() { + let (addr, _handle) = support::spawn_http_proxy_auth_echo_server(); + block_on(async { + let response = client() + .get() + .url("http://example.com/proxy?q=1") + .proxy(Proxy::http_with_authorization( + "127.0.0.1", + u32::from(addr.port()), + "user", + "secret", + )) + .rasync() + .await; + assert!(response.is_ok()); + + let response = response.unwrap(); + assert_eq!("Basic dXNlcjpzZWNyZXQ=", response.body().string().unwrap()); + }); +} + #[test] #[cfg(feature = "async")] fn test_async_proxy_socks5() { @@ -140,3 +163,31 @@ fn test_async_proxy_socks5() { println!("{}", response); }); } + +#[test] +#[cfg(all(feature = "async", feature = "tls-rustls"))] +fn test_async_https_proxy_with_auth_uses_connect_tunnel() { + let (proxy_addr, target_addr, _proxy_handle) = + support::spawn_https_proxy_server_with_credentials("user", "secret"); + block_on(async { + let response = client() + .get() + .url(format!("https://localhost:{}/", target_addr.port())) + .proxy(Proxy::http_with_authorization( + "127.0.0.1", + u32::from(proxy_addr.port()), + "user", + "secret", + )) + .config( + rttp_client::Config::builder() + .verify_ssl_cert(false) + .verify_ssl_hostname(false), + ) + .rasync() + .await; + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!("OK", response.body().string().unwrap()); + }); +} diff --git a/rttp_client/tests/test_http_basic.rs b/rttp_client/tests/test_http_basic.rs index ce43941..7505774 100644 --- a/rttp_client/tests/test_http_basic.rs +++ b/rttp_client/tests/test_http_basic.rs @@ -242,6 +242,51 @@ fn test_http_proxy_uses_absolute_form_for_http_requests() { ); } +#[test] +fn test_http_proxy_with_auth_uses_proxy_authorization_header() { + let (addr, _handle) = support::spawn_http_proxy_auth_echo_server(); + let response = client() + .get() + .url("http://example.com/proxy?q=1") + .proxy(Proxy::http_with_authorization( + "127.0.0.1", + u32::from(addr.port()), + "user", + "secret", + )) + .emit(); + assert!(response.is_ok()); + + let response = response.unwrap(); + assert_eq!("Basic dXNlcjpzZWNyZXQ=", response.body().string().unwrap()); +} + +#[test] +#[cfg(feature = "tls-rustls")] +fn test_https_proxy_with_auth_uses_connect_tunnel() { + let (proxy_addr, target_addr, _proxy_handle) = + support::spawn_https_proxy_server_with_credentials("user", "secret"); + let response = client() + .get() + .url(format!("https://localhost:{}/", target_addr.port())) + .proxy(Proxy::http_with_authorization( + "127.0.0.1", + u32::from(proxy_addr.port()), + "user", + "secret", + )) + .config( + Config::builder() + .verify_ssl_cert(false) + .verify_ssl_hostname(false), + ) + .emit(); + assert!(response.is_ok()); + + let response = response.unwrap(); + assert_eq!("OK", response.body().string().unwrap()); +} + #[test] fn test_connection_closed() { let (addr, _handle) = support::spawn_http_server_count(5);