diff --git a/Cargo.lock b/Cargo.lock index 06d7bcfb9..bb3458144 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -396,6 +396,7 @@ dependencies = [ "slog-term", "subprocess", "tempfile", + "thiserror", "tokio", "tokio-rustls", "tokio-tungstenite", @@ -1838,18 +1839,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 21c908a87..f41d22a01 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -42,6 +42,7 @@ slog-async = "2.8.0" slog-bunyan = "2.5.0" slog-json = "2.6.1" slog-term = "2.9.1" +thiserror = "1.0.64" tokio-rustls = "0.25.0" toml = "0.8.19" waitgroup = "0.1.2" diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index 8716777d9..e4889ccab 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -836,6 +836,8 @@ pub use pagination::PaginationOrder; pub use pagination::PaginationParams; pub use pagination::ResultsPage; pub use pagination::WhichPage; +pub use server::BuildError; +pub use server::ServerBuilder; pub use server::ServerContext; pub use server::ShutdownWaitFuture; pub use server::{HttpServer, HttpServerStarter}; diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 8816c9f77..2bc59f0bb 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -33,6 +33,7 @@ use std::panic; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use thiserror::Error; use tokio::io::ReadBuf; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; @@ -104,10 +105,7 @@ pub struct ServerConfig { } pub struct HttpServerStarter { - app_state: Arc>, - local_addr: SocketAddr, - wrapped: WrappedHttpServerStarter, - handler_waitgroup: WaitGroup, + server: HttpServer, } impl HttpServerStarter { @@ -127,219 +125,18 @@ impl HttpServerStarter { log: &Logger, tls: Option, ) -> Result, GenericError> { - let server_config = ServerConfig { - // We start aggressively to ensure test coverage. - request_body_max_bytes: config.request_body_max_bytes, - page_max_nitems: NonZeroU32::new(10000).unwrap(), - page_default_nitems: NonZeroU32::new(100).unwrap(), - default_handler_task_mode: config.default_handler_task_mode, - log_headers: config.log_headers.clone(), - }; - - let handler_waitgroup = WaitGroup::new(); - let starter = match &tls { - Some(tls) => { - let (starter, app_state, local_addr) = - InnerHttpsServerStarter::new( - config, - server_config, - api, - private, - log, - tls, - handler_waitgroup.worker(), - )?; - HttpServerStarter { - app_state, - local_addr, - wrapped: WrappedHttpServerStarter::Https(starter), - handler_waitgroup, - } - } - None => { - let (starter, app_state, local_addr) = - InnerHttpServerStarter::new( - config, - server_config, - api, - private, - log, - handler_waitgroup.worker(), - )?; - HttpServerStarter { - app_state, - local_addr, - wrapped: WrappedHttpServerStarter::Http(starter), - handler_waitgroup, - } - } - }; - - for (path, method, _) in &starter.app_state.router { - debug!(starter.app_state.log, "registered endpoint"; - "method" => &method, - "path" => &path - ); - } - - Ok(starter) - } - - pub fn start(self) -> HttpServer { - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let log_close = self.app_state.log.new(o!()); - let join_handle = match self.wrapped { - WrappedHttpServerStarter::Http(http) => http.start(rx, log_close), - WrappedHttpServerStarter::Https(https) => { - https.start(rx, log_close) - } - }; - info!(self.app_state.log, "listening"); - - let handler_waitgroup = self.handler_waitgroup; - let join_handle = async move { - // After the server shuts down, we also want to wait for any - // detached handler futures to complete. - () = join_handle - .await - .map_err(|e| format!("server stopped: {e}"))?; - () = handler_waitgroup.wait().await; - Ok(()) - }; - - #[cfg(feature = "usdt-probes")] - let probe_registration = match usdt::register_probes() { - Ok(_) => { - debug!( - self.app_state.log, - "successfully registered DTrace USDT probes" - ); - ProbeRegistration::Succeeded - } - Err(e) => { - let msg = e.to_string(); - error!( - self.app_state.log, - "failed to register DTrace USDT probes: {}", msg - ); - ProbeRegistration::Failed(msg) - } - }; - #[cfg(not(feature = "usdt-probes"))] - let probe_registration = { - debug!( - self.app_state.log, - "DTrace USDT probes compiled out, not registering" - ); - ProbeRegistration::Disabled - }; - - HttpServer { - probe_registration, - app_state: self.app_state, - local_addr: self.local_addr, - closer: CloseHandle { close_channel: Some(tx) }, - join_future: join_handle.boxed().shared(), - } - } -} - -enum WrappedHttpServerStarter { - Http(InnerHttpServerStarter), - Https(InnerHttpsServerStarter), -} - -struct InnerHttpServerStarter( - HttpAcceptor, - ServerConnectionHandler, -); - -type InnerHttpServerStarterNewReturn = - (InnerHttpServerStarter, Arc>, SocketAddr); - -impl InnerHttpServerStarter { - /// Begins execution of the underlying Http server. - fn start( - self, - mut close_signal: tokio::sync::oneshot::Receiver<()>, - log_close: Logger, - ) -> tokio::task::JoinHandle<()> { - use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; - use hyper_util::server::conn::auto; - - tokio::spawn(async move { - let mut builder = auto::Builder::new(TokioExecutor::new()); - // http/1 settings - builder.http1().timer(TokioTimer::new()); - // http/2 settings - builder.http2().timer(TokioTimer::new()); - - // Use a graceful watcher to keep track of all existing connections, - // and when the close_signal is trigger, force all known conns - // to start a graceful shutdown. - let graceful = - hyper_util::server::graceful::GracefulShutdown::new(); - - loop { - tokio::select! { - (sock, remote_addr) = self.0.accept() => { - let fut = builder.serve_connection_with_upgrades( - TokioIo::new(sock), - self.1.make_http_request_handler(remote_addr), - ); - let fut = graceful.watch(fut.into_owned()); - tokio::spawn(fut); - }, - - _ = &mut close_signal => { - info!(log_close, "received request to begin graceful shutdown"); - break; - } - } - } - - // optional: could use another select on a timeout - graceful.shutdown().await + Ok(Self { + server: ServerBuilder::new(log.clone(), private) + .tls(tls) + .api(api) + .config(config.clone()) + .build() + .map_err(Box::new)?, }) } - /// Set up an HTTP server bound on the specified address that runs - /// registered handlers. You must invoke `start()` on the returned instance - /// of `HttpServerStarter` (and await the result) to actually start the - /// server. - fn new( - config: &ConfigDropshot, - server_config: ServerConfig, - api: ApiDescription, - private: C, - log: &Logger, - handler_waitgroup_worker: waitgroup::Worker, - ) -> Result, std::io::Error> { - // We use `from_std` instead of just calling `bind` here directly - // to avoid invoking an async function. - let std_listener = std::net::TcpListener::bind(&config.bind_address)?; - std_listener.set_nonblocking(true)?; - let tcp = TcpListener::from_std(std_listener)?; - let local_addr = tcp.local_addr()?; - let incoming = - HttpAcceptor { tcp, log: log.new(o!("local_addr" => local_addr)) }; - - let app_state = Arc::new(DropshotState { - private, - config: server_config, - router: api.into_router(), - log: log.new(o!("local_addr" => local_addr)), - local_addr, - tls_acceptor: None, - handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), - }); - - let make_service = ServerConnectionHandler::new(app_state.clone()); - Ok(( - InnerHttpServerStarter(incoming, make_service), - app_state, - local_addr, - )) + pub fn start(self) -> HttpServer { + self.server } } @@ -508,16 +305,11 @@ impl HttpsAcceptor { } } -struct InnerHttpsServerStarter( - HttpsAcceptor, - ServerConnectionHandler, -); - /// Create a TLS configuration from the Dropshot config structure. impl TryFrom<&ConfigTls> for rustls::ServerConfig { - type Error = std::io::Error; + type Error = BuildError; - fn try_from(config: &ConfigTls) -> std::io::Result { + fn try_from(config: &ConfigTls) -> Result { let (mut cert_reader, mut key_reader): ( Box, Box, @@ -532,25 +324,17 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { ConfigTls::AsFile { cert_file, key_file } => { let certfile = Box::new(std::io::BufReader::new( std::fs::File::open(cert_file).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!( - "failed to open {}: {}", - cert_file.display(), - e - ), + BuildError::generic_system( + e, + format!("opening {}", cert_file.display()), ) })?, )); let keyfile = Box::new(std::io::BufReader::new( std::fs::File::open(key_file).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!( - "failed to open {}: {}", - key_file.display(), - e - ), + BuildError::generic_system( + e, + format!("opening {}", key_file.display()), ) })?, )); @@ -561,17 +345,17 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { let certs = rustls_pemfile::certs(&mut cert_reader) .collect::, _>>() .map_err(|err| { - io_error(format!("failed to load certificate: {err}")) + BuildError::generic_system(err, "loading TLS certificates") })?; let keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader) .collect::, _>>() .map_err(|err| { - io_error(format!("failed to load private key: {err}")) + BuildError::generic_system(err, "loading TLS private key") })?; let mut keys_iter = keys.into_iter(); let (Some(private_key), None) = (keys_iter.next(), keys_iter.next()) else { - return Err(io_error("expected a single private key".into())); + return Err(BuildError::NotOnePrivateKey); }; let mut cfg = rustls::ServerConfig::builder() @@ -583,102 +367,6 @@ impl TryFrom<&ConfigTls> for rustls::ServerConfig { } } -type InnerHttpsServerStarterNewReturn = - (InnerHttpsServerStarter, Arc>, SocketAddr); - -impl InnerHttpsServerStarter { - /// Begins execution of the underlying Http server. - fn start( - mut self, - mut close_signal: tokio::sync::oneshot::Receiver<()>, - log_close: Logger, - ) -> tokio::task::JoinHandle<()> { - use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; - use hyper_util::server::conn::auto; - - tokio::spawn(async move { - let mut builder = auto::Builder::new(TokioExecutor::new()); - // http/1 settings - builder.http1().timer(TokioTimer::new()); - - // Use a graceful watcher to keep track of all existing connections, - // and when the close_signal is trigger, force all known conns - // to start a graceful shutdown. - let graceful = - hyper_util::server::graceful::GracefulShutdown::new(); - - loop { - tokio::select! { - Some(Ok(sock)) = self.0.accept() => { - let remote_addr = sock.remote_addr(); - let fut = builder.serve_connection_with_upgrades( - TokioIo::new(sock), - self.1.make_http_request_handler(remote_addr), - ); - let fut = graceful.watch(fut.into_owned()); - tokio::spawn(fut); - }, - - _ = &mut close_signal => { - info!(log_close, "received request to begin graceful shutdown"); - break; - } - } - } - - // optional: could use another select on a timeout - graceful.shutdown().await - }) - } - - fn new( - config: &ConfigDropshot, - server_config: ServerConfig, - api: ApiDescription, - private: C, - log: &Logger, - tls: &ConfigTls, - handler_waitgroup_worker: waitgroup::Worker, - ) -> Result, GenericError> { - let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( - rustls::ServerConfig::try_from(tls)?, - )))); - - let tcp = { - let listener = std::net::TcpListener::bind(&config.bind_address)?; - listener.set_nonblocking(true)?; - // We use `from_std` instead of just calling `bind` here directly - // to avoid invoking an async function, to match the interface - // provided by `HttpServerStarter::new`. - TcpListener::from_std(listener)? - }; - - let local_addr = tcp.local_addr()?; - let logger = log.new(o!("local_addr" => local_addr)); - let tcp = HttpAcceptor { tcp, log: logger.clone() }; - let https_acceptor = - HttpsAcceptor::new(logger.clone(), acceptor.clone(), tcp); - - let app_state = Arc::new(DropshotState { - private, - config: server_config, - router: api.into_router(), - log: logger, - local_addr, - tls_acceptor: Some(acceptor), - handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), - }); - - let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); - - Ok(( - InnerHttpsServerStarter(https_acceptor, make_service), - app_state, - local_addr, - )) - } -} - type SharedBoxFuture = Shared + Send>>>; /// Future returned by [`HttpServer::wait_for_shutdown()`]. @@ -1142,8 +830,271 @@ impl Service> } } -fn io_error(err: String) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::Other, err) +#[derive(Debug, Error)] +pub enum BuildError { + #[error("failed to bind to {address}")] + BindError { + address: SocketAddr, + #[source] + error: std::io::Error, + }, + #[error("expected exactly one TLS private key")] + NotOnePrivateKey, + #[error("must register an API")] + MissingApi, + #[error("only one API can be registered with a server")] + TooManyApis, + #[error("{context}")] + SystemError { + context: String, + #[source] + error: std::io::Error, + }, +} + +impl BuildError { + fn bind_error(error: std::io::Error, address: SocketAddr) -> BuildError { + BuildError::BindError { address, error } + } + + fn generic_system>( + error: std::io::Error, + context: S, + ) -> BuildError { + BuildError::SystemError { context: context.into(), error } + } +} + +#[derive(Debug)] +pub struct ServerBuilder { + // required caller-provided values + private: C, + log: Logger, + + // optional caller-provided values + config: ConfigDropshot, + tls: Option, + api: DebugIgnore>>, + + // our own internal state + error: Option, +} + +impl ServerBuilder { + pub fn new(log: Logger, private: C) -> ServerBuilder { + ServerBuilder { + private, + log, + config: Default::default(), + tls: Default::default(), + api: Default::default(), + error: Default::default(), + } + } + + pub fn config(mut self, config: ConfigDropshot) -> Self { + self.config = config; + self + } + + pub fn tls(mut self, tls: Option) -> Self { + self.tls = tls; + self + } + + pub fn api(mut self, api: ApiDescription) -> Self { + if self.api.is_none() { + self.api = DebugIgnore(Some(api)); + } else { + self.error(BuildError::TooManyApis); + } + + self + } + + fn error(&mut self, error: BuildError) { + if self.error.is_none() { + self.error = Some(error); + } + } + + pub fn build(self) -> Result, BuildError> { + let server_config = ServerConfig { + // We start aggressively to ensure test coverage. + request_body_max_bytes: self.config.request_body_max_bytes, + page_max_nitems: NonZeroU32::new(10000).unwrap(), + page_default_nitems: NonZeroU32::new(100).unwrap(), + default_handler_task_mode: self.config.default_handler_task_mode, + log_headers: self.config.log_headers.clone(), + }; + let handler_waitgroup = WaitGroup::new(); + + let config = self.config; + let private = self.private; + let log = self.log; + let tls = self.tls; + let api = self.api.0.ok_or_else(|| BuildError::MissingApi)?; + + let std_listener = std::net::TcpListener::bind(&config.bind_address) + .map_err(|e| BuildError::bind_error(e, config.bind_address))?; + std_listener.set_nonblocking(true).map_err(|e| { + BuildError::generic_system(e, "setting non-blocking") + })?; + // We use `from_std` instead of just calling `bind` here directly + // to avoid invoking an async function. + let tcp = TcpListener::from_std(std_listener).map_err(|e| { + BuildError::generic_system(e, "creating TCP listener") + })?; + let local_addr = tcp.local_addr().map_err(|e| { + BuildError::generic_system(e, "getting local TCP address") + })?; + + let log = log.new(o!("local_addr" => local_addr)); + + let tls_acceptor = tls + .as_ref() + .map(|tls| { + Ok(Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( + rustls::ServerConfig::try_from(tls)?, + ))))) + }) + .transpose()?; + + let app_state = Arc::new(DropshotState { + private, + config: server_config, + router: api.into_router(), + log: log.clone(), + local_addr, + tls_acceptor: tls_acceptor.clone(), + handler_waitgroup_worker: DebugIgnore(handler_waitgroup.worker()), + }); + let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); + + let incoming = HttpAcceptor { tcp, log: log.clone() }; + + let log = &app_state.log; + for (path, method, _) in &app_state.router { + debug!(log, "registered endpoint"; + "method" => &method, + "path" => &path + ); + } + + let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); + + let log_close = log.clone(); + let join_handle = tokio::spawn(async move { + use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; + use hyper_util::server::conn::auto; + + let mut builder = auto::Builder::new(TokioExecutor::new()); + // http/1 settings + builder.http1().timer(TokioTimer::new()); + // XXX-dap previously, the TLS one did NOT do this http2 step + // http/2 settings + builder.http2().timer(TokioTimer::new()); + + // Use a graceful watcher to keep track of all existing connections, + // and when the close_signal is trigger, force all known conns + // to start a graceful shutdown. + let graceful = + hyper_util::server::graceful::GracefulShutdown::new(); + + // The following code looks superficially similar between the HTTP + // and HTTPS paths. However, the concrete types of various objects + // are different and so it's not easy to actually share the code. + let log = log_close; + match tls_acceptor { + Some(tls_acceptor) => { + let mut https_acceptor = + HttpsAcceptor::new(log.clone(), tls_acceptor, incoming); + loop { + tokio::select! { + Some(Ok(sock)) = https_acceptor.accept() => { + let remote_addr = sock.remote_addr(); + let handler = make_service + .make_http_request_handler(remote_addr); + let fut = builder + .serve_connection_with_upgrades( + TokioIo::new(sock), + handler, + ); + let fut = graceful.watch(fut.into_owned()); + tokio::spawn(fut); + }, + + _ = &mut rx => { + info!(log, "beginning graceful shutdown"); + break; + } + } + } + } + None => loop { + tokio::select! { + (sock, remote_addr) = incoming.accept() => { + let handler = make_service + .make_http_request_handler(remote_addr); + let fut = builder + .serve_connection_with_upgrades( + TokioIo::new(sock), + handler, + ); + let fut = graceful.watch(fut.into_owned()); + tokio::spawn(fut); + }, + + _ = &mut rx => { + info!(log, "beginning graceful shutdown"); + break; + } + } + }, + }; + + // optional: could use another select on a timeout + graceful.shutdown().await + }); + + info!(log, "listening"); + + let join_handle = async move { + // After the server shuts down, we also want to wait for any + // detached handler futures to complete. + () = join_handle + .await + .map_err(|e| format!("server stopped: {e}"))?; + () = handler_waitgroup.wait().await; + Ok(()) + }; + + #[cfg(feature = "usdt-probes")] + let probe_registration = match usdt::register_probes() { + Ok(_) => { + debug!(&log, "successfully registered DTrace USDT probes"); + ProbeRegistration::Succeeded + } + Err(e) => { + let msg = e.to_string(); + error!(&log, "failed to register DTrace USDT probes: {}", msg); + ProbeRegistration::Failed(msg) + } + }; + #[cfg(not(feature = "usdt-probes"))] + let probe_registration = { + debug!(&log, "DTrace USDT probes compiled out, not registering"); + ProbeRegistration::Disabled + }; + + Ok(HttpServer { + probe_registration, + app_state, + local_addr, + closer: CloseHandle { close_channel: Some(tx) }, + join_future: join_handle.boxed().shared(), + }) + } } #[cfg(test)]