11"""
2- ProbabilisticComposition {L,G}
2+ Pushforward {L,G}
33
44Differentiable composition of a probabilistic `layer` with an arbitrary function `post_processing`.
55
6- `ProbabilisticComposition ` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost.
6+ `Pushforward ` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost.
77
88# Fields
99- `layer::L`: anything that implements `compute_probability_distribution(layer, θ; kwargs...)`
1010- `post_processing::P`: callable
1111
1212See also: [`FixedAtomsProbabilityDistribution`](@ref).
1313"""
14- struct ProbabilisticComposition {L,P}
14+ struct Pushforward {L,P}
1515 layer:: L
1616 post_processing:: P
1717end
1818
19- function Base. show (io:: IO , composition:: ProbabilisticComposition )
19+ function Base. show (io:: IO , composition:: Pushforward )
2020 (; layer, post_processing) = composition
21- return print (io, " ProbabilisticComposition ($layer , $post_processing )" )
21+ return print (io, " Pushforward ($layer , $post_processing )" )
2222end
2323
2424"""
@@ -30,25 +30,23 @@ This function is not differentiable if `composition.post_processing` isn't.
3030
3131See also: [`apply_on_atoms`](@ref).
3232"""
33- function compute_probability_distribution (
34- composition:: ProbabilisticComposition , θ; kwargs...
35- )
33+ function compute_probability_distribution (composition:: Pushforward , θ; kwargs... )
3634 (; layer, post_processing) = composition
3735 probadist = compute_probability_distribution (layer, θ; kwargs... )
3836 post_processed_probadist = apply_on_atoms (post_processing, probadist; kwargs... )
3937 return post_processed_probadist
4038end
4139
4240"""
43- (composition::ProbabilisticComposition )(θ)
41+ (composition::Pushforward )(θ)
4442
4543Output the expectation of `composition.post_processing(X)`, where `X` follows the distribution defined by `composition.layer` applied to `θ`.
4644
4745Unlike [`compute_probability_distribution(composition, θ)`](@ref), this function is differentiable, even if `composition.post_processing` isn't.
4846
4947See also: [`compute_expectation`](@ref).
5048"""
51- function (composition:: ProbabilisticComposition )(θ:: AbstractArray{<:Real} ; kwargs... )
49+ function (composition:: Pushforward )(θ:: AbstractArray{<:Real} ; kwargs... )
5250 (; layer, post_processing) = composition
5351 probadist = compute_probability_distribution (layer, θ; kwargs... )
5452 return compute_expectation (probadist, post_processing; kwargs... )
0 commit comments