11#= ===================
22 LOGISTIC CLASSIFIER
3- =================== =#
3+ =================== =#
44
55"""
6- Logistic Classifier (typically called "Logistic Regression"). This model is
7- a standard classifier for both binary and multiclass classification.
8- In the binary case it corresponds to the LogisticLoss, in the multiclass to the
9- Multinomial (softmax) loss. An elastic net penalty can be applied with
10- overall objective function
6+ $(doc_header (LogisticClassifier))
117
12- ``L(y, Xθ) + n⋅λ|θ|₂²/2 + n⋅γ|θ|₁``
8+ This model is more commonly known as "logistic regression". It is a standard classifier
9+ for both binary and multiclass classification. The objective function applies either a
10+ logistic loss (binary target) or multinomial (softmax) loss, and has a mixed L1/L2
11+ penalty:
1312
14- where ``L`` is either the logistic or multinomial loss and ``λ`` and ``γ`` indicate
15- the strength of the L2 (resp. L1) regularisation components and
16- ``n`` is the number of samples `size(X, 1)`.
17- With `scale_penalty_with_samples = false` the objective function is
18- ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``
13+ ``L(y, Xθ) + n⋅λ|θ|₂²/2 + n⋅γ|θ|₁``.
1914
20- ## Parameters
15+ Here ``L`` is either `MLJLinearModels.LogisticLoss` or `MLJLinearModels.MultiClassLoss`,
16+ ``λ`` and ``γ`` indicate
17+ the strength of the L2 (resp. L1) regularization components and
18+ ``n`` is the number of training observations.
19+
20+ With `scale_penalty_with_samples = false` the objective function is instead
21+
22+ ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``.
23+
24+ # Training data
25+
26+ In MLJ or MLJBase, bind an instance `model` to data with
27+
28+ mach = machine(model, X, y)
29+
30+ where:
31+
32+ - `X` is any table of input features (eg, a `DataFrame`) whose columns
33+ have `Continuous` scitype; check column scitypes with `schema(X)`
34+
35+ - `y` is the target, which can be any `AbstractVector` whose element
36+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
37+ with `scitype(y)`
38+
39+ Train the machine using `fit!(mach, rows=...)`.
40+
41+
42+ # Hyperparameters
2143
2244$TYPEDFIELDS
2345
2446$(example_docstring (" LogisticClassifier" , nclasses = 2 ))
47+
48+ See also [`MultinomialClassifier`](@ref).
49+
2550"""
2651@with_kw_noshow mutable struct LogisticClassifier <: MMI.Probabilistic
27- " strength of the regulariser if `penalty` is `:l2` or `:l1` and strength of the L2
28- regulariser if `penalty` is `:en`."
52+ " strength of the regularizer if `penalty` is `:l2` or `:l1` and strength of the L2
53+ regularizer if `penalty` is `:en`."
2954 lambda:: Real = eps ()
30- " strength of the L1 regulariser if `penalty` is `:en`."
55+ " strength of the L1 regularizer if `penalty` is `:en`."
3156 gamma:: Real = 0.0
3257 " the penalty to use, either `:l2`, `:l1`, `:en` (elastic net) or `:none`."
3358 penalty:: SymStr = :l2
@@ -37,7 +62,19 @@ $(example_docstring("LogisticClassifier", nclasses = 2))
3762 penalize_intercept:: Bool = false
3863 " whether to scale the penalty with the number of samples."
3964 scale_penalty_with_samples:: Bool = true
40- " type of solver to use, default if `nothing`."
65+ """ some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`, `Newton`,
66+ `NewtonCG`, `ProxGrad`; but subject to the following restrictions:
67+
68+ - If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxyGrad` is the only
69+ option.
70+
71+ - Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
72+
73+ If `solver = nothing` (default) then `ProxGrad(accel=true)` (FISTA) is used,
74+ unless `gamma = 0`, in which case `LBFGS()` is used.
75+
76+ Solver aliases: `FISTA(; kwargs...) = ProxGrad(accel=true, kwargs...)`,
77+ `ISTA(; kwargs...) = ProxGrad(accel=false, kwargs...)`"""
4178 solver:: Option{Solver} = nothing
4279end
4380
@@ -50,27 +87,49 @@ glr(m::LogisticClassifier, nclasses::Integer) =
5087 scale_penalty_with_samples= m. scale_penalty_with_samples,
5188 nclasses= nclasses)
5289
53- descr (:: Type{LogisticClassifier} ) = " Classifier corresponding to the loss function ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the logistic loss."
54-
5590#= ======================
5691 MULTINOMIAL CLASSIFIER
5792 ====================== =#
5893
5994"""
60- See `LogisticClassifier`, it's the same except that multiple classes are assumed
61- by default. The other parameters are the same.
95+ $(doc_header (MultinomialClassifier))
96+
97+ This model coincides with [`LogisticClassifier`](@ref), except certain optimizations
98+ possible in the special binary case will not be applied. Its hyperparameters are
99+ identical.
100+
101+ # Training data
102+
103+ In MLJ or MLJBase, bind an instance `model` to data with
104+
105+ mach = machine(model, X, y)
106+
107+ where:
108+
109+ - `X` is any table of input features (eg, a `DataFrame`) whose columns
110+ have `Continuous` scitype; check column scitypes with `schema(X)`
111+
112+ - `y` is the target, which can be any `AbstractVector` whose element
113+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
114+ with `scitype(y)`
115+
116+ Train the machine using `fit!(mach, rows=...)`.
62117
63- ## Parameters
118+
119+ # Hyperparameters
64120
65121$TYPEDFIELDS
66122
67- $(example_docstring (" LogisticClassifier" , nclasses = 3 ))
123+ $(example_docstring (" MultinomialClassifier" , nclasses = 3 ))
124+
125+ See also [`LogisticClassifier`](@ref).
126+
68127"""
69128@with_kw_noshow mutable struct MultinomialClassifier <: MMI.Probabilistic
70- " strength of the regulariser if `penalty` is `:l2` or `:l1`.
71- Strength of the L2 regulariser if `penalty` is `:en`."
129+ " strength of the regularizer if `penalty` is `:l2` or `:l1`.
130+ Strength of the L2 regularizer if `penalty` is `:en`."
72131 lambda:: Real = eps ()
73- " strength of the L1 regulariser if `penalty` is `:en`."
132+ " strength of the L1 regularizer if `penalty` is `:en`."
74133 gamma:: Real = 0.0
75134 " the penalty to use, either `:l2`, `:l1`, `:en` (elastic net) or `:none`."
76135 penalty:: SymStr = :l2
@@ -80,7 +139,19 @@ $(example_docstring("LogisticClassifier", nclasses = 3))
80139 penalize_intercept:: Bool = false
81140 " whether to scale the penalty with the number of samples."
82141 scale_penalty_with_samples:: Bool = true
83- " type of solver to use, default if `nothing`."
142+ """ some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`,
143+ `NewtonCG`, `ProxGrad`; but subject to the following restrictions:
144+
145+ - If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxyGrad` is the only
146+ option.
147+
148+ - Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
149+
150+ If `solver = nothing` (default) then `ProxGrad(accel=true)` (FISTA) is used,
151+ unless `gamma = 0`, in which case `LBFGS()` is used.
152+
153+ Solver aliases: `FISTA(; kwargs...) = ProxGrad(accel=true, kwargs...)`,
154+ `ISTA(; kwargs...) = ProxGrad(accel=false, kwargs...)`"""
84155 solver:: Option{Solver} = nothing
85156end
86157
@@ -91,7 +162,3 @@ glr(m::MultinomialClassifier, nclasses::Integer) =
91162 penalize_intercept= m. penalize_intercept,
92163 scale_penalty_with_samples= m. scale_penalty_with_samples,
93164 nclasses= nclasses)
94-
95- descr (:: Type{MultinomialClassifier} ) =
96- " Classifier corresponding to the loss function " *
97- " ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the multinomial loss."
0 commit comments