From 20b3cd2cefb564b5ddd439e04ae5c532c78853c1 Mon Sep 17 00:00:00 2001 From: Mike Oleshchuk Date: Fri, 22 May 2026 16:34:46 -0700 Subject: [PATCH] ran cargo fmt --- examples/tds_server_dummy.rs | 285 ++++++++---------- src/client.rs | 4 +- src/client/prepared.rs | 5 +- src/client/rpc_response.rs | 5 +- src/client/tls.rs | 11 +- src/lib.rs | 14 +- src/server/auth/env_provider.rs | 7 +- src/server/auth/gssapi.rs | 9 +- src/server/auth/handler.rs | 58 ++-- src/server/auth/login_info.rs | 6 +- src/server/auth/traits.rs | 2 +- src/server/builder.rs | 9 +- src/server/codec.rs | 87 +++--- src/server/connection.rs | 68 ++--- src/server/handler.rs | 2 +- src/server/messages.rs | 4 +- src/server/mod.rs | 37 ++- src/server/prepared.rs | 20 +- src/server/query.rs | 10 +- src/server/response.rs | 82 ++--- src/server/router.rs | 19 +- src/server/server.rs | 9 +- src/server/sp_cursor.rs | 6 +- src/server/sp_executesql.rs | 13 +- src/server/sp_prepare.rs | 74 ++--- src/server/sp_prepexec.rs | 19 +- src/tds.rs | 2 +- src/tds/codec/column_data.rs | 156 ++++------ src/tds/codec/column_data/string.rs | 3 +- src/tds/codec/login.rs | 3 +- src/tds/codec/rpc_request.rs | 6 +- src/tds/codec/token.rs | 20 +- src/tds/codec/token/token_col_metadata.rs | 5 +- src/tds/codec/token/token_done.rs | 12 +- src/tds/codec/token/token_env_change.rs | 6 +- src/tds/codec/token/token_error.rs | 2 +- src/tds/codec/token/token_fed_auth_info.rs | 2 +- src/tds/codec/token/token_info.rs | 2 +- src/tds/codec/token/token_login_ack.rs | 4 +- .../token_row/bytes_mut_with_data_columns.rs | 5 +- src/tds/codec/token/token_session_state.rs | 2 +- src/tds/tls.rs | 11 +- tests/query.rs | 35 +-- 43 files changed, 493 insertions(+), 648 deletions(-) diff --git a/examples/tds_server_dummy.rs b/examples/tds_server_dummy.rs index 62af14c6..375b99cd 100644 --- a/examples/tds_server_dummy.rs +++ b/examples/tds_server_dummy.rs @@ -27,10 +27,10 @@ mod server { use tiberius::server::codec::{decode_rpc_params, DecodedRpcParam}; use tiberius::server::{ process_connection, AttentionHandler, AuthBuilder, AuthError, AuthHandler, AuthSuccess, - BulkLoadHandler, DefaultEnvChangeProvider, EnvChangeProvider, ErrorHandler, FedAuthValidator, - LoginInfo, ResultSetWriter, RpcHandler, SqlAuthSource, SqlBatchHandler, SspiAcceptor, - SspiStart, SspiStep, TdsAuthHandler, TdsBackendMessage, TdsConnectionContext, TdsConnectionState, - TdsServerHandlers, + BulkLoadHandler, DefaultEnvChangeProvider, EnvChangeProvider, ErrorHandler, + FedAuthValidator, LoginInfo, ResultSetWriter, RpcHandler, SqlAuthSource, SqlBatchHandler, + SspiAcceptor, SspiStart, SspiStep, TdsAuthHandler, TdsBackendMessage, TdsConnectionContext, + TdsConnectionState, TdsServerHandlers, }; use tiberius::{ numeric::Numeric, @@ -39,10 +39,10 @@ mod server { AltMetaDataColumn, BaseMetaDataColumn, Collation, ColumnData, ColumnFlag, DoneStatus, FedAuthInfoOption, FixedLenType, LoginMessage, MetaDataColumn, PreloginMessage, RpcStatus, SessionStateEntry, SsVariantInfo, TokenAltMetaData, TokenAltRow, TokenColInfo, - TokenColMetaData, TokenColName, TokenDone, TokenEnvChange, TokenError, TokenFedAuthInfo, - TokenFeatureExtAck, TokenInfo, TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow, - TokenSessionState, TokenTabName, TvpColumn, TvpData, TvpInfo, TypeInfo, UdtInfo, - VariantData, Uuid, VarLenContext, VarLenType, + TokenColMetaData, TokenColName, TokenDone, TokenEnvChange, TokenError, TokenFeatureExtAck, + TokenFedAuthInfo, TokenInfo, TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow, + TokenSessionState, TokenTabName, TvpColumn, TvpData, TvpInfo, TypeInfo, UdtInfo, Uuid, + VarLenContext, VarLenType, VariantData, }; use tiberius::{EncryptionLevel, Result}; @@ -220,8 +220,7 @@ mod server { collation: Option, flags: enumflags2::BitFlags, ) -> MetaDataColumn<'static> { - let table_name = if matches!(ty, VarLenType::Text | VarLenType::NText | VarLenType::Image) - { + let table_name = if matches!(ty, VarLenType::Text | VarLenType::NText | VarLenType::Image) { Some(vec!["dummy_table".into()]) } else { None @@ -501,7 +500,9 @@ mod server { fn feature_ext_ack(&self, login: &LoginMessage<'_>) -> Option { if self.force_feature_ack && !login.has_feature_ext() { log_event("login: sending FeatureExtAck (forced)"); - return Some(TokenFeatureExtAck { features: Vec::new() }); + return Some(TokenFeatureExtAck { + features: Vec::new(), + }); } let ack = self.inner.feature_ext_ack(login); @@ -521,7 +522,8 @@ mod server { { let force_fedauth = std::env::var("TDS_DUMMY_FORCE_FEDAUTH").ok().as_deref() == Some("1"); - let wants_fedauth = login.fed_auth_token().is_some() || login.fed_auth_nonce().is_some(); + let wants_fedauth = + login.fed_auth_token().is_some() || login.fed_auth_nonce().is_some(); if !force_fedauth && !wants_fedauth { return None; @@ -693,7 +695,11 @@ mod server { { Box::pin(async move { let db_name = message.db_name_ref(); - let db_name = if db_name.is_empty() { "master" } else { db_name }; + let db_name = if db_name.is_empty() { + "master" + } else { + db_name + }; log_event(&format!( "login from={} user={} app={} db={} tds_version={:?} packet_size={}", @@ -765,19 +771,12 @@ mod server { if lower.contains("tds_info") { log_event("sql_batch: tds_info"); - let info = TokenInfo::new( - 5701, - 0, - 0, - "dummy info token", - "tiberius", - "tds_dummy", - 1, - ); + let info = + TokenInfo::new(5701, 0, 0, "dummy info token", "tiberius", "tds_dummy", 1); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Info( - info, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Info(info), + )) .await?; } @@ -805,9 +804,10 @@ mod server { ]; let mut writer = ResultSetWriter::start(client, columns).await?; writer - .send_row_iter([ColumnData::I32(Some(100)), ColumnData::String(Some( - "gamma".into(), - ))]) + .send_row_iter([ + ColumnData::I32(Some(100)), + ColumnData::String(Some("gamma".into())), + ]) .await?; writer.finish(1).await?; return Ok(()); @@ -891,12 +891,7 @@ mod server { 16, None, )); - columns.push(meta_var( - "binary_legacy_col", - VarLenType::Binary, - 16, - None, - )); + columns.push(meta_var("binary_legacy_col", VarLenType::Binary, 16, None)); let num1 = Numeric::new_with_scale(1234, 1); let dec1 = Numeric::new_with_scale(5678, 2); @@ -1064,7 +1059,12 @@ mod server { meta_var("char_big_col", VarLenType::BigChar, 20, collation), ]; if include_legacy { - columns.push(meta_var("varchar_short_col", VarLenType::VarChar, 50, collation)); + columns.push(meta_var( + "varchar_short_col", + VarLenType::VarChar, + 50, + collation, + )); columns.push(meta_var("char_short_col", VarLenType::Char, 20, collation)); } let mut writer = ResultSetWriter::start(client, columns).await?; @@ -1167,9 +1167,7 @@ mod server { }), )]; let mut writer = ResultSetWriter::start(client, columns).await?; - writer - .send_row_iter([ColumnData::Tvp(Some(tvp))]) - .await?; + writer.send_row_iter([ColumnData::Tvp(Some(tvp))]).await?; writer.finish(1).await?; return Ok(()); } @@ -1304,15 +1302,13 @@ mod server { columns: columns.clone(), }; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::ColMetaData( - col_meta, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::ColMetaData(col_meta), + )) .await?; - let force_colname = std::env::var("TDS_DUMMY_FORCE_COLNAME") - .ok() - .as_deref() - == Some("1"); + let force_colname = + std::env::var("TDS_DUMMY_FORCE_COLNAME").ok().as_deref() == Some("1"); let legacy_colname = (client.tds_version() as u32) < TDS_VER_70; if force_colname || legacy_colname { client @@ -1327,11 +1323,11 @@ mod server { } client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::TabName( - TokenTabName { + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::TabName(TokenTabName { tables: vec![vec!["dummy_table".into()]], - }, - ))) + }), + )) .await?; let mut colinfo = BytesMut::with_capacity(3); @@ -1339,27 +1335,29 @@ mod server { colinfo.put_u8(0); // table #0 (expression) colinfo.put_u8(0x04); // EXPRESSION status client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::ColInfo( - TokenColInfo { data: colinfo }, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::ColInfo(TokenColInfo { data: colinfo }), + )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Order( - TokenOrder::new(vec![1]), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Order(TokenOrder::new(vec![1])), + )) .await?; let mut row = TokenRow::with_capacity(1); row.push(ColumnData::I32(Some(66))); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Row(row))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Row(row), + )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::with_rows(1), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::with_rows(1)), + )) .await?; return Ok(()); } @@ -1371,9 +1369,9 @@ mod server { columns: columns.clone(), }; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::ColMetaData( - col_meta, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::ColMetaData(col_meta), + )) .await?; let alt_column = meta_fixed("alt_value", FixedLenType::Int4); @@ -1396,7 +1394,9 @@ mod server { let mut row = TokenRow::with_capacity(1); row.push(ColumnData::I32(Some(10))); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Row(row))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Row(row), + )) .await?; let mut alt_row = TokenAltRow::with_capacity(1, 1); @@ -1409,9 +1409,9 @@ mod server { .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::with_rows(1), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::with_rows(1)), + )) .await?; return Ok(()); } @@ -1432,9 +1432,9 @@ mod server { )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::default(), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::default()), + )) .await?; return Ok(()); } @@ -1443,9 +1443,9 @@ mod server { log_event("sql_batch: tds_fedauth"); log_event("tds_fedauth: FedAuthInfo is login-only; skipping in batch"); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::default(), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::default()), + )) .await?; return Ok(()); } @@ -1457,17 +1457,17 @@ mod server { columns: columns.clone(), }; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::ColMetaData( - col_meta, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::ColMetaData(col_meta), + )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::TabName( - TokenTabName { + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::TabName(TokenTabName { tables: vec![vec!["dummy_table".into()]], - }, - ))) + }), + )) .await?; let mut colinfo = BytesMut::with_capacity(3); @@ -1475,26 +1475,28 @@ mod server { colinfo.put_u8(1); // table #1 colinfo.put_u8(0x00); // no flags client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::ColInfo( - TokenColInfo { data: colinfo }, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::ColInfo(TokenColInfo { data: colinfo }), + )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Order( - TokenOrder::new(vec![1]), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Order(TokenOrder::new(vec![1])), + )) .await?; let mut row = TokenRow::with_capacity(1); row.push(ColumnData::I32(Some(55))); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Row(row))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Row(row), + )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::with_rows(1), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::with_rows(1)), + )) .await?; return Ok(()); } @@ -1521,20 +1523,13 @@ mod server { "iso-8859-1".into(), )), tiberius::server::BackendToken::EnvChange(TokenEnvChange::PacketSize( - 8192, - 4096, + 8192, 4096, )), tiberius::server::BackendToken::EnvChange( - TokenEnvChange::UnicodeDataSortingLID( - "0x0409".into(), - "0x0000".into(), - ), + TokenEnvChange::UnicodeDataSortingLID("0x0409".into(), "0x0000".into()), ), tiberius::server::BackendToken::EnvChange( - TokenEnvChange::UnicodeDataSortingCFL( - "0x0001".into(), - "0x0000".into(), - ), + TokenEnvChange::UnicodeDataSortingCFL("0x0001".into(), "0x0000".into()), ), tiberius::server::BackendToken::EnvChange(TokenEnvChange::SqlCollation { old: collation_old, @@ -1582,9 +1577,7 @@ mod server { new: tx_desc_new.clone(), }, ), - tiberius::server::BackendToken::EnvChange( - TokenEnvChange::ResetConnection, - ), + tiberius::server::BackendToken::EnvChange(TokenEnvChange::ResetConnection), tiberius::server::BackendToken::EnvChange(TokenEnvChange::UserName( "dummy_user".into(), "old_user".into(), @@ -1724,24 +1717,20 @@ mod server { if lower.contains("tds_error") { log_event("sql_batch: tds_error"); - let err = TokenError::new( - 50000, - 1, - 16, - "dummy error", - "tiberius", - "tds_error", - 1, - ); + let err = + TokenError::new(50000, 1, 16, "dummy error", "tiberius", "tds_error", 1); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Error( - err, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Error(err), + )) .await?; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::with_status(DoneStatus::Error.into(), 0), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::with_status( + DoneStatus::Error.into(), + 0, + )), + )) .await?; return Ok(()); } @@ -1786,9 +1775,9 @@ mod server { if !lower.contains("select") { log_event("sql_batch: non-select done"); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::default(), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::default()), + )) .await?; return Ok(()); } @@ -1860,8 +1849,7 @@ mod server { let is_execute = proc_name.eq_ignore_ascii_case("sp_execute"); let is_prepexec = proc_name.eq_ignore_ascii_case("sp_prepexec"); let is_exec_family = is_executesql || is_execute || is_prepexec; - let suppress_param_echo = - is_executesql || is_prepare || is_execute || is_prepexec; + let suppress_param_echo = is_executesql || is_prepare || is_execute || is_prepexec; if is_prepare { output_only = true; } @@ -1907,19 +1895,11 @@ mod server { } } - let info = TokenInfo::new( - 8127, - 0, - 0, - "dummy rpc info", - "tiberius", - "tds_rpc", - 1, - ); + let info = TokenInfo::new(8127, 0, 0, "dummy rpc info", "tiberius", "tds_rpc", 1); client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Info( - info, - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Info(info), + )) .await?; let mut pending_return_tokens = { @@ -2035,9 +2015,9 @@ mod server { } client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::DoneProc( - TokenDone::with_rows(1), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::DoneProc(TokenDone::with_rows(1)), + )) .await }) } @@ -2060,9 +2040,9 @@ mod server { { Box::pin(async move { client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::default(), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::default()), + )) .await }) } @@ -2086,9 +2066,12 @@ mod server { client.clear_attention(); let status = DoneStatus::Attention; client - .send(TdsBackendMessage::Token(tiberius::server::BackendToken::Done( - TokenDone::with_status(status.into(), 0), - ))) + .send(TdsBackendMessage::Token( + tiberius::server::BackendToken::Done(TokenDone::with_status( + status.into(), + 0, + )), + )) .await?; client.set_state(TdsConnectionState::ReadyForQuery); Ok(()) @@ -2138,14 +2121,8 @@ mod server { fn load_tls_acceptor() -> Option { let cert_path = std::env::var("TDS_DUMMY_TLS_CERT").ok()?; let key_path = std::env::var("TDS_DUMMY_TLS_KEY").ok()?; - let tls12_only = std::env::var("TDS_DUMMY_TLS12_ONLY") - .ok() - .as_deref() - == Some("1"); - let tls13_only = std::env::var("TDS_DUMMY_TLS13_ONLY") - .ok() - .as_deref() - == Some("1"); + let tls12_only = std::env::var("TDS_DUMMY_TLS12_ONLY").ok().as_deref() == Some("1"); + let tls13_only = std::env::var("TDS_DUMMY_TLS13_ONLY").ok().as_deref() == Some("1"); if tls12_only && tls13_only { log_event("TLS config: both TLS12_ONLY and TLS13_ONLY set; using TLS 1.3"); @@ -2241,7 +2218,11 @@ mod server { #[cfg(feature = "server-rustls")] log_event(&format!( "TLS acceptor: {}", - if tls_acceptor.is_some() { "enabled" } else { "disabled" } + if tls_acceptor.is_some() { + "enabled" + } else { + "disabled" + } )); loop { diff --git a/src/client.rs b/src/client.rs index aae34af2..cce1d536 100644 --- a/src/client.rs +++ b/src/client.rs @@ -642,9 +642,7 @@ where ReceivedToken::Error(e) => { last_error.get_or_insert(crate::Error::Server(e)); } - ReceivedToken::DoneInProc(ref done) - if !done.status().contains(DoneStatus::More) => - { + ReceivedToken::DoneInProc(ref done) if !done.status().contains(DoneStatus::More) => { if columns.is_some() { results.push(std::mem::take(&mut current)); columns = None; diff --git a/src/client/prepared.rs b/src/client/prepared.rs index e4b15a7f..c29b59a8 100644 --- a/src/client/prepared.rs +++ b/src/client/prepared.rs @@ -176,10 +176,7 @@ impl Drop for PreparedStatement { } /// Build the RPC parameter list for `sp_execute`: `[@handle, @P1, @P2, ...]`. -fn build_execute_params<'a>( - handle: PreparedHandle, - params: &[&'a dyn ToSql], -) -> Vec> { +fn build_execute_params<'a>(handle: PreparedHandle, params: &[&'a dyn ToSql]) -> Vec> { let mut rpc_params: Vec> = Vec::with_capacity(params.len() + 1); rpc_params.push(RpcParam { name: Cow::Borrowed(""), diff --git a/src/client/rpc_response.rs b/src/client/rpc_response.rs index ed540f3e..518821a8 100644 --- a/src/client/rpc_response.rs +++ b/src/client/rpc_response.rs @@ -411,8 +411,9 @@ mod tests { mk_done_proc_final(), ]); - let (outputs, status, metadata) = - collect_rpc_outputs_with_metadata_from_stream(s).await.unwrap(); + let (outputs, status, metadata) = collect_rpc_outputs_with_metadata_from_stream(s) + .await + .unwrap(); assert_eq!(outputs.len(), 1); assert_eq!(outputs[0].get::().unwrap(), Some(42)); diff --git a/src/client/tls.rs b/src/client/tls.rs index 4a8e1593..0e2c25a4 100644 --- a/src/client/tls.rs +++ b/src/client/tls.rs @@ -151,7 +151,10 @@ impl TlsPreloginWrapper { pub fn handshake_complete(&mut self) { self.pending_handshake = false; - debug_assert!(self.pending_len == 0, "pending TLS handshake data not flushed"); + debug_assert!( + self.pending_len == 0, + "pending TLS handshake data not flushed" + ); self.wr_buf.clear(); self.wr_pos = 0; self.pending_len = 0; @@ -289,10 +292,8 @@ impl AsyncWrite for TlsPreloginWrapper inner.wr_buf.len() - inner.wr_pos, ); - let written = ready!( - Pin::new(&mut inner.stream.as_mut().unwrap()) - .poll_write(cx, &inner.wr_buf[inner.wr_pos..]) - )?; + let written = ready!(Pin::new(&mut inner.stream.as_mut().unwrap()) + .poll_write(cx, &inner.wr_buf[inner.wr_pos..]))?; if written == 0 { return Poll::Ready(Err(io::Error::new( diff --git a/src/lib.rs b/src/lib.rs index 644208cf..cad9f062 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -266,8 +266,8 @@ mod result; mod row; mod tds; -mod sql_browser; pub mod server; +mod sql_browser; pub use client::{ AuthMethod, CancellationToken, Client, Config, Cursor, CursorConcurrencyOptions, CursorHandle, @@ -281,14 +281,14 @@ pub use result::*; pub use row::{Column, ColumnType, Row}; pub use sql_browser::SqlBrowser; pub use tds::{ - codec::{ - AltMetaDataColumn, BaseMetaDataColumn, BulkLoadRequest, ColumnData, ColumnFlag, Encode, - FedAuthInfoOption, FixedLenType, IntoRow, LoginMessage, MetaDataColumn, PreloginMessage, - RpcOption, RpcProcId, RpcStatus, DoneStatus, SessionStateEntry, SsVariantInfo, + codec::{ + AltMetaDataColumn, BaseMetaDataColumn, BulkLoadRequest, ColumnData, ColumnFlag, DoneStatus, + Encode, FedAuthInfoOption, FixedLenType, IntoRow, LoginMessage, MetaDataColumn, + PreloginMessage, RpcOption, RpcProcId, RpcStatus, SessionStateEntry, SsVariantInfo, TokenAltMetaData, TokenAltRow, TokenColInfo, TokenColMetaData, TokenColName, TokenDone, - TokenEnvChange, TokenError, TokenFedAuthInfo, TokenFeatureExtAck, TokenInfo, TokenLoginAck, + TokenEnvChange, TokenError, TokenFeatureExtAck, TokenFedAuthInfo, TokenInfo, TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow, TokenSessionState, TokenSspi, TokenTabName, - TypeInfo, TypeLength, TvpColumn, TvpData, TvpInfo, UdtData, UdtInfo, VarLenContext, + TvpColumn, TvpData, TvpInfo, TypeInfo, TypeLength, UdtData, UdtInfo, VarLenContext, VarLenType, VariantData, }, numeric, diff --git a/src/server/auth/env_provider.rs b/src/server/auth/env_provider.rs index 954259be..16b6db40 100644 --- a/src/server/auth/env_provider.rs +++ b/src/server/auth/env_provider.rs @@ -30,7 +30,12 @@ impl Default for DefaultEnvChangeProvider { impl DefaultEnvChangeProvider { pub fn login_ack(&self, login: &LoginMessage<'_>) -> TokenLoginAck { - TokenLoginAck::new(1, login.tds_version(), &self.program_name, self.server_version) + TokenLoginAck::new( + 1, + login.tds_version(), + &self.program_name, + self.server_version, + ) } pub fn env_changes(&self, client: &C, login: &LoginMessage<'_>) -> Vec diff --git a/src/server/auth/gssapi.rs b/src/server/auth/gssapi.rs index 1ddfadfa..5f099890 100644 --- a/src/server/auth/gssapi.rs +++ b/src/server/auth/gssapi.rs @@ -51,9 +51,7 @@ impl SspiAcceptor for GssapiAcceptor { let complete = ctx.is_complete(); let session_user = if complete { - ctx.source_name() - .ok() - .map(|name| name.to_string()) + ctx.source_name().ok().map(|name| name.to_string()) } else { None }; @@ -89,10 +87,7 @@ impl SspiSession for GssapiSession { })?; let complete = self.ctx.is_complete(); let session_user = if complete { - self.ctx - .source_name() - .ok() - .map(|name| name.to_string()) + self.ctx.source_name().ok().map(|name| name.to_string()) } else { None }; diff --git a/src/server/auth/handler.rs b/src/server/auth/handler.rs index f5668f63..00f9d22a 100644 --- a/src/server/auth/handler.rs +++ b/src/server/auth/handler.rs @@ -17,7 +17,9 @@ use super::builder::AuthBuilder; use super::env_provider::DefaultEnvChangeProvider; use super::error::AuthError; use super::login_info::LoginInfo; -use super::traits::{EnvChangeProvider, FedAuthValidator, SqlAuthSource, SspiAcceptor, SspiSession, SspiStep}; +use super::traits::{ + EnvChangeProvider, FedAuthValidator, SqlAuthSource, SspiAcceptor, SspiSession, SspiStep, +}; /// Maximum number of concurrent pending SSPI authentication sessions. const MAX_PENDING_SSPI_SESSIONS: usize = 1000; @@ -117,15 +119,8 @@ where .and_then(|info| info.server()) .map(|s| s.to_string()) .unwrap_or_else(|| client.socket_addr().ip().to_string()); - let token = crate::TokenError::new( - err.code, - err.state, - err.class, - err.message, - server, - "", - 1, - ); + let token = + crate::TokenError::new(err.code, err.state, err.class, err.message, server, "", 1); let done = TokenDone::with_status(DoneStatus::Error.into(), 0); client @@ -247,23 +242,14 @@ where if let Some(token) = message.fed_auth_token() { let Some(validator) = self.fed_auth.as_ref() else { return self - .send_login_error( - client, - Some(&info), - AuthError::login_failed(info.user()), - ) + .send_login_error(client, Some(&info), AuthError::login_failed(info.user())) .await; }; match validator.validate(&info, token).await { Ok(success) => { return self - .finish_login( - client, - &message, - &info, - success.session_user.as_deref(), - ) + .finish_login(client, &message, &info, success.session_user.as_deref()) .await; } Err(err) => { @@ -275,11 +261,7 @@ where if let Some(initial) = message.integrated_security_bytes() { let Some(acceptor) = self.sspi.as_ref() else { return self - .send_login_error( - client, - Some(&info), - AuthError::login_failed(info.user()), - ) + .send_login_error(client, Some(&info), AuthError::login_failed(info.user())) .await; }; @@ -333,12 +315,7 @@ where match sql_auth.authenticate(&info, password).await { Ok(success) => { return self - .finish_login( - client, - &message, - &info, - success.session_user.as_deref(), - ) + .finish_login(client, &message, &info, success.session_user.as_deref()) .await; } Err(err) => { @@ -348,7 +325,9 @@ where } if self.allow_trust { - return self.finish_login(client, &message, &info, info.user()).await; + return self + .finish_login(client, &message, &info, info.user()) + .await; } self.send_login_error(client, Some(&info), AuthError::login_failed(info.user())) @@ -373,18 +352,16 @@ where let Some(mut session) = session else { return self - .send_login_error( - client, - None, - AuthError::login_failed(None), - ) + .send_login_error(client, None, AuthError::login_failed(None)) .await; }; let step = match session.sspi.step(token.as_ref()) { Ok(step) => step, Err(err) => { - return self.send_login_error(client, Some(&session.info), err).await; + return self + .send_login_error(client, Some(&session.info), err) + .await; } }; @@ -397,8 +374,7 @@ where sessions.insert(addr, session); } - self.handle_sspi_step(client, &login, &info, step) - .await + self.handle_sspi_step(client, &login, &info, step).await }) } } diff --git a/src/server/auth/login_info.rs b/src/server/auth/login_info.rs index 594c65f6..a8bc8fb9 100644 --- a/src/server/auth/login_info.rs +++ b/src/server/auth/login_info.rs @@ -30,7 +30,11 @@ impl LoginInfo { let hostname = login.hostname_ref().trim(); Self { - user: if user.is_empty() { None } else { Some(user.to_string()) }, + user: if user.is_empty() { + None + } else { + Some(user.to_string()) + }, database: if database.is_empty() { None } else { diff --git a/src/server/auth/traits.rs b/src/server/auth/traits.rs index 72debd82..852d3524 100644 --- a/src/server/auth/traits.rs +++ b/src/server/auth/traits.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; use async_trait::async_trait; -use crate::tds::codec::{TokenEnvChange, TokenFedAuthInfo, TokenFeatureExtAck, TokenLoginAck}; +use crate::tds::codec::{TokenEnvChange, TokenFeatureExtAck, TokenFedAuthInfo, TokenLoginAck}; use crate::LoginMessage; use super::error::{AuthResult, AuthSuccess}; diff --git a/src/server/builder.rs b/src/server/builder.rs index 3010ee6a..9be9b9db 100644 --- a/src/server/builder.rs +++ b/src/server/builder.rs @@ -138,7 +138,9 @@ pub struct TdsServerBuilder { error: E, } -impl Default for TdsServerBuilder { +impl Default + for TdsServerBuilder +{ fn default() -> Self { Self::new() } @@ -233,7 +235,10 @@ impl TdsServerBuilder { /// .auth(my_auth_handler) /// .query(my_query_handler); /// ``` - pub fn query(self, handler: H) -> TdsServerBuilder>, R, B, AT, E> + pub fn query( + self, + handler: H, + ) -> TdsServerBuilder>, R, B, AT, E> where H: QueryHandler, { diff --git a/src/server/codec.rs b/src/server/codec.rs index f6c2817d..e9bfdca5 100644 --- a/src/server/codec.rs +++ b/src/server/codec.rs @@ -7,9 +7,8 @@ use bytes::{Buf, BytesMut}; const MAX_PENDING_PAYLOAD_SIZE: usize = 16 * 1024 * 1024; use crate::server::messages::{ - AllHeaders, RequestFlags, RpcMessage, SqlBatchMessage, TdsBackendMessage, - TdsFrontendMessage, TraceActivityHeader, TransactionDescriptor, TransactionDescriptorHeader, - UnknownHeader, + AllHeaders, RequestFlags, RpcMessage, SqlBatchMessage, TdsBackendMessage, TdsFrontendMessage, + TraceActivityHeader, TransactionDescriptor, TransactionDescriptorHeader, UnknownHeader, }; use crate::server::state::TdsConnectionState; use crate::tds::codec::{ @@ -18,8 +17,8 @@ use crate::tds::codec::{ }; use crate::tds::Context; use crate::SqlReadBytes; -use asynchronous_codec::Decoder; use crate::{Error, Result}; +use asynchronous_codec::Decoder; use byteorder::{LittleEndian, ReadBytesExt}; use enumflags2::BitFlags; use futures_util::io::AsyncRead; @@ -81,8 +80,7 @@ impl TdsCodec { if self.pending_discard { if let Some(pending) = self.pending_type { - let starts_new_message = - header.r#type() != pending || header.id() == 1; + let starts_new_message = header.r#type() != pending || header.id() == 1; if starts_new_message { self.pending_type = None; self.pending_payload.clear(); @@ -107,9 +105,7 @@ impl TdsCodec { None => { // Check initial payload size if payload.len() > MAX_PENDING_PAYLOAD_SIZE { - return Err(Error::Protocol( - "tds: payload exceeded maximum size".into(), - )); + return Err(Error::Protocol("tds: payload exceeded maximum size".into())); } self.pending_type = Some(header.r#type()); self.pending_payload = payload; @@ -372,9 +368,8 @@ fn decode_all_headers(bytes: &[u8]) -> Result { } let mut desc_bytes = [0u8; 8]; desc_bytes.copy_from_slice(&data[..8]); - let outstanding_requests = u32::from_le_bytes([ - data[8], data[9], data[10], data[11], - ]); + let outstanding_requests = + u32::from_le_bytes([data[8], data[9], data[10], data[11]]); headers.transaction_descriptor = Some(TransactionDescriptorHeader { descriptor: TransactionDescriptor::new(desc_bytes), outstanding_requests, @@ -388,18 +383,14 @@ fn decode_all_headers(bytes: &[u8]) -> Result { } let mut activity_id = [0u8; 16]; activity_id.copy_from_slice(&data[..16]); - let sequence_number = - u32::from_le_bytes([data[16], data[17], data[18], data[19]]); + let sequence_number = u32::from_le_bytes([data[16], data[17], data[18], data[19]]); headers.trace_activity = Some(TraceActivityHeader { activity_id, sequence_number, }); } _ => { - headers.unknown.push(UnknownHeader { - header_type, - data, - }); + headers.unknown.push(UnknownHeader { header_type, data }); } } } @@ -695,24 +686,30 @@ mod tests { offset = end; } - let attention_packet = - encode_packet(PacketType::AttentionSignal, PacketStatus::EndOfMessage, id, BytesMut::new()); + let attention_packet = encode_packet( + PacketType::AttentionSignal, + PacketStatus::EndOfMessage, + id, + BytesMut::new(), + ); id = id.wrapping_add(1); let mut payload2 = BytesMut::new(); BatchRequest::new("SELECT 1", [0; 8]) .encode(&mut payload2) .expect("batch encode"); - let followup_packet = - encode_packet(PacketType::SQLBatch, PacketStatus::EndOfMessage, id, payload2); + let followup_packet = encode_packet( + PacketType::SQLBatch, + PacketStatus::EndOfMessage, + id, + payload2, + ); buf.extend_from_slice(&batch_packets[0]); - assert!( - codec - .decode(&mut buf, TdsConnectionState::ReadyForQuery) - .expect("decode") - .is_none() - ); + assert!(codec + .decode(&mut buf, TdsConnectionState::ReadyForQuery) + .expect("decode") + .is_none()); buf.extend_from_slice(&attention_packet); let msg = codec @@ -759,19 +756,25 @@ mod tests { let payload_bytes = payload.freeze(); let first_chunk = BytesMut::from(&payload_bytes[..chunk_size]); - let first_packet = - encode_packet(PacketType::SQLBatch, PacketStatus::NormalMessage, 1, first_chunk); + let first_packet = encode_packet( + PacketType::SQLBatch, + PacketStatus::NormalMessage, + 1, + first_chunk, + ); buf.extend_from_slice(&first_packet); - assert!( - codec - .decode(&mut buf, TdsConnectionState::ReadyForQuery) - .expect("decode") - .is_none() + assert!(codec + .decode(&mut buf, TdsConnectionState::ReadyForQuery) + .expect("decode") + .is_none()); + + let attention_packet = encode_packet( + PacketType::AttentionSignal, + PacketStatus::EndOfMessage, + 2, + BytesMut::new(), ); - - let attention_packet = - encode_packet(PacketType::AttentionSignal, PacketStatus::EndOfMessage, 2, BytesMut::new()); buf.extend_from_slice(&attention_packet); let msg = codec .decode(&mut buf, TdsConnectionState::ReadyForQuery) @@ -782,8 +785,12 @@ mod tests { BatchRequest::new("SELECT 1", [0; 8]) .encode(&mut payload2) .expect("batch encode"); - let followup_packet = - encode_packet(PacketType::SQLBatch, PacketStatus::EndOfMessage, 1, payload2); + let followup_packet = encode_packet( + PacketType::SQLBatch, + PacketStatus::EndOfMessage, + 1, + payload2, + ); buf.extend_from_slice(&followup_packet); let msg = codec diff --git a/src/server/connection.rs b/src/server/connection.rs index 516ed0a6..0a2436d4 100644 --- a/src/server/connection.rs +++ b/src/server/connection.rs @@ -22,10 +22,10 @@ use crate::tds::codec::{ DoneStatus, Encode, FeatureLevel, Packet, PacketHeader, PacketStatus, PacketType, TokenDone, TokenEnvChange, }; -use std::sync::Arc; use crate::tds::Context as TdsContext; -use crate::Error; use crate::EncryptionLevel; +use crate::Error; +use std::sync::Arc; /// Buffer size for read/write operations. const BUFFER_SIZE: usize = 8192; @@ -292,7 +292,11 @@ impl TdsConnection { TdsBackendMessage::Prelogin(message) => { let mut payload = BytesMut::new(); message.encode(&mut payload)?; - let packet_type = match self.metadata.custom.get("prelogin_packet_type").map(String::as_str) + let packet_type = match self + .metadata + .custom + .get("prelogin_packet_type") + .map(String::as_str) { Some("tabular") => PacketType::TabularResult, _ => PacketType::PreLogin, @@ -320,7 +324,9 @@ impl TdsConnection { self.write_payload_as_packets(PacketType::TabularResult, payload, false)?; Ok(()) } - TdsBackendMessage::Packet(packet) => self.codec.encode(TdsBackendMessage::Packet(packet), &mut self.write_buf), + TdsBackendMessage::Packet(packet) => self + .codec + .encode(TdsBackendMessage::Packet(packet), &mut self.write_buf), } } @@ -433,30 +439,24 @@ impl TdsConnection { token.encode_with_columns(payload, &meta.columns) } BackendToken::Order(token) => token.encode(payload), - BackendToken::Done(token) => { - self.encode_done_token( - token, - crate::tds::codec::TokenType::Done, - payload, - done_row_count_bytes, - ) - } - BackendToken::DoneProc(token) => { - self.encode_done_token( - token, - crate::tds::codec::TokenType::DoneProc, - payload, - done_row_count_bytes, - ) - } - BackendToken::DoneInProc(token) => { - self.encode_done_token( - token, - crate::tds::codec::TokenType::DoneInProc, - payload, - done_row_count_bytes, - ) - } + BackendToken::Done(token) => self.encode_done_token( + token, + crate::tds::codec::TokenType::Done, + payload, + done_row_count_bytes, + ), + BackendToken::DoneProc(token) => self.encode_done_token( + token, + crate::tds::codec::TokenType::DoneProc, + payload, + done_row_count_bytes, + ), + BackendToken::DoneInProc(token) => self.encode_done_token( + token, + crate::tds::codec::TokenType::DoneInProc, + payload, + done_row_count_bytes, + ), BackendToken::ReturnStatus(status) => { payload.put_u8(crate::tds::codec::TokenType::ReturnStatus as u8); payload.put_u32_le(status); @@ -478,8 +478,8 @@ impl TdsConnection { if !self.message_in_progress { self.context.reset_packet_id(); } - let packet_size = (self.context.packet_size() as usize) - .saturating_sub(crate::tds::codec::HEADER_BYTES); + let packet_size = + (self.context.packet_size() as usize).saturating_sub(crate::tds::codec::HEADER_BYTES); if packet_size == 0 { return Err(Error::Protocol("invalid packet size".into())); @@ -585,7 +585,9 @@ impl TdsConnectionContext for TdsConnection { TdsConnection::clear_attention(self); } - fn poll_attention<'a>(&'a mut self) -> std::pin::Pin> + Send + 'a>> + fn poll_attention<'a>( + &'a mut self, + ) -> std::pin::Pin> + Send + 'a>> where Self: Sized, { @@ -637,9 +639,7 @@ impl Sink for TdsConnection { } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut() - .poll_flush_buf(cx) - .map_err(Into::into) + self.get_mut().poll_flush_buf(cx).map_err(Into::into) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/src/server/handler.rs b/src/server/handler.rs index c4e82562..cf2eb850 100644 --- a/src/server/handler.rs +++ b/src/server/handler.rs @@ -66,10 +66,10 @@ use std::net::SocketAddr; use futures_util::sink::Sink; -use crate::EncryptionLevel; use crate::server::messages::{AllHeaders, TdsBackendMessage, TransactionDescriptor}; use crate::server::state::TdsConnectionState; use crate::tds::codec::FeatureLevel; +use crate::EncryptionLevel; use crate::Result; /// Well-known connection metadata fields. diff --git a/src/server/messages.rs b/src/server/messages.rs index 40c95831..5d02600b 100644 --- a/src/server/messages.rs +++ b/src/server/messages.rs @@ -4,8 +4,8 @@ use bytes::BytesMut; use crate::tds::codec::{ Packet, PreloginMessage, RpcOption, RpcProcId, TokenAltMetaData, TokenAltRow, TokenColInfo, - TokenColMetaData, TokenColName, TokenDone, TokenEnvChange, TokenError, TokenFedAuthInfo, - TokenFeatureExtAck, TokenInfo, TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow, + TokenColMetaData, TokenColName, TokenDone, TokenEnvChange, TokenError, TokenFeatureExtAck, + TokenFedAuthInfo, TokenInfo, TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow, TokenSessionState, TokenSspi, TokenTabName, }; use enumflags2::BitFlags; diff --git a/src/server/mod.rs b/src/server/mod.rs index 2a457cd7..cdf38915 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -93,13 +93,13 @@ //! - `server-rustls`: Enables rustls-based TLS //! - `integrated-auth-gssapi`: Enables GSSAPI/Kerberos authentication (Unix only) -pub mod backend; pub mod auth; +pub mod backend; pub mod builder; pub mod codec; pub mod connection; -pub mod handler; mod handle_macro; +pub mod handler; pub mod messages; pub mod prepared; pub mod query; @@ -113,14 +113,15 @@ pub mod sp_prepexec; pub mod state; pub mod tls; -pub use backend::{NetBackend, NetListener, NetStream, NetStreamExt}; -pub use builder::{BuiltTdsServer, NotSet, Set, TdsServerBuilder}; +#[cfg(all(unix, feature = "integrated-auth-gssapi"))] +pub use auth::gssapi::GssapiAcceptor; pub use auth::{ AuthBuilder, AuthError, AuthResult, AuthSuccess, DefaultEnvChangeProvider, EnvChangeProvider, FedAuthValidator, LoginInfo, SqlAuthSource, SspiAcceptor, SspiSession, SspiStart, SspiStep, - TdsAuthHandler, - METADATA_APPLICATION, METADATA_DATABASE, METADATA_SERVER, METADATA_USER, + TdsAuthHandler, METADATA_APPLICATION, METADATA_DATABASE, METADATA_SERVER, METADATA_USER, }; +pub use backend::{NetBackend, NetListener, NetStream, NetStreamExt}; +pub use builder::{BuiltTdsServer, NotSet, Set, TdsServerBuilder}; pub use codec::{decode_rpc_params, DecodedRpcParam, RpcParamSet, TdsCodec}; pub use connection::TdsConnection; pub use handler::{ @@ -133,11 +134,19 @@ pub use messages::{ TdsFrontendMessage, TransactionDescriptor, }; pub use prepared::{PreparedHandle, PreparedStatement, ProcedureCache, ProcedureCacheConfig}; +pub use query::{QueryColumn, QueryColumnType, QueryHandler, QueryOutput, SimpleQueryAdapter}; pub use response::{ finish_proc, finish_proc_more, infer_type_info, send_output_param, send_output_params, send_return_status, OutputParameter, ResultSetWriter, }; +pub use router::{RejectUnknownProc, SystemProcRouter, SystemProcRouterBuilder}; pub use server::process_connection; +pub use sp_cursor::{ + parse_cursor_close, parse_cursor_fetch, parse_cursor_open, CursorCache, CursorCacheConfig, + CursorEntry, CursorHandle, ParsedCursorClose, ParsedCursorFetch, ParsedCursorOpen, + SpCursorCloseHandler, SpCursorCloseRpcHandler, SpCursorFetchHandler, SpCursorFetchRpcHandler, + SpCursorOpenHandler, SpCursorOpenRpcHandler, +}; pub use sp_executesql::{ parse_executesql, ExecuteSqlParam, ParsedExecuteSql, SpExecuteSqlHandler, SpExecuteSqlRpcHandler, @@ -147,20 +156,8 @@ pub use sp_prepare::{ PreparedStatementRpcHandler, SpExecuteHandler, SpExecuteRpcHandler, SpPrepareHandler, SpPrepareRpcHandler, SpUnprepareHandler, SpUnprepareRpcHandler, }; -pub use sp_prepexec::{ - parse_prepexec, ParsedPrepExec, SpPrepExecHandler, SpPrepExecRpcHandler, -}; -pub use sp_cursor::{ - parse_cursor_close, parse_cursor_fetch, parse_cursor_open, CursorCache, CursorCacheConfig, - CursorEntry, CursorHandle, ParsedCursorClose, ParsedCursorFetch, ParsedCursorOpen, - SpCursorCloseHandler, SpCursorCloseRpcHandler, SpCursorFetchHandler, SpCursorFetchRpcHandler, - SpCursorOpenHandler, SpCursorOpenRpcHandler, -}; -pub use query::{QueryColumn, QueryColumnType, QueryHandler, QueryOutput, SimpleQueryAdapter}; -pub use router::{RejectUnknownProc, SystemProcRouter, SystemProcRouterBuilder}; +pub use sp_prepexec::{parse_prepexec, ParsedPrepExec, SpPrepExecHandler, SpPrepExecRpcHandler}; pub use state::TdsConnectionState; -pub use tls::{NoTls, TlsAccept, TlsStream}; #[cfg(feature = "server-rustls")] pub use tls::RustlsAcceptor; -#[cfg(all(unix, feature = "integrated-auth-gssapi"))] -pub use auth::gssapi::GssapiAcceptor; +pub use tls::{NoTls, TlsAccept, TlsStream}; diff --git a/src/server/prepared.rs b/src/server/prepared.rs index bcce0a17..a9094aec 100644 --- a/src/server/prepared.rs +++ b/src/server/prepared.rs @@ -154,7 +154,7 @@ impl Default for ProcedureCacheConfig { fn default() -> Self { Self { max_capacity: 1000, - max_age: Duration::from_secs(60 * 60), // 1 hour + max_age: Duration::from_secs(60 * 60), // 1 hour idle_timeout: Duration::from_secs(30 * 60), // 30 minutes } } @@ -475,11 +475,7 @@ mod tests { fn procedure_cache_prepare_and_get() { let mut cache = ProcedureCache::new(1); - let handle = cache.prepare( - "SELECT 1".to_string(), - vec![], - vec![], - ); + let handle = cache.prepare("SELECT 1".to_string(), vec![], vec![]); assert!(cache.contains(&handle)); assert_eq!(cache.len(), 1); @@ -495,7 +491,11 @@ mod tests { let handle = cache.prepare( "SELECT @p1".to_string(), - vec![TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Intn, 4, None))], + vec![TypeInfo::VarLenSized(VarLenContext::new( + VarLenType::Intn, + 4, + None, + ))], vec!["@p1".to_string()], ); @@ -590,11 +590,7 @@ mod tests { #[test] fn prepared_statement_record_execution() { - let mut stmt = PreparedStatement::new( - "SELECT 1".to_string(), - vec![], - vec![], - ); + let mut stmt = PreparedStatement::new("SELECT 1".to_string(), vec![], vec![]); assert_eq!(stmt.execution_count, 0); diff --git a/src/server/query.rs b/src/server/query.rs index a64ac144..d1005182 100644 --- a/src/server/query.rs +++ b/src/server/query.rs @@ -224,11 +224,9 @@ impl QueryColumnType { QueryColumnType::Date => { TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Daten, 3, None)) } - QueryColumnType::Time(scale) => TypeInfo::VarLenSized(VarLenContext::new( - VarLenType::Timen, - scale as usize, - None, - )), + QueryColumnType::Time(scale) => { + TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Timen, scale as usize, None)) + } QueryColumnType::DateTime2(scale) => TypeInfo::VarLenSized(VarLenContext::new( VarLenType::Datetime2, scale as usize, @@ -639,7 +637,7 @@ where pub async fn error_message(&mut self, number: u32, message: impl Into) -> Result<()> { let error = TokenError::new( number, - 0, // state + 0, // state 16, // class (severity) - 16 is "user error" message.into(), String::new(), // server diff --git a/src/server/response.rs b/src/server/response.rs index d6854ca0..7bc9b0e3 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -119,7 +119,10 @@ impl<'a> OutputParameter<'a> { /// let output = OutputParameter::from_input(input_param, ColumnData::I32(Some(42))); /// writer.send_output_param(output).await?; /// ``` - pub fn from_input(input: &crate::server::codec::DecodedRpcParam, value: ColumnData<'a>) -> Self { + pub fn from_input( + input: &crate::server::codec::DecodedRpcParam, + value: ColumnData<'a>, + ) -> Self { Self { name: Cow::Owned(input.name.clone()), value, @@ -186,29 +189,19 @@ fn column_data_into_static(value: ColumnData<'_>) -> ColumnData<'static> { ColumnData::DateTimeOffset(v) => ColumnData::DateTimeOffset(v), // Cow types - convert to owned - ColumnData::String(s) => { - ColumnData::String(s.map(|cow| Cow::Owned(cow.into_owned()))) - } - ColumnData::Binary(b) => { - ColumnData::Binary(b.map(|cow| Cow::Owned(cow.into_owned()))) - } - ColumnData::Xml(x) => { - ColumnData::Xml(x.map(|cow| Cow::Owned(cow.into_owned()))) - } - ColumnData::Udt(u) => { - ColumnData::Udt(u.map(|cow| Cow::Owned(cow.into_owned()))) - } - ColumnData::Variant(v) => { - ColumnData::Variant(v.map(|var| var.into_owned())) - } - ColumnData::Tvp(t) => { - ColumnData::Tvp(t.map(|tvp| tvp_data_into_static(tvp))) - } + ColumnData::String(s) => ColumnData::String(s.map(|cow| Cow::Owned(cow.into_owned()))), + ColumnData::Binary(b) => ColumnData::Binary(b.map(|cow| Cow::Owned(cow.into_owned()))), + ColumnData::Xml(x) => ColumnData::Xml(x.map(|cow| Cow::Owned(cow.into_owned()))), + ColumnData::Udt(u) => ColumnData::Udt(u.map(|cow| Cow::Owned(cow.into_owned()))), + ColumnData::Variant(v) => ColumnData::Variant(v.map(|var| var.into_owned())), + ColumnData::Tvp(t) => ColumnData::Tvp(t.map(|tvp| tvp_data_into_static(tvp))), } } /// Convert TvpData to a 'static version. -fn tvp_data_into_static(tvp: crate::tds::codec::TvpData<'_>) -> crate::tds::codec::TvpData<'static> { +fn tvp_data_into_static( + tvp: crate::tds::codec::TvpData<'_>, +) -> crate::tds::codec::TvpData<'static> { crate::tds::codec::TvpData { db_name: Cow::Owned(tvp.db_name.into_owned()), schema: Cow::Owned(tvp.schema.into_owned()), @@ -251,8 +244,12 @@ pub fn infer_type_info(value: &ColumnData<'_>, collation: Collation) -> TypeInfo ColumnData::I16(_) => TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Intn, 2, None)), ColumnData::I32(_) => TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Intn, 4, None)), ColumnData::I64(_) => TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Intn, 8, None)), - ColumnData::F32(_) => TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Floatn, 4, None)), - ColumnData::F64(_) => TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Floatn, 8, None)), + ColumnData::F32(_) => { + TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Floatn, 4, None)) + } + ColumnData::F64(_) => { + TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Floatn, 8, None)) + } ColumnData::Bit(_) => TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Bitn, 1, None)), ColumnData::String(s) => { // Use NVARCHAR with appropriate length @@ -262,7 +259,11 @@ pub fn infer_type_info(value: &ColumnData<'_>, collation: Collation) -> TypeInfo .map(|s| s.len().saturating_mul(2)) // UTF-16 encoding .unwrap_or(0); // Use max if > 4000 chars (8000 bytes), otherwise use actual length or default - let type_len = if len > 8000 { 0xFFFF } else { std::cmp::max(len, 8000) }; + let type_len = if len > 8000 { + 0xFFFF + } else { + std::cmp::max(len, 8000) + }; TypeInfo::VarLenSized(VarLenContext::new( VarLenType::NVarchar, type_len, @@ -275,7 +276,11 @@ pub fn infer_type_info(value: &ColumnData<'_>, collation: Collation) -> TypeInfo ColumnData::Binary(b) => { let len = b.as_ref().map(|b| b.len()).unwrap_or(0); // Use max if > 8000, otherwise use actual length or default - let type_len = if len > 8000 { 0xFFFF } else { std::cmp::max(len, 8000) }; + let type_len = if len > 8000 { + 0xFFFF + } else { + std::cmp::max(len, 8000) + }; TypeInfo::VarLenSized(VarLenContext::new(VarLenType::BigVarBin, type_len, None)) } ColumnData::Numeric(n) => { @@ -334,10 +339,16 @@ pub fn infer_type_info(value: &ColumnData<'_>, collation: Collation) -> TypeInfo } ColumnData::Udt(u) => { let len = u.as_ref().map(|b| b.len()).unwrap_or(0); - let type_len = if len > 8000 { 0xFFFF } else { std::cmp::max(len, 8000) }; + let type_len = if len > 8000 { + 0xFFFF + } else { + std::cmp::max(len, 8000) + }; TypeInfo::VarLenSized(VarLenContext::new(VarLenType::BigVarBin, type_len, None)) } - ColumnData::Variant(_) => TypeInfo::SsVariant(crate::tds::codec::SsVariantInfo { max_len: 8016 }), + ColumnData::Variant(_) => { + TypeInfo::SsVariant(crate::tds::codec::SsVariantInfo { max_len: 8016 }) + } ColumnData::Tvp(_) => TypeInfo::Tvp(crate::tds::codec::TvpInfo { db_name: String::new(), schema: String::new(), @@ -365,7 +376,9 @@ where columns: columns.clone(), }; client - .send(TdsBackendMessage::TokenPartial(BackendToken::ColMetaData(token))) + .send(TdsBackendMessage::TokenPartial(BackendToken::ColMetaData( + token, + ))) .await?; Ok(Self { client, columns }) @@ -537,7 +550,8 @@ where .into(), ) })?; - let mut dst_ti = BytesMutWithTypeInfo::new(&mut payload).with_type_info(&column.base.ty); + let mut dst_ti = + BytesMutWithTypeInfo::new(&mut payload).with_type_info(&column.base.ty); value.encode(&mut dst_ti)?; while payload.len() >= chunk_size { let chunk = payload.split_to(chunk_size); @@ -571,11 +585,7 @@ where /// Send a batch of rows using a columnar accessor. /// /// The accessor should provide a value for each (row, col) pair. - pub async fn send_batch_rows<'b, F>( - &mut self, - rows: usize, - mut value_at: F, - ) -> Result<()> + pub async fn send_batch_rows<'b, F>(&mut self, rows: usize, mut value_at: F) -> Result<()> where F: FnMut(usize, usize) -> ColumnData<'b>, { @@ -629,11 +639,7 @@ where } /// Send a batch of rows using NBCROW encoding. - pub async fn send_batch_rows_nbc<'b, F>( - &mut self, - rows: usize, - mut value_at: F, - ) -> Result<()> + pub async fn send_batch_rows_nbc<'b, F>(&mut self, rows: usize, mut value_at: F) -> Result<()> where F: FnMut(usize, usize) -> ColumnData<'b>, { diff --git a/src/server/router.rs b/src/server/router.rs index a38db6d4..70d4cdec 100644 --- a/src/server/router.rs +++ b/src/server/router.rs @@ -52,11 +52,7 @@ use crate::{Error, Result}; pub struct RejectUnknownProc; impl RpcHandler for RejectUnknownProc { - fn on_rpc<'a, C>( - &'a self, - _client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, _client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -213,8 +209,7 @@ impl std::fmt::Debug } } -impl RpcHandler - for SystemProcRouter +impl RpcHandler for SystemProcRouter where ES: SpExecuteSqlHandler, P: SpPrepareHandler, @@ -226,11 +221,7 @@ where CC: SpCursorCloseHandler, F: RpcHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -327,9 +318,7 @@ pub struct SystemProcRouterBuilder { fallback: F, } -impl Default - for SystemProcRouterBuilder<(), (), (), (), (), (), (), (), RejectUnknownProc> -{ +impl Default for SystemProcRouterBuilder<(), (), (), (), (), (), (), (), RejectUnknownProc> { fn default() -> Self { Self::new() } diff --git a/src/server/server.rs b/src/server/server.rs index 4717b2a3..e4fc8240 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -15,8 +15,8 @@ use crate::server::messages::{AllHeaders, BackendToken, TdsBackendMessage, TdsFr use crate::server::state::TdsConnectionState; use crate::server::tls::{MaybeTlsStream, NoTls, TlsAccept}; use crate::tds::codec::{DoneStatus, TokenDone}; -use crate::Error; use crate::EncryptionLevel; +use crate::Error; /// Default startup timeout (60 seconds). const STARTUP_TIMEOUT: Duration = Duration::from_secs(60); @@ -211,8 +211,7 @@ fn apply_request_headers( conn: &mut TdsConnection>, headers: &AllHeaders, request_flags: crate::server::messages::RequestFlags, -) -where +) where S: NetStream, T: TlsAccept, { @@ -237,7 +236,9 @@ where EncryptionLevel::NotSupported => Ok(()), _ => { let Some(acceptor) = tls_acceptor.as_ref() else { - return Err(Error::Protocol("TLS requested but no acceptor configured".into())); + return Err(Error::Protocol( + "TLS requested but no acceptor configured".into(), + )); }; if conn.is_secure() { diff --git a/src/server/sp_cursor.rs b/src/server/sp_cursor.rs index 2d137394..b3e075fd 100644 --- a/src/server/sp_cursor.rs +++ b/src/server/sp_cursor.rs @@ -603,7 +603,11 @@ mod tests { use crate::tds::codec::RpcStatus; DecodedRpcParam { name: name.to_string(), - flags: if output { RpcStatus::ByRefValue.into() } else { BitFlags::empty() }, + flags: if output { + RpcStatus::ByRefValue.into() + } else { + BitFlags::empty() + }, ty: TypeInfo::FixedLen(FixedLenType::Int4), value: ColumnData::I32(value), } diff --git a/src/server/sp_executesql.rs b/src/server/sp_executesql.rs index 5ae9c831..16d8e8b6 100644 --- a/src/server/sp_executesql.rs +++ b/src/server/sp_executesql.rs @@ -407,7 +407,7 @@ pub fn parse_executesql(params: RpcParamSet) -> Result ColumnData::String(None) => None, // NULL ColumnData::I32(Some(0)) => None, // Special marker for no params ColumnData::I32(None) => None, // NULL as I32 - _ => None, // Other types treated as no param defs + _ => None, // Other types treated as no param defs } } else { None @@ -534,11 +534,7 @@ impl RpcHandler for SpExecuteSqlRpcHandler where H: SpExecuteSqlHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -634,10 +630,7 @@ mod tests { #[test] fn test_parse_executesql_empty_param_defs() { - let params = vec![ - make_string_param("", "SELECT 1"), - make_string_param("", ""), - ]; + let params = vec![make_string_param("", "SELECT 1"), make_string_param("", "")]; let param_set = RpcParamSet::new(params); let parsed = parse_executesql(param_set).unwrap(); diff --git a/src/server/sp_prepare.rs b/src/server/sp_prepare.rs index a335c4de..56a7efd3 100644 --- a/src/server/sp_prepare.rs +++ b/src/server/sp_prepare.rs @@ -609,11 +609,7 @@ impl RpcHandler for SpPrepareRpcHandler where H: SpPrepareHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -631,20 +627,12 @@ where Ok(()) } Some(other) => Err(Error::Protocol( - format!( - "SpPrepareRpcHandler: unsupported RPC proc ID {:?}", - other - ) - .into(), + format!("SpPrepareRpcHandler: unsupported RPC proc ID {:?}", other).into(), )), None => { let name = message.proc_name.as_deref().unwrap_or(""); Err(Error::Protocol( - format!( - "SpPrepareRpcHandler: unsupported RPC procedure '{}'", - name - ) - .into(), + format!("SpPrepareRpcHandler: unsupported RPC procedure '{}'", name).into(), )) } } @@ -700,11 +688,7 @@ impl RpcHandler for SpExecuteRpcHandler where H: SpExecuteHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -718,20 +702,12 @@ where self.inner.execute(client, request).await } Some(other) => Err(Error::Protocol( - format!( - "SpExecuteRpcHandler: unsupported RPC proc ID {:?}", - other - ) - .into(), + format!("SpExecuteRpcHandler: unsupported RPC proc ID {:?}", other).into(), )), None => { let name = message.proc_name.as_deref().unwrap_or(""); Err(Error::Protocol( - format!( - "SpExecuteRpcHandler: unsupported RPC procedure '{}'", - name - ) - .into(), + format!("SpExecuteRpcHandler: unsupported RPC procedure '{}'", name).into(), )) } } @@ -787,11 +763,7 @@ impl RpcHandler for SpUnprepareRpcHandler where H: SpUnprepareHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -805,11 +777,7 @@ where self.inner.unprepare(client, request).await } Some(other) => Err(Error::Protocol( - format!( - "SpUnprepareRpcHandler: unsupported RPC proc ID {:?}", - other - ) - .into(), + format!("SpUnprepareRpcHandler: unsupported RPC proc ID {:?}", other).into(), )), None => { let name = message.proc_name.as_deref().unwrap_or(""); @@ -914,11 +882,7 @@ where E: SpExecuteHandler, U: SpUnprepareHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -1061,7 +1025,10 @@ mod tests { let result = parse_prepare(param_set); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("expected at least 3")); + assert!(result + .unwrap_err() + .to_string() + .contains("expected at least 3")); } // ------------------------------------------------------------------------- @@ -1070,10 +1037,7 @@ mod tests { #[test] fn test_parse_execute_basic() { - let params = vec![ - make_i32_param("@handle", 42), - make_i32_param("@id", 100), - ]; + let params = vec![make_i32_param("@handle", 42), make_i32_param("@id", 100)]; let param_set = RpcParamSet::new(params); let parsed = parse_execute(param_set).unwrap(); @@ -1207,15 +1171,15 @@ mod tests { assert_eq!(parsed.param_defs(), Some("@a int, @b varchar(50)")); assert_eq!(parsed.options(), 5); // Check that handle_type_info is preserved - assert!(matches!(parsed.handle_type_info(), TypeInfo::FixedLen(FixedLenType::Int4))); + assert!(matches!( + parsed.handle_type_info(), + TypeInfo::FixedLen(FixedLenType::Int4) + )); } #[test] fn test_parsed_execute_into_params() { - let params = vec![ - make_i32_param("@handle", 1), - make_i32_param("@id", 42), - ]; + let params = vec![make_i32_param("@handle", 1), make_i32_param("@id", 42)]; let param_set = RpcParamSet::new(params); let parsed = parse_execute(param_set).unwrap(); diff --git a/src/server/sp_prepexec.rs b/src/server/sp_prepexec.rs index 5ab09ba8..de23e965 100644 --- a/src/server/sp_prepexec.rs +++ b/src/server/sp_prepexec.rs @@ -194,11 +194,7 @@ impl RpcHandler for SpPrepExecRpcHandler where H: SpPrepExecHandler, { - fn on_rpc<'a, C>( - &'a self, - client: &'a mut C, - message: RpcMessage, - ) -> BoxFuture<'a, Result<()>> + fn on_rpc<'a, C>(&'a self, client: &'a mut C, message: RpcMessage) -> BoxFuture<'a, Result<()>> where C: TdsClient + 'a, { @@ -211,20 +207,13 @@ where Ok(()) } Some(other) => Err(Error::Protocol( - format!( - "SpPrepExecRpcHandler: unsupported RPC proc ID {:?}", - other - ) - .into(), + format!("SpPrepExecRpcHandler: unsupported RPC proc ID {:?}", other).into(), )), None => { let name = message.proc_name.as_deref().unwrap_or(""); Err(Error::Protocol( - format!( - "SpPrepExecRpcHandler: unsupported RPC procedure '{}'", - name - ) - .into(), + format!("SpPrepExecRpcHandler: unsupported RPC procedure '{}'", name) + .into(), )) } } diff --git a/src/tds.rs b/src/tds.rs index 11694f4a..3f5d7c64 100644 --- a/src/tds.rs +++ b/src/tds.rs @@ -3,9 +3,9 @@ mod collation; mod context; pub mod numeric; pub mod stream; +pub mod time; #[cfg(feature = "server-rustls")] pub(crate) mod tls; -pub mod time; pub mod xml; pub use collation::Collation; diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 460e6072..b6ab6630 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -20,20 +20,20 @@ mod xml; use super::{Encode, FixedLenType, TypeInfo, VarLenContext, VarLenType}; use crate::tds::time::{Date, DateTime2, DateTimeOffset, Time}; +use crate::tds::Context; use crate::{ tds::{time::DateTime, time::SmallDateTime, xml::XmlData, Collation, Numeric}, SqlReadBytes, }; -use enumflags2::BitFlags; use bytes::{Buf, BufMut, BytesMut}; pub(crate) use bytes_mut_with_type_info::BytesMutWithTypeInfo; +use enumflags2::BitFlags; +use futures_util::io::{AsyncRead, AsyncReadExt}; use std::borrow::{BorrowMut, Cow}; use std::future::Future; use std::pin::Pin; use std::task::Poll; use uuid::Uuid; -use crate::tds::Context; -use futures_util::io::{AsyncRead, AsyncReadExt}; const MAX_NVARCHAR_SIZE: usize = 1 << 30; @@ -258,10 +258,7 @@ impl<'a> VariantData<'a> { } /// Build a typed sql_variant payload from a base type and value. - pub fn from_typed( - ty: TypeInfo, - value: ColumnData<'_>, - ) -> crate::Result> { + pub fn from_typed(ty: TypeInfo, value: ColumnData<'_>) -> crate::Result> { let payload = encode_variant_payload(ty, value)?; Ok(VariantData::new(payload)) } @@ -414,9 +411,10 @@ fn encode_variant_money_bytes(value: f64, len: usize) -> crate::Result> fn encode_variant_numeric_bytes(value: Numeric) -> crate::Result> { let raw = value.value(); - let abs = raw.checked_abs().ok_or_else(|| { - crate::Error::BulkInput("sql_variant: numeric overflow".into()) - })? as u128; + let abs = raw + .checked_abs() + .ok_or_else(|| crate::Error::BulkInput("sql_variant: numeric overflow".into()))? + as u128; let mut buf = BytesMut::with_capacity(17); buf.put_u8(if raw < 0 { 0 } else { 1 }); buf.put_u128_le(abs); @@ -433,11 +431,7 @@ fn encode_variant_non_unicode( .max_buffer_length_from_utf8_without_replacement(value.len()) .unwrap(); let mut bytes = Vec::with_capacity(len); - let (res, _) = encoder.encode_from_utf8_to_vec_without_replacement( - value, - &mut bytes, - true, - ); + let (res, _) = encoder.encode_from_utf8_to_vec_without_replacement(value, &mut bytes, true); if let encoding_rs::EncoderResult::Unmappable(_) = res { return Err(crate::Error::Encoding( "sql_variant: unrepresentable character".into(), @@ -638,7 +632,8 @@ fn encode_variant_payload(ty: TypeInfo, value: ColumnData<'_>) -> crate::Result< ) if matches!( ty, VarLenType::Decimaln | VarLenType::Numericn | VarLenType::Decimal | VarLenType::Numeric - ) => { + ) => + { if num.scale() != scale { return Err(crate::Error::BulkInput( format!( @@ -678,9 +673,9 @@ fn encode_variant_payload(ty: TypeInfo, value: ColumnData<'_>) -> crate::Result< "sql_variant: char length exceeds u16".into(), )); } - let collation = ctx.collation().ok_or_else(|| { - crate::Error::BulkInput("sql_variant: missing collation".into()) - })?; + let collation = ctx + .collation() + .ok_or_else(|| crate::Error::BulkInput("sql_variant: missing collation".into()))?; let bytes = encode_variant_non_unicode(value.as_ref(), collation, max_len)?; prop_bytes.put_u32_le(collation.info()); prop_bytes.put_u8(collation.sort_id()); @@ -697,9 +692,9 @@ fn encode_variant_payload(ty: TypeInfo, value: ColumnData<'_>) -> crate::Result< "sql_variant: nchar length exceeds u16".into(), )); } - let collation = ctx.collation().ok_or_else(|| { - crate::Error::BulkInput("sql_variant: missing collation".into()) - })?; + let collation = ctx + .collation() + .ok_or_else(|| crate::Error::BulkInput("sql_variant: missing collation".into()))?; let bytes = encode_variant_unicode(value.as_ref(), max_len)?; prop_bytes.put_u32_le(collation.info()); prop_bytes.put_u8(collation.sort_id()); @@ -751,9 +746,7 @@ fn encode_variant_payload(ty: TypeInfo, value: ColumnData<'_>) -> crate::Result< Ok(payload.to_vec()) } -async fn decode_variant_payload( - payload: &[u8], -) -> crate::Result<(TypeInfo, ColumnData<'static>)> { +async fn decode_variant_payload(payload: &[u8]) -> crate::Result<(TypeInfo, ColumnData<'static>)> { let (base_type, prop_bytes, value_bytes) = split_variant_payload(payload)?; let mut reader = VariantReader::new(BytesMut::from(value_bytes)); @@ -778,7 +771,9 @@ async fn decode_variant_payload( FixedLenType::Float8 => ColumnData::F64(Some(reader.read_f64_le().await?)), FixedLenType::Money => money::decode(&mut reader, 8).await?, FixedLenType::Money4 => money::decode(&mut reader, 4).await?, - FixedLenType::Datetime => ColumnData::DateTime(Some(DateTime::decode(&mut reader).await?)), + FixedLenType::Datetime => { + ColumnData::DateTime(Some(DateTime::decode(&mut reader).await?)) + } FixedLenType::Datetime4 => { ColumnData::SmallDateTime(Some(SmallDateTime::decode(&mut reader).await?)) } @@ -846,11 +841,8 @@ async fn decode_variant_payload( "sql_variant payload has trailing bytes".into(), )); } - let ty = TypeInfo::VarLenSized(VarLenContext::new( - VarLenType::Timen, - scale as usize, - None, - )); + let ty = + TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Timen, scale as usize, None)); Ok((ty, value)) } VarLenType::Datetime2 => { @@ -936,9 +928,7 @@ async fn decode_variant_payload( let mut numeric_reader = VariantReader::new(buf); let numeric = Numeric::decode(&mut numeric_reader, scale) .await? - .ok_or_else(|| { - crate::Error::Protocol("sql_variant: numeric null".into()) - })?; + .ok_or_else(|| crate::Error::Protocol("sql_variant: numeric null".into()))?; if numeric_reader.remaining() != 0 { return Err(crate::Error::Protocol( "sql_variant payload has trailing bytes".into(), @@ -1156,9 +1146,8 @@ async fn decode_variant_payload( )); } let len = value_bytes.len(); - let len_u8 = u8::try_from(len).map_err(|_| { - crate::Error::Protocol("sql_variant: money length overflow".into()) - })?; + let len_u8 = u8::try_from(len) + .map_err(|_| crate::Error::Protocol("sql_variant: money length overflow".into()))?; let value = money::decode(&mut reader, len_u8).await?; if reader.remaining() != 0 { return Err(crate::Error::Protocol( @@ -1189,11 +1178,7 @@ async fn decode_variant_payload( "sql_variant payload has trailing bytes".into(), )); } - let ty = TypeInfo::VarLenSized(VarLenContext::new( - VarLenType::Datetimen, - len, - None, - )); + let ty = TypeInfo::VarLenSized(VarLenContext::new(VarLenType::Datetimen, len, None)); Ok((ty, value)) } _ => Err(crate::Error::Protocol( @@ -1286,9 +1271,7 @@ impl<'a> ColumnData<'a> { VarLenType::Decimaln | VarLenType::Numericn | VarLenType::Decimal - | VarLenType::Numeric => { - ColumnData::Numeric(Numeric::decode(src, *scale).await?) - } + | VarLenType::Numeric => ColumnData::Numeric(Numeric::decode(src, *scale).await?), _ => { return Err(crate::Error::Protocol( format!("unexpected precision type {:?}", ty).into(), @@ -2112,11 +2095,7 @@ where } rows.push(row); } - _ => { - return Err(crate::Error::Protocol( - "tvp: invalid row token".into(), - )) - } + _ => return Err(crate::Error::Protocol("tvp: invalid row token".into())), } } @@ -2165,9 +2144,7 @@ fn encode_tvp_value<'a>( for row in rows { if row.len() != columns.len() { - return Err(crate::Error::BulkInput( - "tvp: row length mismatch".into(), - )); + return Err(crate::Error::BulkInput("tvp: row length mismatch".into())); } dst.put_u8(TVP_ROW_TOKEN); for (value, column) in row.into_iter().zip(columns.iter()) { @@ -2228,19 +2205,15 @@ fn encode_text_value<'a>( let bytes = match ty { VarLenType::Text => { - let collation = collation.ok_or_else(|| { - crate::Error::BulkInput("text: missing collation".into()) - })?; + let collation = collation + .ok_or_else(|| crate::Error::BulkInput("text: missing collation".into()))?; let mut encoder = collation.encoding()?.new_encoder(); let len = encoder .max_buffer_length_from_utf8_without_replacement(value.len()) .unwrap(); let mut buf = Vec::with_capacity(len); - let (res, _) = encoder.encode_from_utf8_to_vec_without_replacement( - value, - &mut buf, - true, - ); + let (res, _) = + encoder.encode_from_utf8_to_vec_without_replacement(value, &mut buf, true); if let encoding_rs::EncoderResult::Unmappable(_) = res { return Err(crate::Error::Encoding( "text: unrepresentable character".into(), @@ -2255,11 +2228,7 @@ fn encode_text_value<'a>( } buf } - _ => { - return Err(crate::Error::Protocol( - "text: unsupported type".into(), - )) - } + _ => return Err(crate::Error::Protocol("text: unsupported type".into())), }; if bytes.len() > u32::MAX as usize { @@ -2311,16 +2280,14 @@ fn encode_short_len_string<'a>( )); } - let collation = collation.ok_or_else(|| { - crate::Error::BulkInput("char/varchar: missing collation".into()) - })?; + let collation = collation + .ok_or_else(|| crate::Error::BulkInput("char/varchar: missing collation".into()))?; let mut encoder = collation.encoding()?.new_encoder(); let len = encoder .max_buffer_length_from_utf8_without_replacement(value.len()) .unwrap(); let mut bytes = Vec::with_capacity(len); - let (res, _) = - encoder.encode_from_utf8_to_vec_without_replacement(value, &mut bytes, true); + let (res, _) = encoder.encode_from_utf8_to_vec_without_replacement(value, &mut bytes, true); if let encoding_rs::EncoderResult::Unmappable(_) = res { return Err(crate::Error::Encoding( "char/varchar: unrepresentable character".into(), @@ -2437,11 +2404,7 @@ mod tests { name: "label".into(), user_type: 0, flags: ColumnFlag::Nullable.into(), - ty: TypeInfo::VarLenSized(VarLenContext::new( - VarLenType::NVarchar, - 40, - collation, - )), + ty: TypeInfo::VarLenSized(VarLenContext::new(VarLenType::NVarchar, 40, collation)), }, ]; let rows = vec![ @@ -2453,9 +2416,7 @@ mod tests { ]; // Type name/schema/db_name live in the TypeInfo header, not the TVP data // payload, so round-tripped TvpData always has empty names. - let tvp = TvpData::new("") - .columns(columns) - .rows(rows); + let tvp = TvpData::new("").columns(columns).rows(rows); test_round_trip( TypeInfo::Tvp(TvpInfo { db_name: "db".into(), @@ -3147,12 +3108,10 @@ mod tests { async fn ssvariant_typed_payload_round_trip() { let ty = TypeInfo::FixedLen(FixedLenType::Int4); let value = ColumnData::I32(Some(42)); - let payload = VariantData::from_typed(ty.clone(), value.clone()) - .expect("typed variant payload"); - let (decoded_ty, decoded_value) = payload - .decode_typed() - .await - .expect("decode typed variant"); + let payload = + VariantData::from_typed(ty.clone(), value.clone()).expect("typed variant payload"); + let (decoded_ty, decoded_value) = + payload.decode_typed().await.expect("decode typed variant"); assert_eq!(decoded_ty, ty); assert_eq!(decoded_value, value); } @@ -3231,7 +3190,10 @@ mod tests { // Decode: first read the TypeInfo header, then decode the value let reader = &mut buf.into_sql_read_bytes(); let ti = TypeInfo::decode(reader).await.expect("TypeInfo decode"); - assert!(matches!(ti, TypeInfo::SsVariant(SsVariantInfo { max_len: 8016 }))); + assert!(matches!( + ti, + TypeInfo::SsVariant(SsVariantInfo { max_len: 8016 }) + )); let decoded = ColumnData::decode(reader, &ti) .await @@ -3283,12 +3245,10 @@ mod tests { )), }, ]) - .rows(vec![ - vec![ - ColumnData::I32(Some(1)), - ColumnData::String(Some("hello".into())), - ], - ]); + .rows(vec![vec![ + ColumnData::I32(Some(1)), + ColumnData::String(Some("hello".into())), + ]]); let mut buf = BytesMut::new(); let mut dst = BytesMutWithTypeInfo::new(&mut buf); @@ -3429,14 +3389,12 @@ mod tests { let cd = ColumnData::Variant(Some(variant.clone())); // FromSql (borrowed) - let result: Option<&VariantData<'static>> = - <&VariantData<'static>>::from_sql(&cd).unwrap(); + let result: Option<&VariantData<'static>> = <&VariantData<'static>>::from_sql(&cd).unwrap(); assert!(result.is_some()); assert_eq!(result.unwrap().payload(), variant.payload()); // FromSqlOwned - let result: Option> = - VariantData::from_sql_owned(cd).unwrap(); + let result: Option> = VariantData::from_sql_owned(cd).unwrap(); assert!(result.is_some()); } @@ -3457,14 +3415,12 @@ mod tests { let cd = ColumnData::Tvp(Some(tvp)); // FromSql (borrowed) - let result: Option<&TvpData<'static>> = - <&TvpData<'static>>::from_sql(&cd).unwrap(); + let result: Option<&TvpData<'static>> = <&TvpData<'static>>::from_sql(&cd).unwrap(); assert!(result.is_some()); assert_eq!(result.unwrap().rows.len(), 1); // FromSqlOwned - let result: Option> = - TvpData::from_sql_owned(cd).unwrap(); + let result: Option> = TvpData::from_sql_owned(cd).unwrap(); assert!(result.is_some()); } } diff --git a/src/tds/codec/column_data/string.rs b/src/tds/codec/column_data/string.rs index a60638b0..ff0fe466 100644 --- a/src/tds/codec/column_data/string.rs +++ b/src/tds/codec/column_data/string.rs @@ -70,7 +70,8 @@ where match ty { VarLenType::Char | VarLenType::VarChar => { - let collation = collation.ok_or_else(|| Error::Protocol("varchar: missing collation".into()))?; + let collation = + collation.ok_or_else(|| Error::Protocol("varchar: missing collation".into()))?; let encoder = collation.encoding()?; let s = encoder .decode_without_bom_handling_and_without_replacement(buf.as_ref()) diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index 7f8c05d8..b4895635 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -480,8 +480,7 @@ impl<'a> Decode for LoginMessage<'a> { BitFlags::from_bits(cursor.read_u8()?).expect("option_flags_1 verification"); ret.option_flags_2 = BitFlags::from_bits(cursor.read_u8()?).expect("option_flags_2 verification"); - ret.type_flags = - BitFlags::from_bits(cursor.read_u8()?).expect("type_flags verification"); + ret.type_flags = BitFlags::from_bits(cursor.read_u8()?).expect("type_flags verification"); ret.option_flags_3 = BitFlags::from_bits(cursor.read_u8()?).expect("option_flags_3 verification"); diff --git a/src/tds/codec/rpc_request.rs b/src/tds/codec/rpc_request.rs index e4d63ef8..397ce457 100644 --- a/src/tds/codec/rpc_request.rs +++ b/src/tds/codec/rpc_request.rs @@ -180,11 +180,7 @@ mod tests { #[test] fn encode_named_proc_writes_utf16_length_prefixed() { - let req = TokenRpcRequest::new( - Cow::Borrowed("my_sp"), - Vec::new(), - [0; 8], - ); + let req = TokenRpcRequest::new(Cow::Borrowed("my_sp"), Vec::new(), [0; 8]); let mut buf = BytesMut::new(); req.encode(&mut buf).unwrap(); diff --git a/src/tds/codec/token.rs b/src/tds/codec/token.rs index 37a83844..245ca327 100644 --- a/src/tds/codec/token.rs +++ b/src/tds/codec/token.rs @@ -1,40 +1,40 @@ -mod token_col_metadata; +mod token_alt_metadata; +mod token_alt_row; mod token_col_info; +mod token_col_metadata; mod token_col_name; mod token_done; mod token_env_change; mod token_error; -mod token_fed_auth_info; mod token_feature_ext_ack; +mod token_fed_auth_info; mod token_info; mod token_login_ack; mod token_order; -mod token_alt_metadata; -mod token_alt_row; mod token_return_value; mod token_row; -mod token_sspi; mod token_session_state; +mod token_sspi; mod token_tab_name; mod token_type; -pub use token_col_metadata::*; +pub use token_alt_metadata::*; +pub use token_alt_row::*; pub use token_col_info::*; +pub use token_col_metadata::*; pub use token_col_name::*; pub use token_done::*; pub use token_env_change::*; pub use token_error::*; -pub use token_fed_auth_info::*; pub use token_feature_ext_ack::*; +pub use token_fed_auth_info::*; pub use token_info::*; pub use token_login_ack::*; pub use token_order::*; -pub use token_alt_metadata::*; -pub use token_alt_row::*; pub use token_return_value::*; pub use token_row::*; -pub use token_sspi::*; pub use token_session_state::*; +pub use token_sspi::*; pub use token_tab_name::*; pub use token_type::*; diff --git a/src/tds/codec/token/token_col_metadata.rs b/src/tds/codec/token/token_col_metadata.rs index dbb955ba..0fb372f7 100644 --- a/src/tds/codec/token/token_col_metadata.rs +++ b/src/tds/codec/token/token_col_metadata.rs @@ -264,7 +264,10 @@ impl Encode for BaseMetaDataColumn { dst.put_u16_le(BitFlags::bits(self.flags)); let table_parts = match &self.ty { TypeInfo::VarLenSized(cx) - if matches!(cx.r#type(), VarLenType::Text | VarLenType::NText | VarLenType::Image) => + if matches!( + cx.r#type(), + VarLenType::Text | VarLenType::NText | VarLenType::Image + ) => { Some(self.table_name.as_deref().unwrap_or(&[])) } diff --git a/src/tds/codec/token/token_done.rs b/src/tds/codec/token/token_done.rs index 1e62af51..d0c071c8 100644 --- a/src/tds/codec/token/token_done.rs +++ b/src/tds/codec/token/token_done.rs @@ -93,11 +93,7 @@ impl TokenDone { } } - pub(crate) fn encode_with_type( - self, - dst: &mut BytesMut, - ty: TokenType, - ) -> crate::Result<()> { + pub(crate) fn encode_with_type(self, dst: &mut BytesMut, ty: TokenType) -> crate::Result<()> { self.encode_with_type_and_count_bytes(dst, ty, 8) } @@ -120,11 +116,7 @@ impl TokenDone { } dst.put_u32_le(self.done_rows as u32); } - _ => { - return Err(Error::Protocol( - "done: invalid row count width".into(), - )) - } + _ => return Err(Error::Protocol("done: invalid row count width".into())), } Ok(()) } diff --git a/src/tds/codec/token/token_env_change.rs b/src/tds/codec/token/token_env_change.rs index 96e85a1b..a89a7d84 100644 --- a/src/tds/codec/token/token_env_change.rs +++ b/src/tds/codec/token/token_env_change.rs @@ -1,6 +1,6 @@ use crate::{tds::codec::Encode, tds::Collation, Error, SqlReadBytes, TokenType}; -use bytes::{BufMut, BytesMut}; use byteorder::{LittleEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; use fmt::Debug; use futures_util::io::AsyncReadExt; use std::{ @@ -502,7 +502,9 @@ fn read_b_varbyte(buf: &mut Cursor>) -> crate::Result> { fn parse_tx_descriptor(bytes: Vec) -> crate::Result<[u8; 8]> { if bytes.len() != 8 { - return Err(Error::Protocol("invalid transaction descriptor length".into())); + return Err(Error::Protocol( + "invalid transaction descriptor length".into(), + )); } let mut desc = [0u8; 8]; desc.copy_from_slice(&bytes); diff --git a/src/tds/codec/token/token_error.rs b/src/tds/codec/token/token_error.rs index f1525875..402f22cb 100644 --- a/src/tds/codec/token/token_error.rs +++ b/src/tds/codec/token/token_error.rs @@ -1,6 +1,6 @@ +use crate::tds::codec::token::{write_b_varchar, write_us_varchar}; use crate::{tds::codec::Encode, tds::codec::FeatureLevel, SqlReadBytes, TokenType}; use bytes::{BufMut, BytesMut}; -use crate::tds::codec::token::{write_b_varchar, write_us_varchar}; use std::fmt; #[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)] diff --git a/src/tds/codec/token/token_fed_auth_info.rs b/src/tds/codec/token/token_fed_auth_info.rs index 52299821..ab1e7bf3 100644 --- a/src/tds/codec/token/token_fed_auth_info.rs +++ b/src/tds/codec/token/token_fed_auth_info.rs @@ -1,6 +1,6 @@ use crate::{Error, SqlReadBytes, TokenType}; -use bytes::{BufMut, BytesMut}; use byteorder::{LittleEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; use futures_util::io::AsyncReadExt; use std::io::Cursor; diff --git a/src/tds/codec/token/token_info.rs b/src/tds/codec/token/token_info.rs index 3cc10250..1bd32818 100644 --- a/src/tds/codec/token/token_info.rs +++ b/src/tds/codec/token/token_info.rs @@ -1,6 +1,6 @@ +use crate::tds::codec::token::{write_b_varchar, write_us_varchar}; use crate::{tds::codec::Encode, SqlReadBytes, TokenType}; use bytes::{BufMut, BytesMut}; -use crate::tds::codec::token::{write_b_varchar, write_us_varchar}; #[allow(dead_code)] // we might want to debug the values #[derive(Debug)] diff --git a/src/tds/codec/token/token_login_ack.rs b/src/tds/codec/token/token_login_ack.rs index dada44c2..cc057469 100644 --- a/src/tds/codec/token/token_login_ack.rs +++ b/src/tds/codec/token/token_login_ack.rs @@ -1,7 +1,7 @@ +use crate::tds::codec::token::write_b_varchar; use crate::{tds::codec::Encode, Error, FeatureLevel, SqlReadBytes, TokenType}; -use bytes::BytesMut; use bytes::BufMut; -use crate::tds::codec::token::write_b_varchar; +use bytes::BytesMut; use std::convert::TryFrom; #[allow(dead_code)] // we might want to debug the values diff --git a/src/tds/codec/token/token_row/bytes_mut_with_data_columns.rs b/src/tds/codec/token/token_row/bytes_mut_with_data_columns.rs index e3f00281..6f460d29 100644 --- a/src/tds/codec/token/token_row/bytes_mut_with_data_columns.rs +++ b/src/tds/codec/token/token_row/bytes_mut_with_data_columns.rs @@ -11,7 +11,10 @@ pub(crate) struct BytesMutWithDataColumns<'a, 'c> { impl<'a, 'c> BytesMutWithDataColumns<'a, 'c> { pub fn new(bytes: &'a mut BytesMut, data_columns: &'c [MetaDataColumn<'c>]) -> Self { - BytesMutWithDataColumns { bytes, data_columns } + BytesMutWithDataColumns { + bytes, + data_columns, + } } pub fn data_columns(&self) -> &'c [MetaDataColumn<'c>] { diff --git a/src/tds/codec/token/token_session_state.rs b/src/tds/codec/token/token_session_state.rs index 1ec837f6..318e22b5 100644 --- a/src/tds/codec/token/token_session_state.rs +++ b/src/tds/codec/token/token_session_state.rs @@ -1,6 +1,6 @@ use crate::{Error, SqlReadBytes, TokenType}; -use bytes::{BufMut, BytesMut}; use byteorder::{LittleEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; use futures_util::io::AsyncReadExt; use std::io::{Cursor, Read}; diff --git a/src/tds/tls.rs b/src/tds/tls.rs index e7f35910..9c0901b6 100644 --- a/src/tds/tls.rs +++ b/src/tds/tls.rs @@ -45,7 +45,10 @@ impl TlsPreloginWrapper { pub fn handshake_complete(&mut self) { self.pending_handshake = false; - debug_assert!(self.pending_len == 0, "pending TLS handshake data not flushed"); + debug_assert!( + self.pending_len == 0, + "pending TLS handshake data not flushed" + ); self.wr_buf.clear(); self.wr_pos = 0; self.pending_len = 0; @@ -191,10 +194,8 @@ impl AsyncWrite for TlsPreloginWrapper inner.wr_buf.len() - inner.wr_pos, ); - let written = ready!( - Pin::new(&mut inner.stream.as_mut().unwrap()) - .poll_write(cx, &inner.wr_buf[inner.wr_pos..]) - )?; + let written = ready!(Pin::new(&mut inner.stream.as_mut().unwrap()) + .poll_write(cx, &inner.wr_buf[inner.wr_pos..]))?; if written == 0 { return Poll::Ready(Err(io::Error::new( diff --git a/tests/query.rs b/tests/query.rs index 59ecac68..5396f3f3 100644 --- a/tests/query.rs +++ b/tests/query.rs @@ -3121,11 +3121,8 @@ where let table = random_table().await; - conn.execute( - format!("CREATE TABLE ##{} (v sql_variant)", table), - &[], - ) - .await?; + conn.execute(format!("CREATE TABLE ##{} (v sql_variant)", table), &[]) + .await?; conn.execute( format!("INSERT INTO ##{} (v) VALUES (@P1)", table), @@ -3134,7 +3131,7 @@ where .await?; let row = conn - .query(format!("SELECT v FROM ##{}",table), &[]) + .query(format!("SELECT v FROM ##{}", table), &[]) .await? .into_row() .await? @@ -3224,11 +3221,7 @@ where name: std::borrow::Cow::Borrowed("label"), user_type: 0, flags: ColumnFlag::Nullable.into(), - ty: TypeInfo::VarLenSized(VarLenContext::new( - VarLenType::NVarchar, - 100, - collation, - )), + ty: TypeInfo::VarLenSized(VarLenContext::new(VarLenType::NVarchar, 100, collation)), }, ]) .rows(vec![ @@ -3248,10 +3241,7 @@ where // Execute the proc with TVP parameter let stream = conn - .query( - format!("EXEC {proc_name} @tvp = @P1"), - &[&tvp], - ) + .query(format!("EXEC {proc_name} @tvp = @P1"), &[&tvp]) .await?; let rows: Vec<_> = stream.into_first_result().await?; @@ -3292,12 +3282,10 @@ where .into_results() .await?; - conn.simple_query(format!( - "CREATE TYPE dbo.{type_name} AS TABLE (id INT)", - )) - .await? - .into_results() - .await?; + conn.simple_query(format!("CREATE TYPE dbo.{type_name} AS TABLE (id INT)",)) + .await? + .into_results() + .await?; conn.simple_query(format!( "CREATE PROCEDURE {proc_name} @tvp dbo.{type_name} READONLY AS SELECT COUNT(*) AS cnt FROM @tvp", @@ -3318,10 +3306,7 @@ where .rows(vec![]); let row = conn - .query( - format!("EXEC {proc_name} @tvp = @P1"), - &[&tvp], - ) + .query(format!("EXEC {proc_name} @tvp = @P1"), &[&tvp]) .await? .into_row() .await?