Skip to content

Commit fbc1658

Browse files
authored
Merge pull request #56 from JuliaAI/kind-of-learner
Bring dev up to date with master
2 parents 283de3f + 77de486 commit fbc1658

29 files changed

+618
-309
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ sandbox/
88
/docs/site/
99
/docs/Manifest.toml
1010
.vscode
11+
LocalPreferences.toml

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
12
name = "LearnAPI"
23
uuid = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
3-
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "1.0.1"
4+
version = "2.0.0"
55

66
[compat]
7+
Preferences = "1.5.0"
78
julia = "1.10"
89

10+
[deps]
11+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
12+
913
[extras]
1014
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1115

1216
[targets]
13-
test = ["Test",]
17+
test = ["Test"]

docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
44
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
5-
LearnTestAPI = "3111ed91-c4f2-40e7-bb19-7f6c618409b8"
65
MLCore = "c2834f40-e789-41da-a90e-33b280584a8c"
76
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
87
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ using Documenter
22
using LearnAPI
33
using ScientificTypesBase
44
using DocumenterInterLinks
5-
using LearnTestAPI
5+
# using LearnTestAPI
66

77
const REPO = Remotes.GitHub("JuliaAI", "LearnAPI.jl")
88

99
makedocs(
10-
modules=[LearnAPI, LearnTestAPI],
10+
modules=[LearnAPI, ], #LearnTestAPI],
1111
format=Documenter.HTML(
1212
prettyurls = true,#get(ENV, "CI", nothing) == "true",
1313
collapselevel = 1,
@@ -18,6 +18,7 @@ makedocs(
1818
"Reference" => [
1919
"Overview" => "reference.md",
2020
"Public Names" => "list_of_public_names.md",
21+
"Kinds of learner" => "kinds_of_learner.md",
2122
"fit/update" => "fit_update.md",
2223
"predict/transform" => "predict_transform.md",
2324
"Kinds of Target Proxy" => "kinds_of_target_proxy.md",

docs/src/anatomy_of_an_implementation.md

Lines changed: 72 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Anatomy of an Implementation
22

3-
The core LearnAPI.jl pattern looks like this:
3+
LearnAPI.jl supports three core patterns. The default pattern, known as the
4+
[`LearnAPI.Descriminative`](@ref) pattern, looks like this:
45

56
```julia
67
model = fit(learner, data)
@@ -10,38 +11,51 @@ predict(model, newdata)
1011
Here `learner` specifies [hyperparameters](@ref hyperparameters), while `model` stores
1112
learned parameters and any byproducts of algorithm execution.
1213

13-
Variations on this pattern:
14+
[Transformers](@ref) ordinarily implement `transform` instead of `predict`. For more on
15+
`predict` versus `transform`, see [Predict or transform?](@ref)
1416

15-
- [Transformers](@ref) ordinarily implement `transform` instead of `predict`. For more on
16-
`predict` versus `transform`, see [Predict or transform?](@ref)
17+
Two other `fit`/`predict`/`transform` patterns supported by LearnAPI.jl are:
18+
[`LearnAPI.Generative`](@ref) which has the form:
1719

18-
- ["Static" (non-generalizing) algorithms](@ref static_algorithms), which includes some
19-
simple transformers and some clustering algorithms, have a `fit` that consumes no
20-
`data`. Instead `predict` or `transform` does the heavy lifting.
20+
```julia
21+
model = fit(learner, data)
22+
predict(model) # a single distribution, for example
23+
```
2124

22-
- In [density estimation](@ref density_estimation), the `newdata` argument in `predict` is
23-
missing.
25+
and [`LearnAPI.Static`](@ref), which looks like this:
26+
27+
```julia
28+
model = fit(learner) # no `data` argument
29+
predict(model, data) # may mutate `model` to record byproducts of computation
30+
```
2431

25-
These are the basic possibilities.
32+
Do not read too much into the names for these patterns, which are formalized [here](@ref kinds_of_learner). Use may not always correspond to prior associations.
2633

27-
Elaborating on the core pattern above, this tutorial details an implementation of the
28-
LearnAPI.jl for naive [ridge regression](https://en.wikipedia.org/wiki/Ridge_regression)
29-
with no intercept. The kind of workflow we want to enable has been previewed in [Sample
30-
workflow](@ref). Readers can also refer to the [demonstration](@ref workflow) of the
31-
implementation given later.
34+
Elaborating on the common `Descriminative` pattern above, this tutorial details an
35+
implementation of the LearnAPI.jl for naive [ridge
36+
regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. The kind of
37+
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
38+
refer to the [demonstration](@ref workflow) of the implementation given later.
3239

33-
## A basic implementation
40+
!!! tip "Quick Start for new implementations"
3441

35-
See [here](@ref code) for code without explanations.
42+
1. From this tutorial, read at least "[A basic implementation](@ref)" below.
43+
1. Looking over the examples in "[Common Implementation Patterns](@ref patterns)", identify the appropriate core learner pattern above for your algorithm.
44+
1. Implement `fit` (probably following an existing example). Read the [`fit`](@ref) document string to see what else may need to be implemented, paying particular attention to the "New implementations" section.
45+
3. Rinse and repeat with each new method implemented.
46+
4. Identify any additional [learner traits](@ref traits) that have appropriate overloadings; use the [`@trait`](@ref) macro to define these in one block.
47+
5. Ensure your implementation includes the compulsory method [`LearnAPI.learner`](@ref) and compulsory traits [`LearnAPI.constructor`](@ref) and [`LearnAPI.functions`](@ref). Read and apply "[Testing your implementation](@ref)".
3648

37-
We suppose our algorithm's `fit` method consumes data in the form `(X, y)`, where
38-
`X` is a suitable table¹ (the features) and `y` a vector (the target).
49+
If you get stuck, refer back to this tutorial and the [Reference](@ref reference) sections.
3950

40-
!!! important
4151

42-
Implementations wishing to support other data
43-
patterns may need to take additional steps explained under
44-
[Other data patterns](@ref di) below.
52+
## A basic implementation
53+
54+
See [here](@ref code) for code without explanations.
55+
56+
Let us suppose our algorithm's `fit` method is to consume data in the form `(X, y)`, where
57+
`X` is a suitable table¹ (the features, a.k.a., covariates or predictors) and `y` a vector
58+
(the target, a.k.a., labels or response).
4559

4660
The first line below imports the lightweight package LearnAPI.jl whose methods we will be
4761
extending. The second imports libraries needed for the core algorithm.
@@ -110,7 +124,7 @@ Note that we also include `learner` in the struct, for it must be possible to re
110124
The implementation of `fit` looks like this:
111125

112126
```@example anatomy
113-
function LearnAPI.fit(learner::Ridge, data; verbosity=1)
127+
function LearnAPI.fit(learner::Ridge, data; verbosity=LearnAPI.default_verbosity())
114128
X, y = data
115129
116130
# data preprocessing:
@@ -158,6 +172,22 @@ If the kind of proxy is omitted, as in `predict(model, Xnew)`, then a fallback g
158172
first element of the tuple returned by [`LearnAPI.kinds_of_proxy(learner)`](@ref), which
159173
we overload appropriately below.
160174

175+
### Data deconstructors: `target` and `features`
176+
177+
LearnAPI.jl is flexible about the form of training `data`. However, to buy into
178+
meta-functionality, such as cross-validation, we'll need to say something about the
179+
structure of this data. We implement [`LearnAPI.target`](@ref) to say what
180+
part of the data constitutes a [target variable](@ref proxy), and
181+
[`LearnAPI.features`](@ref) to say what are the features (valid `newdata` in a
182+
`predict(model, newdata)` call):
183+
184+
```@example anatomy
185+
LearnAPI.target(learner::Ridge, (X, y)) = y
186+
LearnAPI.features(learner::Ridge, (X, y)) = X
187+
```
188+
189+
Another data deconstructor, for learners that support per-observation weights in training,
190+
is [`LearnAPI.weights`](@ref).
161191

162192
### [Accessor functions](@id af)
163193

@@ -241,15 +271,11 @@ the *type* of the argument.
241271
### The `functions` trait
242272

243273
The last trait, `functions`, above returns a list of all LearnAPI.jl methods that can be
244-
meaningfully applied to the learner or associated model, with the exception of traits. You
245-
always include the first five you see here: `fit`, `learner`, `clone` ,`strip`,
246-
`obs`. Here [`clone`](@ref) is a utility function provided by LearnAPI that you never
247-
overload, while [`obs`](@ref) is discussed under [Providing a separate data front
248-
end](@ref) below and is always included because it has a meaningful fallback. The
249-
`features` method, here provided by a fallback, articulates how the features `X` can be
250-
extracted from the training data `(X, y)`. We must also include `target` here to flag our
251-
model as supervised; again the method itself is provided by a fallback valid in the
252-
present case.
274+
meaningfully applied to the learner or the output of `fit` (denoted `model` above), with
275+
the exception of traits. You always include the first five you see here: `fit`, `learner`,
276+
`clone` ,`strip`, `obs`. Here [`clone`](@ref) is a utility function provided by LearnAPI
277+
that you never overload, while [`obs`](@ref) is discussed under [Providing a separate data
278+
front end](@ref) below and is always included because it has a meaningful fallback.
253279

254280
See [`LearnAPI.functions`](@ref) for a checklist of what the `functions` trait needs to
255281
return.
@@ -340,11 +366,6 @@ assumptions about data from those made above.
340366
under [Providing a separate data front end](@ref) below; or (ii) overload the trait
341367
[`LearnAPI.data_interface`](@ref) to specify a more relaxed data API.
342368

343-
- Where the form of data consumed by `fit` is different from that consumed by
344-
`predict/transform` (as in classical supervised learning) it may be necessary to
345-
explicitly overload the functions [`LearnAPI.features`](@ref) and (if supervised)
346-
[`LearnAPI.target`](@ref). The same holds if overloading [`obs`](@ref); see below.
347-
348369

349370
## Providing a separate data front end
350371

@@ -448,14 +469,14 @@ newobservations = MLCore.getobs(observations, test_indices)
448469
predict(model, newobservations)
449470
```
450471

451-
which works for any non-static learner implementing `predict`, no matter how one is
452-
supposed to accesses the individual observations of `data` or `newdata`. See also the
453-
demonstration [below](@ref advanced_demo). Furthermore, fallbacks ensure the above pattern
454-
still works if we choose not to implement a front end at all, which is allowed, if
455-
supported `data` and `newdata` already implement `getobs`/`numobs`.
472+
which works for any [`LearnAPI.Descriminative`](@ref) learner implementing `predict`, no
473+
matter how one is supposed to accesses the individual observations of `data` or
474+
`newdata`. See also the demonstration [below](@ref advanced_demo). Furthermore, fallbacks
475+
ensure the above pattern still works if we choose not to implement a front end at all,
476+
which is allowed, if supported `data` and `newdata` already implement `getobs`/`numobs`.
456477

457-
Here we specifically wrap all the preprocessed data into single object, for which we
458-
introduce a new type:
478+
In the ridge regression example we specifically wrap all the preprocessed data into single
479+
object, for which we introduce a new type:
459480

460481
```@example anatomy2
461482
struct RidgeFitObs{T,M<:AbstractMatrix{T}}
@@ -476,13 +497,13 @@ function LearnAPI.obs(::Ridge, data)
476497
end
477498
```
478499

479-
We informally refer to the output of `obs` as "observations" (see [The `obs`
480-
contract](@ref) below). The previous core `fit` signature is now replaced with two
500+
We informally refer to the output of `obs` as "observations" (see "[The `obs`
501+
contract](@ref)" below). The previous core `fit` signature is now replaced with two
481502
methods - one to handle "regular" input, and one to handle the pre-processed data
482503
(observations) which appears first below:
483504

484505
```@example anatomy2
485-
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
506+
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=LearnAPI.default_verbosity())
486507
487508
lambda = learner.lambda
488509
@@ -545,13 +566,10 @@ LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
545566
predict(model, Point(), obs(model, Xnew))
546567
```
547568

548-
### `features` and `target` methods
569+
### Data deconstructors: `features` and `target`
549570

550-
Two methods [`LearnAPI.features`](@ref) and [`LearnAPI.target`](@ref) articulate how
551-
features and target can be extracted from `data` consumed by LearnAPI.jl
552-
methods. Fallbacks provided by LearnAPI.jl sufficed in our basic implementation
553-
above. Here we must explicitly overload them, so that they also handle the output of
554-
`obs(learner, data)`:
571+
These methods must be able to handle any `data` supported by `fit`, which includes the
572+
output of `obs(learner, data)`:
555573

556574
```@example anatomy2
557575
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A

docs/src/common_implementation_patterns.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ This guide is intended to be consulted after reading [Anatomy of an Implementati
99
which introduces the main interface objects and terminology.
1010

1111
Although an implementation is defined purely by the methods and traits it implements, many
12-
implementations fall into one (or more) of the following informally understood patterns or
13-
tasks:
12+
implementations fall into one (or more) of the informally understood patterns or tasks
13+
below. While some generally fall into one of the core `Descriminative`, `Generative` or
14+
`Static` patterns detailed [here](@id kinds_of_learner), there are exceptions (such as
15+
clustering, which has both `Descriminative` and `Static` variations).
1416

1517
- [Regression](@ref): Supervised learners for continuous targets
1618

docs/src/examples.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct RidgeFitted{T,F}
3232
named_coefficients::F
3333
end
3434

35-
function LearnAPI.fit(learner::Ridge, data; verbosity=1)
35+
function LearnAPI.fit(learner::Ridge, data; verbosity=LearnAPI.default_verbosity())
3636
X, y = data
3737

3838
# data preprocessing:
@@ -57,6 +57,10 @@ end
5757
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
5858
Tables.matrix(Xnew)*model.coefficients
5959

60+
# data deconstructors:
61+
LearnAPI.target(learner::Ridge, (X, y)) = y
62+
LearnAPI.features(learner::Ridge, (X, y)) = X
63+
6064
# accessor functions:
6165
LearnAPI.learner(model::RidgeFitted) = model.learner
6266
LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients
@@ -125,7 +129,11 @@ function LearnAPI.obs(::Ridge, data)
125129
end
126130
LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations
127131

128-
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
132+
function LearnAPI.fit(
133+
learner::Ridge,
134+
observations::RidgeFitObs;
135+
verbosity=LearnAPI.default_verbosity(),
136+
)
129137

130138
lambda = learner.lambda
131139

@@ -159,7 +167,7 @@ LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
159167
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
160168
predict(model, Point(), obs(model, Xnew))
161169

162-
# methods to deconstruct training data:
170+
# training data deconstructors:
163171
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
164172
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
165173
LearnAPI.features(learner::Ridge, data) = LearnAPI.features(learner, obs(learner, data))
@@ -222,7 +230,7 @@ frontend = FrontEnds.Saffron()
222230
LearnAPI.obs(learner::Ridge, data) = FrontEnds.fitobs(learner, data, frontend)
223231
LearnAPI.obs(model::RidgeFitted, data) = obs(model, data, frontend)
224232

225-
function LearnAPI.fit(learner::Ridge, observations::FrontEnds.Obs; verbosity=1)
233+
function LearnAPI.fit(learner::Ridge, observations::FrontEnds.Obs; verbosity=LearnAPI.default_verbosity())
226234

227235
lambda = learner.lambda
228236

docs/src/features_target_weights.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Methods for extracting certain parts of `data` for all supported calls of the fo
44
[`fit(learner, data)`](@ref).
55

66
```julia
7-
LearnAPI.features(learner, data) -> <training "features", suitable input for `predict` or `transform`>
7+
LearnAPI.features(learner, data) -> <training "features"; suitable input for `predict` or `transform`>
88
LearnAPI.target(learner, data) -> <target variable>
99
LearnAPI.weights(learner, data) -> <per-observation weights>
1010
```
@@ -29,11 +29,11 @@ training_loss = sum(ŷ .!= y)
2929

3030
# Implementation guide
3131

32-
| method | fallback return value | compulsory? |
33-
|:-------------------------------------------|:---------------------------------------------:|--------------------------|
34-
| [`LearnAPI.features(learner, data)`](@ref) | `first(data)` if `data` is tuple, else `data` | if fallback insufficient |
35-
| [`LearnAPI.target(learner, data)`](@ref) | `last(data)` | if fallback insufficient |
36-
| [`LearnAPI.weights(learner, data)`](@ref) | `nothing` | no |
32+
| method | fallback return value | compulsory? |
33+
|:-------------------------------------------|:---------------------:|-------------|
34+
| [`LearnAPI.features(learner, data)`](@ref) | no fallback | no |
35+
| [`LearnAPI.target(learner, data)`](@ref) | no fallback | no |
36+
| [`LearnAPI.weights(learner, data)`](@ref) | no fallback | no |
3737

3838

3939
# Reference

0 commit comments

Comments
 (0)