diff --git a/builder-pattern-macro/src/attributes.rs b/builder-pattern-macro/src/attributes.rs index de8002c..bfb5260 100644 --- a/builder-pattern-macro/src/attributes.rs +++ b/builder-pattern-macro/src/attributes.rs @@ -1,5 +1,6 @@ use bitflags::bitflags; -use syn::{Attribute, Expr, Meta, NestedMeta}; +use proc_macro2::TokenTree; +use syn::{Attribute, Expr, Ident, Meta, NestedMeta}; bitflags! { pub struct Setters: u32 { @@ -23,6 +24,17 @@ pub struct FieldAttributes { pub documents: Vec, pub setters: Setters, pub vis: FieldVisibility, + pub late_bound_default: bool, + pub infer: Vec, +} + +pub fn ident_add_underscore(ident: &Ident) -> Ident { + let ident_ = ident.to_string() + "_"; + Ident::new(&ident_, ident.span()) +} + +pub fn ident_add_underscore_tree(ident: &Ident) -> TokenTree { + TokenTree::Ident(ident_add_underscore(ident)) } impl Default for FieldAttributes { @@ -34,6 +46,8 @@ impl Default for FieldAttributes { documents: vec![], setters: Setters::VALUE, vis: FieldVisibility::Default, + late_bound_default: false, + infer: vec![], } } } @@ -62,6 +76,7 @@ impl From> for FieldAttributes { unimplemented!("Duplicated `hidden` attributes.") } attributes.vis = FieldVisibility::Hidden; + attributes.late_bound_default = true; } else if attr.path.is_ident("public") { if attributes.vis != FieldVisibility::Default { unimplemented!("Duplicated `public` attributes.") @@ -75,6 +90,10 @@ impl From> for FieldAttributes { attributes.documents = get_documents(&attrs); } else if attr.path.is_ident("setter") { parse_setters(attr, &mut attributes) + } else if attr.path.is_ident("infer") { + parse_infer(attr, &mut attributes) + } else if attr.path.is_ident("late_bound_default") { + attributes.late_bound_default = true; } }); match attributes.validate() { @@ -129,6 +148,28 @@ fn parse_setters(attr: &Attribute, attributes: &mut FieldAttributes) { attributes.setters = setters; } +fn parse_infer(attr: &Attribute, attributes: &mut FieldAttributes) { + let meta = attr.parse_meta().unwrap(); + let mut params = vec![]; + if let Meta::List(l) = meta { + let it = l.nested.iter(); + it.for_each(|m| { + if let NestedMeta::Meta(Meta::Path(p)) = m { + if let Some(ident) = p.get_ident() { + params.push(ident.clone()); + } else { + unimplemented!("Invalid infer, write a type parameter.") + } + } else { + unimplemented!("Invalid setter.") + } + }); + } else { + unimplemented!("Invalid setter.") + } + attributes.infer = params; +} + pub fn get_documents(attrs: &[Attribute]) -> Vec { let mut documents: Vec = vec![]; diff --git a/builder-pattern-macro/src/builder/builder_decl.rs b/builder-pattern-macro/src/builder/builder_decl.rs index b256580..4656441 100644 --- a/builder-pattern-macro/src/builder/builder_decl.rs +++ b/builder-pattern-macro/src/builder/builder_decl.rs @@ -22,9 +22,9 @@ impl<'a> ToTokens for BuilderDecl<'a> { let builder_name = self.input.builder_name(); let where_clause = &self.input.generics.where_clause; - let impl_tokens = self.input.tokenize_impl(); + let impl_tokens = self.input.tokenize_impl(&[]); let all_generics = self.input.all_generics().collect::>(); - let ty_tokens = self.input.tokenize_types(); + let ty_tokens = self.input.tokenize_types(&[], false); let fn_lifetime = self.input.fn_lifetime(); let builder_fields = self.input.builder_fields(&fn_lifetime); @@ -41,7 +41,8 @@ impl<'a> ToTokens for BuilderDecl<'a> { AsyncFieldMarker, ValidatorOption > #where_clause { - _phantom: ::core::marker::PhantomData<( + __builder_phantom: ::core::marker::PhantomData<( + &#fn_lifetime (), #ty_tokens #(#all_generics,)* AsyncFieldMarker, diff --git a/builder-pattern-macro/src/builder/builder_functions.rs b/builder-pattern-macro/src/builder/builder_functions.rs index 66ad37a..7f8ea0a 100644 --- a/builder-pattern-macro/src/builder/builder_functions.rs +++ b/builder-pattern-macro/src/builder/builder_functions.rs @@ -1,11 +1,11 @@ use crate::{ - attributes::{FieldVisibility, Setters}, + attributes::{ident_add_underscore_tree, FieldVisibility, Setters}, field::Field, struct_input::StructInput, }; use core::str::FromStr; -use proc_macro2::{Ident, Span, TokenStream}; +use proc_macro2::{Group, Ident, Span, TokenStream, TokenTree}; use quote::ToTokens; use syn::{parse_quote, spanned::Spanned, Attribute}; @@ -22,7 +22,21 @@ impl<'a> ToTokens for BuilderFunctions<'a> { .chain(self.input.optional_fields.iter()) .map(|f| { let ident = &f.ident; - quote! { #ident: self.#ident } + if f.attrs.late_bound_default { + quote! { + #ident: match self.#ident { + Some(::builder_pattern::setter::Setter::LateBoundDefault(d)) => { + Some(::builder_pattern::setter::Setter::LateBoundDefault(d)) + } + Some(::builder_pattern::setter::Setter::Value(val)) => { + Some(::builder_pattern::setter::Setter::Value(val)) + } + _ => unreachable!(), + } + } + } else { + quote! { #ident: self.#ident } + } }) .collect::>(); @@ -52,6 +66,25 @@ impl<'a> ToTokens for BuilderFunctions<'a> { } } +pub fn replace_type_params_in( + stream: TokenStream, + replacements: &[Ident], + with: &impl Fn(&Ident) -> TokenTree, +) -> TokenStream { + stream + .into_iter() + .map(|tt| match tt { + TokenTree::Group(g) => { + let delim = g.delimiter(); + let stream = replace_type_params_in(g.stream(), replacements, with); + TokenTree::Group(Group::new(delim, stream)) + } + TokenTree::Ident(ident) if replacements.contains(&ident) => with(&ident), + x => x, + }) + .collect() +} + impl<'a> BuilderFunctions<'a> { pub fn new(input: &'a StructInput) -> Self { Self { input } @@ -101,21 +134,38 @@ impl<'a> BuilderFunctions<'a> { index: usize, builder_fields: &mut Vec, ) { - let (ident, ty, vis) = (&f.ident, &f.ty, &f.vis); + let (ident, orig_ty, vis) = (&f.ident, &f.ty, &f.vis); let builder_name = self.input.builder_name(); let where_clause = &self.input.generics.where_clause; let lifetimes = self.input.lifetimes(); let fn_lifetime = self.input.fn_lifetime(); - let impl_tokens = self.input.tokenize_impl(); - let ty_tokens = self.input.tokenize_types(); - let (other_generics, before_generics, after_generics) = self.get_generics(f, index); - let (arg_type_gen, arg_type) = if f.attrs.use_into { - ( - Some(quote! {>}), - TokenStream::from_str("IntoType").unwrap(), - ) + let impl_tokens = self.input.tokenize_impl(&[]); + let ty_tokens = self.input.tokenize_types(&[], false); + let ty_tokens_ = self.input.tokenize_types(&f.attrs.infer, false); + let fn_where_clause = self.input.setter_where_clause(&f.attrs.infer); + let (other_generics, before_generics, mut after_generics) = self.get_generics(f, index); + let replaced_ty = replace_type_params_in( + quote! { #orig_ty }, + &f.attrs.infer, + &ident_add_underscore_tree, + ); + after_generics + .iter_mut() + .for_each(|ty_tokens: &mut TokenStream| { + let tokens = std::mem::take(ty_tokens); + *ty_tokens = + replace_type_params_in(tokens, &f.attrs.infer, &ident_add_underscore_tree); + }); + let into_generics = if f.attrs.use_into { + vec![quote! {IntoType: Into<#replaced_ty>}] + } else { + vec![] + }; + let fn_generics = f.tokenize_replacement_params(&into_generics); + let arg_type = if f.attrs.use_into { + quote! { IntoType } } else { - (None, quote! {#ty}) + quote! { #replaced_ty } }; let documents = Self::documents(f, Setters::VALUE); @@ -131,7 +181,7 @@ impl<'a> BuilderFunctions<'a> { Result<#builder_name < #fn_lifetime, #(#lifetimes,)* - #ty_tokens + #ty_tokens_ #(#after_generics,)* AsyncFieldMarker, ValidatorOption @@ -142,7 +192,7 @@ impl<'a> BuilderFunctions<'a> { match #v (value.into()) { Ok(value) => Ok( #builder_name { - _phantom: ::core::marker::PhantomData, + __builder_phantom: ::core::marker::PhantomData, #(#builder_fields),* }), Err(e) => Err(format!("Validation failed: {:?}", e)) @@ -161,7 +211,7 @@ impl<'a> BuilderFunctions<'a> { #builder_name < #fn_lifetime, #(#lifetimes,)* - #ty_tokens + #ty_tokens_ #(#after_generics,)* AsyncFieldMarker, ValidatorOption @@ -169,7 +219,7 @@ impl<'a> BuilderFunctions<'a> { }, quote! { #builder_name { - _phantom: ::core::marker::PhantomData, + __builder_phantom: ::core::marker::PhantomData, #(#builder_fields),* } }, @@ -195,7 +245,9 @@ impl<'a> BuilderFunctions<'a> { #where_clause { #(#documents)* - #vis fn #ident #arg_type_gen(self, value: #arg_type) -> #ret_type { + #vis fn #ident #fn_generics(self, value: #arg_type) -> #ret_type + #fn_where_clause + { #ret_expr } } @@ -215,8 +267,8 @@ impl<'a> BuilderFunctions<'a> { let where_clause = &self.input.generics.where_clause; let lifetimes = self.input.lifetimes(); let fn_lifetime = self.input.fn_lifetime(); - let impl_tokens = self.input.tokenize_impl(); - let ty_tokens = self.input.tokenize_types(); + let impl_tokens = self.input.tokenize_impl(&[]); + let ty_tokens = self.input.tokenize_types(&[], false); let (other_generics, before_generics, after_generics) = self.get_generics(f, index); let arg_type_gen = if f.attrs.use_into { quote! {, ValType: #fn_lifetime + ::core::ops::Fn() -> IntoType>} @@ -244,7 +296,7 @@ impl<'a> BuilderFunctions<'a> { }; let ret_expr_val = quote! { #builder_name { - _phantom: ::core::marker::PhantomData, + __builder_phantom: ::core::marker::PhantomData, #(#builder_fields),* } }; @@ -305,8 +357,8 @@ impl<'a> BuilderFunctions<'a> { let where_clause = &self.input.generics.where_clause; let lifetimes = self.input.lifetimes(); let fn_lifetime = self.input.fn_lifetime(); - let impl_tokens = self.input.tokenize_impl(); - let ty_tokens = self.input.tokenize_types(); + let impl_tokens = self.input.tokenize_impl(&[]); + let ty_tokens = self.input.tokenize_types(&[], false); let (other_generics, before_generics, after_generics) = self.get_generics(f, index); let arg_type_gen = if f.attrs.use_into { quote! {< @@ -343,7 +395,7 @@ impl<'a> BuilderFunctions<'a> { }; let ret_expr_val = quote! { #builder_name { - _phantom: ::core::marker::PhantomData, + __builder_phantom: ::core::marker::PhantomData, #(#builder_fields),* } }; diff --git a/builder-pattern-macro/src/builder/builder_impl.rs b/builder-pattern-macro/src/builder/builder_impl.rs index 10706a3..8cf9e19 100644 --- a/builder-pattern-macro/src/builder/builder_impl.rs +++ b/builder-pattern-macro/src/builder/builder_impl.rs @@ -1,8 +1,9 @@ use crate::{attributes::Setters, struct_input::StructInput}; use core::str::FromStr; -use proc_macro2::TokenStream; +use proc_macro2::{Ident, Span, TokenStream}; use quote::ToTokens; +use syn::spanned::Spanned; pub struct BuilderImpl<'a> { pub input: &'a StructInput, @@ -74,10 +75,11 @@ impl<'a> BuilderImpl<'a> { let fn_lifetime = self.input.fn_lifetime(); - let impl_tokens = self.input.tokenize_impl(); + let impl_tokens = self.input.tokenize_impl(&[]); let optional_generics = self.optional_generics().collect::>(); let satisfied_generics = self.satified_generics().collect::>(); - let ty_tokens = self.input.tokenize_types(); + + let ty_tokens = self.input.tokenize_types(&[], false); let mut struct_init_args = vec![]; let mut validated_init_fields = vec![]; @@ -89,23 +91,62 @@ impl<'a> BuilderImpl<'a> { .chain(self.input.optional_fields.iter()) .for_each(|f| { let ident = &f.ident; + let ty = &f.ty; + // let substituted_ty = replace_defaults(quote! { #ty }); struct_init_args.push(ident.to_token_stream()); + let mk_default_case = + |wrap: fn(TokenStream) -> TokenStream| match &f.attrs.default.as_ref() { + Some((expr, setters)) if f.attrs.late_bound_default => { + let expr = match *setters { + Setters::VALUE => quote_spanned! { expr.span() => #expr }, + Setters::LAZY => quote_spanned! { expr.span() => (#expr)() }, + _ => unimplemented!(), + }; + let wrapped_expr = wrap(expr); + quote! { + None => unreachable!("field should have had default"), + Some(::builder_pattern::setter::Setter::LateBoundDefault(id)) => { + { let val: #ty = id.cast(#wrapped_expr); val } + } + Some(::builder_pattern::setter::Setter::Default(..)) => + unreachable!("late-bound optional field was set in new()"), + } + } + Some((_expr, _setters)) => { + let id = Ident::new("id", Span::call_site()); + let default = Ident::new("default", Span::call_site()); + let expr = wrap(quote!{ #id.cast(#default) }); + quote! { + None => unreachable!("early-bound optional field had no default set in new()"), + Some(::builder_pattern::setter::Setter::LateBoundDefault(..)) => unreachable!("early-bound optional field had no default set in new()"), + Some(::builder_pattern::setter::Setter::Default(#default, #id)) => #expr, + } + } + _ => quote! { + Some(::builder_pattern::setter::Setter::LateBoundDefault(..)) | + Some(::builder_pattern::setter::Setter::Default(..)) | + None => unreachable!("required field not set"), + }, + }; + if f.attrs.validator.is_some() && !(f.attrs.setters & (Setters::LAZY | Setters::ASYNC)).is_empty() { let async_case = if is_async { quote! { - ::builder_pattern::setter::Setter::Async(f) => Ok(f().await), - ::builder_pattern::setter::Setter::AsyncValidated(f) => f().await, + Some(::builder_pattern::setter::Setter::Async(f)) => Ok(f().await), + Some(::builder_pattern::setter::Setter::AsyncValidated(f)) => f().await, } } else { quote! {_ => unimplemented!()} }; + let default_case = mk_default_case(|expr| quote! { Ok(#expr) }); validated_init_fields.push(quote! { - let #ident = match match self.#ident.unwrap() { - ::builder_pattern::setter::Setter::Value(v) => Ok(v), - ::builder_pattern::setter::Setter::Lazy(f) => Ok(f()), - ::builder_pattern::setter::Setter::LazyValidated(f) => f(), + let #ident = match match self.#ident { + #default_case + Some(::builder_pattern::setter::Setter::Value(v)) => Ok(v), + Some(::builder_pattern::setter::Setter::Lazy(f)) => Ok(f()), + Some(::builder_pattern::setter::Setter::LazyValidated(f)) => f(), #async_case } { Ok(v) => v, @@ -115,32 +156,36 @@ impl<'a> BuilderImpl<'a> { } else { let async_case = if is_async { quote! { - ::builder_pattern::setter::Setter::Async(f) => f().await, + Some(::builder_pattern::setter::Setter::Async(f)) => f().await, _ => unimplemented!(), } } else { quote! {_ => unimplemented!()} }; + let default_case = mk_default_case(|expr| quote! { #expr }); init_fields.push(quote! { - let #ident = match self.#ident.unwrap() { - ::builder_pattern::setter::Setter::Value(v) => v, - ::builder_pattern::setter::Setter::Lazy(f) => f(), + let #ident = match self.#ident { + #default_case + Some(::builder_pattern::setter::Setter::Value(v)) => v, + Some(::builder_pattern::setter::Setter::Lazy(f)) => f(), #async_case }; }); } let async_case = if is_async { quote! { - ::builder_pattern::setter::Setter::Async(f) => f().await, + Some(::builder_pattern::setter::Setter::Async(f)) => f().await, _ => unimplemented!(), } } else { quote! {_ => unimplemented!()} }; + let default_case = mk_default_case(|expr| quote! { #expr }); no_lazy_validation_fields.push(quote! { - let #ident = match self.#ident.unwrap() { - ::builder_pattern::setter::Setter::Value(v) => v, - ::builder_pattern::setter::Setter::Lazy(f) => f(), + let #ident = match self.#ident { + Some(::builder_pattern::setter::Setter::Value(v)) => v, + Some(::builder_pattern::setter::Setter::Lazy(f)) => f(), + #default_case #async_case }; }); @@ -155,10 +200,11 @@ impl<'a> BuilderImpl<'a> { }; tokens.extend(quote! { impl <#fn_lifetime, #impl_tokens #(#optional_generics,)*> #builder_name - <#fn_lifetime, #(#lifetimes,)* #ty_tokens #(#satisfied_generics),*, #async_generic, ()> + <#fn_lifetime, #(#lifetimes,)* #ty_tokens #(#satisfied_generics,)* #async_generic, ()> #where_clause { #[allow(dead_code)] + #[allow(clippy::redundant_closure_call)] #vis #kw_async fn build(self) -> #ident <#(#lifetimes,)* #ty_tokens> { #(#no_lazy_validation_fields)* #ident { diff --git a/builder-pattern-macro/src/field.rs b/builder-pattern-macro/src/field.rs index 2578c5a..79ca098 100644 --- a/builder-pattern-macro/src/field.rs +++ b/builder-pattern-macro/src/field.rs @@ -1,9 +1,11 @@ +use crate::attributes::ident_add_underscore; + use super::attributes::FieldAttributes; use core::cmp::Ordering; -use proc_macro2::Ident; -use quote::ToTokens; -use syn::{Attribute, Type, Visibility}; +use proc_macro2::{Ident, TokenStream}; +use quote::{ToTokens, TokenStreamExt}; +use syn::{token::Comma, Attribute, Token, Type, Visibility}; pub struct Field { pub vis: Visibility, @@ -30,6 +32,23 @@ impl Field { ty_token.to_string() } } + + pub fn tokenize_replacement_params(&self, additional: &[TokenStream]) -> TokenStream { + let mut stream = TokenStream::new(); + if self.attrs.infer.is_empty() && additional.is_empty() { + return stream; + } + let underscored = self + .attrs + .infer + .iter() + .map(|ident| ident_add_underscore(ident)); + ::default().to_tokens(&mut stream); + stream.append_terminated(underscored, Comma::default()); + stream.append_terminated(additional.iter(), Comma::default()); + ]>::default().to_tokens(&mut stream); + stream + } } impl Ord for Field { diff --git a/builder-pattern-macro/src/lib.rs b/builder-pattern-macro/src/lib.rs index 18c59e6..d13debc 100644 --- a/builder-pattern-macro/src/lib.rs +++ b/builder-pattern-macro/src/lib.rs @@ -30,7 +30,9 @@ extern crate proc_macro2; into, public, setter, - validator + validator, + infer, + late_bound_default, ) )] pub fn derive_builder(input: TokenStream) -> TokenStream { diff --git a/builder-pattern-macro/src/struct_impl.rs b/builder-pattern-macro/src/struct_impl.rs index 9eaa6de..34aea83 100644 --- a/builder-pattern-macro/src/struct_impl.rs +++ b/builder-pattern-macro/src/struct_impl.rs @@ -1,7 +1,10 @@ -use crate::{attributes::Setters, struct_input::StructInput}; +use crate::{ + attributes::Setters, builder::builder_functions::replace_type_params_in, + struct_input::StructInput, +}; use core::str::FromStr; -use proc_macro2::TokenStream; +use proc_macro2::{Ident, TokenStream}; use quote::ToTokens; use syn::{parse_quote, spanned::Spanned, Attribute}; @@ -15,13 +18,23 @@ impl<'a> ToTokens for StructImpl<'a> { fn to_tokens(&self, tokens: &mut TokenStream) { let ident = &self.input.ident; let vis = &self.input.vis; - let where_clause = &self.input.generics.where_clause; let builder_name = self.input.builder_name(); let lifetimes = self.input.lifetimes(); - let impl_tokens = self.input.tokenize_impl(); let empty_generics = self.empty_generics(); - let ty_tokens = self.input.tokenize_types(); + let defaulted_generics = self.input.defaulted_generics(); + + let with_prmdef = |ident: &Ident| self.input.with_param_default(&defaulted_generics, ident); + let replace_defaults = + |stream: TokenStream| replace_type_params_in(stream, &defaulted_generics, &with_prmdef); + + let impl_tokens = self.input.tokenize_impl(&defaulted_generics); + + let where_clause = &self.input.generics.where_clause; + let where_tokens = + replace_type_params_in(quote! { #where_clause }, &defaulted_generics, &with_prmdef); + + let ty_tokens = replace_defaults(self.input.tokenize_types(&[], false)); let fn_lifetime = self.input.fn_lifetime(); @@ -29,20 +42,20 @@ impl<'a> ToTokens for StructImpl<'a> { let docs = self.documents(); tokens.extend(quote! { - impl <#impl_tokens> #ident <#(#lifetimes,)* #ty_tokens> #where_clause { + impl <#impl_tokens> #ident <#(#lifetimes,)* #ty_tokens> #where_tokens { #(#docs)* #[allow(clippy::new_ret_no_self)] #vis fn new<#fn_lifetime>() -> #builder_name< #fn_lifetime, #(#lifetimes,)* #ty_tokens - #(#empty_generics),*, + #(#empty_generics,)* (), () > { #[allow(clippy::redundant_closure_call)] #builder_name { - _phantom: ::core::marker::PhantomData, + __builder_phantom: ::core::marker::PhantomData, #(#builder_init_args),* } } @@ -78,20 +91,31 @@ impl<'a> StructImpl<'a> { }) .chain(self.input.optional_fields.iter().map(|f| { if let (ident, Some((expr, setters))) = (&f.ident, &f.attrs.default.as_ref()) { - match *setters { - Setters::VALUE => quote_spanned! { expr.span() => - #ident: Some(::builder_pattern::setter::Setter::Value(#expr)) - }, - Setters::LAZY => { - quote_spanned! { expr.span() => - #ident: Some( - ::builder_pattern::setter::Setter::Lazy( - Box::new(#expr) + if f.attrs.late_bound_default { + quote_spanned! { expr.span() => + #ident: Some(::builder_pattern::setter::Setter::LateBoundDefault( + ::builder_pattern::refl::refl() + )) + } + } else { + match *setters { + Setters::VALUE => quote_spanned! { expr.span() => + #ident: Some(::builder_pattern::setter::Setter::Default( + #expr, + ::builder_pattern::refl::refl() + )) + }, + Setters::LAZY => { + quote_spanned! { expr.span() => + #ident: Some( + ::builder_pattern::setter::Setter::Lazy( + Box::new(#expr) + ) ) - ) + } } + _ => unimplemented!(), } - _ => unimplemented!(), } } else { unimplemented!() diff --git a/builder-pattern-macro/src/struct_input.rs b/builder-pattern-macro/src/struct_input.rs index 72f8a32..7adc36b 100644 --- a/builder-pattern-macro/src/struct_input.rs +++ b/builder-pattern-macro/src/struct_input.rs @@ -1,4 +1,7 @@ -use crate::attributes::{FieldAttributes, FieldVisibility}; +use crate::attributes::{ + ident_add_underscore, ident_add_underscore_tree, FieldAttributes, FieldVisibility, +}; +use crate::builder::builder_functions::replace_type_params_in; use crate::builder::{ builder_decl::BuilderDecl, builder_functions::BuilderFunctions, builder_impl::BuilderImpl, }; @@ -6,13 +9,15 @@ use crate::field::Field; use crate::struct_impl::StructImpl; use core::str::FromStr; -use proc_macro2::{Ident, Span, TokenStream}; +use proc_macro2::{Group, Ident, Span, TokenStream, TokenTree}; use quote::{ToTokens, TokenStreamExt}; +use syn::token::Comma; use syn::{ parse::{Parse, ParseStream, Result}, AttrStyle, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Lifetime, Token, VisPublic, Visibility, }; +use syn::{WhereClause, WherePredicate}; pub struct StructInput { pub vis: Visibility, @@ -138,82 +143,122 @@ impl StructInput { &'a self, fn_lifetime: &'a Lifetime, ) -> impl 'a + Iterator { + // TODO: just store defaulted_generics in a field + let defaulted_generics = self.defaulted_generics(); + let defaulted_generics2 = defaulted_generics.clone(); + let with_prmdef = move |ident: &Ident| self.with_param_default(&defaulted_generics, ident); + let replace_defaults = move |stream: TokenStream| { + replace_type_params_in(stream, &defaulted_generics2, &with_prmdef) + }; self.required_fields .iter() .chain(self.optional_fields.iter()) .map(move |f| { let (ident, ty) = (&f.ident, &f.ty); + let subst = replace_defaults(quote! { #ty }); quote! { - #ident: Option<::builder_pattern::setter::Setter<#fn_lifetime, #ty>> + #ident: Option<::builder_pattern::setter::Setter<#fn_lifetime, #ty, #subst>> } }) } /// Tokenize type parameters. /// It skips lifetimes and has no outer brackets. - pub fn tokenize_types(&self) -> TokenStream { + pub fn tokenize_types(&self, infer: &[Ident], omit_replaced: bool) -> TokenStream { let generics = &self.generics; let mut tokens = TokenStream::new(); if generics.params.is_empty() { return tokens; } - - let mut trailing_or_empty = true; - for param in generics.params.pairs() { - if let GenericParam::Lifetime(_) = *param.value() { - trailing_or_empty = param.punct().is_some(); - } + if omit_replaced + && generics.params.iter().all(|x| match x { + GenericParam::Type(param) => infer.contains(¶m.ident), + GenericParam::Const(_) => false, + _ => true, + }) + { + return tokens; } + for param in generics.params.pairs() { if let GenericParam::Lifetime(_) = **param.value() { continue; } - if !trailing_or_empty { - ::default().to_tokens(&mut tokens); - trailing_or_empty = true; - } match *param.value() { GenericParam::Lifetime(_) => unreachable!(), GenericParam::Type(param) => { // Leave off the type parameter defaults - param.ident.to_tokens(&mut tokens); + if infer.contains(¶m.ident) { + if omit_replaced { + continue; + } + ident_add_underscore(¶m.ident).to_tokens(&mut tokens); + } else { + param.ident.to_tokens(&mut tokens); + } } GenericParam::Const(param) => { // Leave off the const parameter defaults param.ident.to_tokens(&mut tokens); } } - param.punct().to_tokens(&mut tokens); + Comma::default().to_tokens(&mut tokens); } - ::default().to_tokens(&mut tokens); tokens } + pub fn setter_where_clause(&self, infer: &[Ident]) -> TokenStream { + let mut stream = TokenStream::new(); + if infer.is_empty() { + return stream; + } + let clauses = self + .generics + .where_clause + .iter() + .flat_map(|where_clause: &WhereClause| { + where_clause.predicates.iter().map(|pred: &WherePredicate| { + replace_type_params_in(quote! { #pred }, infer, &ident_add_underscore_tree) + }) + }); + stream.extend(quote! { where }); + stream.append_terminated(clauses, quote! { , }); + stream + } + /// Tokenize parameters for `impl` blocks. /// It doesn't contain outer brackets, but lifetimes and trait bounds. - pub fn tokenize_impl(&self) -> TokenStream { + pub fn tokenize_impl(&self, filter_out: &[Ident]) -> TokenStream { let mut tokens = TokenStream::new(); let generics = &self.generics; - let mut trailing_or_empty = true; + if generics.params.is_empty() { + return tokens; + } + if generics.params.iter().all(|x| match x { + GenericParam::Type(param) => filter_out.contains(¶m.ident), + _ => false, + }) { + return tokens; + } + for param in generics.params.pairs() { - if let GenericParam::Lifetime(_) = **param.value() { - param.to_tokens(&mut tokens); - trailing_or_empty = param.punct().is_some(); + if let GenericParam::Lifetime(l) = *param.value() { + l.to_tokens(&mut tokens); + Comma::default().to_tokens(&mut tokens); } } for param in generics.params.pairs() { if let GenericParam::Lifetime(_) = **param.value() { continue; } - if !trailing_or_empty { - ::default().to_tokens(&mut tokens); - trailing_or_empty = true; - } match *param.value() { GenericParam::Lifetime(_) => unreachable!(), GenericParam::Type(param) => { + if filter_out.contains(¶m.ident) { + continue; + } // Leave off the type parameter defaults tokens.append_all(param.attrs.iter().filter(|attr| match attr.style { AttrStyle::Outer => true, @@ -240,11 +285,37 @@ impl StructInput { param.ty.to_tokens(&mut tokens); } } - param.punct().to_tokens(&mut tokens); - } - if !tokens.is_empty() { - ::default().to_tokens(&mut tokens); + Comma::default().to_tokens(&mut tokens); } tokens } + + pub fn defaulted_generics(&self) -> Vec { + self.generics + .type_params() + .filter(|x| x.default.is_some()) + .map(|x| x.ident.clone()) + .collect() + } + + pub fn with_param_default(&self, defaulted_generics: &[Ident], ident: &Ident) -> TokenTree { + let with_prmdef = |ident: &Ident| self.with_param_default(defaulted_generics, ident); + self.generics + .type_params() + .find_map(|x| { + if x.ident == *ident { + let default = x.default.as_ref().unwrap(); + let stream = quote! { #default }; + let replaced_within = + replace_type_params_in(stream, defaulted_generics, &with_prmdef); + Some(TokenTree::Group(Group::new( + proc_macro2::Delimiter::None, + replaced_within, + ))) + } else { + None + } + }) + .expect("hmmmm") + } } diff --git a/builder-pattern/examples/default-generics.rs b/builder-pattern/examples/default-generics.rs new file mode 100644 index 0000000..9b5ffa0 --- /dev/null +++ b/builder-pattern/examples/default-generics.rs @@ -0,0 +1,156 @@ +use builder_pattern::Builder; +use std::any::{Any, TypeId}; +use std::marker::PhantomData; +use std::ops::Add; + +#[allow(unused)] +#[derive(Builder)] +struct Op { + #[infer(T)] + #[default(None)] + optional_field: Option, +} + +fn defaulted() { + // Should be inferred as Op, i.e. the macro should notice the defaulted type param. + let a = Op::new().build(); + assert_eq!(a.type_id(), TypeId::of::>()); +} + +fn override_default() { + // Should be inferred as Op + let a = Op::new().optional_field(Some(5i32)).build(); + assert_eq!(a.type_id(), TypeId::of::>()); +} + +#[allow(unused)] +#[derive(Builder)] +struct IterExtra> +where + I: IntoIterator, +{ + single: T, + #[default(None)] + extra: Option, +} + +fn inferred() { + let a = IterExtra::new().single(1).build(); + assert_eq!(a.type_id(), TypeId::of::>>()); +} + +#[allow(unused)] +#[derive(Builder)] +struct DefaultedClosure R> +where + F1: for<'a> FnMut(R, &T) -> R, + F2: for<'a> FnMut(R, &T) -> R, +{ + mandatory: F1, + #[infer(F2)] + #[late_bound_default] + #[default(|r, _t| r)] + optional: F2, + #[hidden] + #[default(PhantomData)] + phantom: PhantomData<(T, R)>, +} + +trait Callable { + fn call_fn(&mut self, r: R, t: &T) -> R; + fn call_inverse(&mut self, r: R, t: &T) -> R; +} +impl Callable for DefaultedClosure +where + F1: for<'a> FnMut(R, &T) -> R, + F2: for<'a> FnMut(R, &T) -> R, +{ + fn call_fn(&mut self, r: R, t: &T) -> R { + (self.mandatory)(r, t) + } + fn call_inverse(&mut self, r: R, t: &T) -> R { + let f = &mut self.optional; + f(r, t) + } +} + +fn accumulate_sum(acc: T, next: &T) -> T +where + T: for<'a> Add<&'a T, Output = T>, +{ + acc + next +} + +fn infer_f_generic() { + let mut _a = DefaultedClosure::new() + .mandatory(|acc: f64, x| acc + x) + .optional(accumulate_sum) + .build(); +} + +fn infer_f_missing() { + let mut _a = DefaultedClosure::new() + .mandatory(|acc: f64, x| acc + x) + .build(); +} + +fn fold_with_closure<'b, F1, T: 'b, R, F2>( + iter: impl Iterator, + init: R, + mut c: DefaultedClosure, +) -> R +where + F1: for<'a> FnMut(R, &T) -> R, + F2: for<'a> FnMut(R, &T) -> R, +{ + iter.fold(init, move |acc, x| c.call_fn(acc, x)) +} + +fn infer_using_fold() { + let _ = fold_with_closure( + core::iter::once(&5i32), + 0, + DefaultedClosure::new() + .mandatory(|acc, &x| acc + x) + .optional(|_acc, &x| x) + .build(), + ); +} + +fn infer_before_fold() { + let folder = DefaultedClosure::new().mandatory(|acc, &x| acc + x).build(); + let _ = fold_with_closure(core::iter::once(&5i32), 0, folder); +} + +fn infer_t_r() { + let mut a = DefaultedClosure::new() + // The types of the closure params should be inferred + .mandatory(|acc, x| acc + x) + .build(); + let _called: i32 = a.call_fn(5i32, &5); +} + +fn build_with_optional_new_type() { + let mut captured = String::from("hello"); + let mut a = DefaultedClosure::new() + // The types of the closure params should be inferred + .mandatory(|acc, x| acc + x) + .optional(move |acc, x| { + captured.push_str("hello"); + acc - x + }) + .build(); + let _called: i32 = a.call_fn(5i32, &5); +} + +fn main() { + defaulted(); + override_default(); + inferred(); + infer_f_generic(); + infer_f_missing(); + infer_using_fold(); + infer_before_fold(); + build_with_optional_new_type(); + infer_t_r(); +} diff --git a/builder-pattern/examples/default-infer.rs b/builder-pattern/examples/default-infer.rs new file mode 100644 index 0000000..a151abc --- /dev/null +++ b/builder-pattern/examples/default-infer.rs @@ -0,0 +1,38 @@ +use builder_pattern::Builder; + +#[allow(unused)] +#[derive(Builder)] +struct LateBound B = fn(B) -> B> { + field_a: A, + field_b: B, + #[late_bound_default] + #[default(|x| x)] + transform_b: F, +} + +impl LateBound +where + B: Clone, +{ + fn get_b(&self) -> B { + (self.transform_b)(self.field_b.clone()) + } +} + +fn with() { + let l = LateBound::new() + .field_a(String::new()) + .field_b(5) + .transform_b(|x| x + 10) + .build(); + assert_eq!(l.get_b(), 15); +} +fn without() { + let l = LateBound::new().field_a(String::new()).field_b(200).build(); + assert_eq!(l.get_b(), 200); +} + +fn main() { + with(); + without(); +} diff --git a/builder-pattern/examples/default.rs b/builder-pattern/examples/default.rs index 306cecf..957e887 100644 --- a/builder-pattern/examples/default.rs +++ b/builder-pattern/examples/default.rs @@ -1,6 +1,7 @@ use builder_pattern::Builder; use uuid::Uuid; +#[allow(unused)] #[derive(Builder, Debug)] struct Test { #[default(String::from("Jack"))] diff --git a/builder-pattern/examples/documentation.rs b/builder-pattern/examples/documentation.rs index 6417460..92008cb 100644 --- a/builder-pattern/examples/documentation.rs +++ b/builder-pattern/examples/documentation.rs @@ -17,6 +17,7 @@ use builder_pattern::Builder; /// /// println!("{:?}", person); /// ``` +#[allow(unused)] #[derive(Builder, Debug)] struct Person { /** diff --git a/builder-pattern/examples/fail-visibility1.rs b/builder-pattern/examples/fail-visibility1.rs index cbd05e4..eaa1020 100644 --- a/builder-pattern/examples/fail-visibility1.rs +++ b/builder-pattern/examples/fail-visibility1.rs @@ -3,13 +3,15 @@ mod test { // Private structure #[derive(Builder, Debug)] - struct PrivateTest { + pub struct PrivateTest { pub a: i32, pub b: &'static str, c: i32, } } +use test::*; + pub fn main() { let t1 = PrivateTest::new().a(333).c(1.234).b("hello").build(); } diff --git a/builder-pattern/examples/into-with-default.rs b/builder-pattern/examples/into-with-default.rs new file mode 100644 index 0000000..619c88c --- /dev/null +++ b/builder-pattern/examples/into-with-default.rs @@ -0,0 +1,25 @@ +use std::any::TypeId; + +use builder_pattern::Builder; + +#[allow(unused)] +#[derive(Builder)] +struct Test { + // Note that without the #[infer(T)], we would still have T = f64 from the + // type param default. + // The setter method will have a `T_` parameter, and return a TestBuilder. + #[infer(T)] + #[into] + vector: Vec, +} + +fn main() { + let _ = Test::new().vector(&b"byte slice"[..]).build(); + + // in more detail: + // we can't use a mutable builder and re-assign it, because they are different types. + let builder = Test::new(); + let builder = builder.vector::(&b"hello"[..]); + let t = builder.build(); + assert_eq!(std::any::Any::type_id(&t.vector), TypeId::of::>()); +} diff --git a/builder-pattern/examples/visibility1.rs b/builder-pattern/examples/visibility1.rs index 3a4b8bf..b2a6407 100644 --- a/builder-pattern/examples/visibility1.rs +++ b/builder-pattern/examples/visibility1.rs @@ -2,6 +2,7 @@ mod test { use builder_pattern::Builder; // Public structure + #[allow(unused)] #[derive(Builder, Debug)] pub struct PublicTest { pub a: i32, diff --git a/builder-pattern/examples/visibility2.rs b/builder-pattern/examples/visibility2.rs index c8f961a..36d871b 100644 --- a/builder-pattern/examples/visibility2.rs +++ b/builder-pattern/examples/visibility2.rs @@ -3,6 +3,7 @@ mod test { // Public structure #[derive(Builder, Debug)] + #[allow(unused)] pub struct PublicTest { pub a: i32, pub b: Option, diff --git a/builder-pattern/src/lib.rs b/builder-pattern/src/lib.rs index 266ebe1..66f564f 100644 --- a/builder-pattern/src/lib.rs +++ b/builder-pattern/src/lib.rs @@ -648,3 +648,6 @@ pub use builder_pattern_macro::Builder; #[doc(hidden)] pub mod setter; + +#[doc(hidden)] +pub mod refl; diff --git a/builder-pattern/src/refl.rs b/builder-pattern/src/refl.rs new file mode 100644 index 0000000..08f96c1 --- /dev/null +++ b/builder-pattern/src/refl.rs @@ -0,0 +1,60 @@ +//! https://github.com/Centril/refl +//! +//! Used under the MIT license. Just the basics. + +use core::marker::PhantomData; +use core::mem; + +/// +/// ```compile_fail +/// use builder_pattern::refl::Id; +/// let id = Id::>::REFL; +/// ``` +/// +/// ``` +/// use builder_pattern::refl::{refl, Id}; +/// fn get_i32(t: T, id: Id) -> i32 { +/// id.cast(t) +/// } +/// let five = get_i32(5, refl()); +/// assert_eq!(five, 5i32); +/// ``` +/// +pub struct Id(PhantomData<(fn(S) -> S, fn(T) -> T)>); + +impl Id { + pub const REFL: Self = Id(PhantomData); +} + +pub fn refl() -> Id { + Id::REFL +} + +impl Id { + /// Casts a value of type `S` to `T`. + /// + /// This is safe because the `Id` type is always guaranteed to + /// only be inhabited by `Id` types by construction. + pub fn cast(self, value: S) -> T + where + S: Sized, + T: Sized, + { + unsafe { + // Transmute the value; + // This is safe since we know by construction that + // S == T (including lifetime invariance) always holds. + let cast_value = mem::transmute_copy(&value); + + // Forget the value; + // otherwise the destructor of S would be run. + mem::forget(value); + + cast_value + } + } + /// Converts `Id` into `Id` since type equality is symmetric. + pub fn sym(self) -> Id { + Id(PhantomData) + } +} diff --git a/builder-pattern/src/setter.rs b/builder-pattern/src/setter.rs index 9ca8312..73cce48 100644 --- a/builder-pattern/src/setter.rs +++ b/builder-pattern/src/setter.rs @@ -1,7 +1,11 @@ #[cfg(feature = "future")] use futures::future::LocalBoxFuture; -pub enum Setter<'a, T> { +use super::refl::Id; + +pub enum Setter<'a, T, D = T> { + Default(D, Id), + LateBoundDefault(Id), Value(T), Lazy(Box T>), LazyValidated(Box Result>), diff --git a/test-no-future/examples/empty.rs b/test-no-future/examples/empty.rs new file mode 100644 index 0000000..ddff689 --- /dev/null +++ b/test-no-future/examples/empty.rs @@ -0,0 +1,8 @@ +use builder_pattern::Builder; + +#[derive(Builder)] +struct Thing {} + +fn main() { + let _: Thing = Thing::new().build(); +}