Skip to content

Commit cab3054

Browse files
committed
more stuff
1 parent 3e2a5da commit cab3054

File tree

10 files changed

+95
-57
lines changed

10 files changed

+95
-57
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
44
version = "0.1.0"
55

66
[deps]
7+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[extras]

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ makedocs(;
99
pages=[
1010
"Introduction" => "index.md",
1111
"Anatomy of an Implementation" => "anatomy_of_an_implementation.md",
12-
"Common Implementation Patterns" => "common_implementation_patterns.md",
1312
"Reference" => "reference.md",
1413
"Fit, update and ingest" => "fit_update_and_ingest.md",
1514
"Predict and other operations" => "operations.md",
1615
"Model Traits" => "model_traits.md",
16+
"Common Implementation Patterns" => "common_implementation_patterns.md",
1717
],
1818
repo="https://$REPO/blob/{commit}{path}#L{line}",
1919
sitename="LearnAPI.jl"

docs/src/anatomy_of_an_implementation.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
> dispatched on the model type; `predict` is an example of an **operation** (another is
66
> `transform`). In this example we also implement an **accessor function**, called
77
> `feature_importance`, returning the absolute values of the linear coefficients. The
8-
> ridge regressor has a target variable and one trait declaration flags the output of
9-
> `predict` as being a [proxy](@ref scope) for the target. Other traits articulate the
10-
> model's training data type requirements and the input/output type of `predict`.
8+
> ridge regressor has a target variable and `predict` makes literal predictions of the
9+
> target (rather than, say, probablistic predictions); this behaviour is flagged by the
10+
> `target_proxies` model trait. Other traits articulate the model's training data type
11+
> requirements and the input/output type of `predict`.
1112
1213
We begin by describing an implementation of LearnAPI.jl for basic ridge
1314
regression (no intercept) to introduce the main actors in any implementation.
@@ -58,7 +59,8 @@ function LearnAPI.fit(model::MyRidge, verbosity, X, y)
5859
5960
# process input:
6061
x = Tables.matrix(X) # convert table to matrix
61-
features = Tables.columnnames(X)
62+
s = Tables.schema(X)
63+
features = s.names
6264
6365
# core solver:
6466
coefficients = (x'x + model.lambda*I)\(x'y)
@@ -140,28 +142,28 @@ nothing # hide
140142
Another example of an accessor function is [`training_losses`](@ref).
141143

142144

143-
## [Model traits](@id traits)
145+
## [Model traits](@id traits)
144146

145147
Our model has a target variable, in the sense outlined in [Scope and undefined
146148
notions](@ref scope), and `predict` returns an object with exactly the same form as the
147149
target. We indicate this behaviour by declaring
148150

149151
```@example anatomy
150-
LearnAPI.target_proxy(::Type{<:MyRidge}) = (; predict=LearnAPI.TrueTarget())
152+
LearnAPI.target_proxies(::Type{<:MyRidge}) = (; predict=LearnAPI.TrueTarget())
151153
nothing # hide
152154
```
153155
Or, you can use the shorthand
154156

155157
```@example anatomy
156-
@trait MyRidge target_proxy = (; predict=LearnAPI.TrueTarget())
158+
@trait MyRidge target_proxies = (; predict=LearnAPI.TrueTarget())
157159
nothing # hide
158160
```
159161

160162
More generally, `predict` only returns a *proxy* for the target, such as probability
161163
distributions, and we would make a different declaration here. See [Target proxies](@ref)
162164
for details.
163165

164-
`LearnAPI.target_proxy` is an example of a **model trait**. A complete list of traits
166+
`LearnAPI.target_proxies` is an example of a **model trait**. A complete list of traits
165167
and the contracts they imply is given in [Model Traits](@ref).
166168

167169
> **MLJ only.** The values of all traits constitute a model's **metadata**, which is

docs/src/index.md

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ model. Probability distributions, confidence intervals and survival functions ar
4141
of [Target proxies](@ref). LearnAPI provides a trait for distinguishing such models based
4242
on the target proxy.
4343

44+
LearnAPI does not provide an interface for data access or data resampling, and could be
45+
used in conjunction with one or more such interfaces (e.g.,
46+
[Tables.jl](https://github.com/JuliaML/MLUtils.jl),
47+
[MLJUtils.jl](https://github.com/JuliaML/MLUtils.jl)).
48+
4449
## Methods
4550

4651
In LearnAPI.jl a *model* is just a container for the hyper-parameters of some machine
@@ -56,17 +61,17 @@ The following methods, dispatched on model type, are provided:
5661

5762
- `ingest!` for incremental learning
5863

59-
- **operations**, such as `predict`, `transform` and `inverse_transform` for applying the
60-
model to data not used for training
64+
- **operations**, `predict`, `predict_joint`, `transform` and `inverse_transform` for
65+
applying the model to data not used for training
6166

6267
- common **accessor functions**, such as `feature_importances` and `training_losses`, for
63-
extracting, from training outcomes, information common to different types of models
68+
extracting, from training outcomes, information common to some models
6469

65-
- **model traits**, such as `target_proxy(model)`, for promising specific behaviour
70+
- **model traits**, such as `target_proxies(model)`, for promising specific behaviour
6671

6772
There is flexibility about how much of the interface is implemented by a given model
6873
object `model`. A special trait `functions(model)` declares what has been explicitly
69-
implemented or overloaded to work with `model`, excluding traits.
74+
implemented to work with `model`, excluding traits.
7075

7176
Since this is a functional-style interface, `fit` returns model `state`, in addition to
7277
learned parameters, for passing to the optional `update!` and `ingest!` methods. These
@@ -77,12 +82,13 @@ component (important for models that do not generalize to new data).
7782
Models can be supervised or not supervised, can generalize to new data observations, or
7883
not generalize. To ensure proper handling by client packages of probabilistic and other
7984
non-literal forms of target predictions (pdfs, confidence intervals, survival functions,
80-
etc) the kind of prediction can be flagged appropriately; see more at "target" below.
85+
etc) the output of `predict` and `predict_joint` can be flagged appropriately; see more at
86+
"target" below.
8187

8288

8389
## [Scope and undefined notions](@id scope)
8490

85-
The Learn API provides methods for training, applying, and saving machine learning models,
91+
LearnAPI.jl provides methods for training, applying, and saving machine learning models,
8692
and that is all. *It does not specify an interface for data access or data
8793
resampling*. However, LearnAPI.jl is predicated on a few basic undefined notions (in
8894
**boldface**) which some higher-level interface might decide to formalize:
@@ -115,16 +121,16 @@ resampling*. However, LearnAPI.jl is predicated on a few basic undefined notions
115121

116122
## Contents
117123

118-
Our opening observations notwithstanding, it is useful to have a guide to the interface,
119-
linked below, organized around common *informally defined* patterns or "tasks". However,
120-
the definitive specification of the interface is the [Reference](@ref) section.
124+
It is useful to have a guide to the interface, linked below, organized around common
125+
*informally defined* patterns or "tasks". However, the definitive specification of the
126+
interface is the [Reference](@ref) section.
121127

122128
- [Anatomy of an Implementation](@ref) (Overview)
123129

124-
- [Common Implementation Patterns](@ref) (User Guide)
125-
126130
- [Reference](@ref) (Official Specification)
127131

132+
- [Common Implementation Patterns](@ref) (User Guide)
133+
128134
- [Testing an Implementation](@ref)
129135

130136
!!! info

docs/src/model_traits.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
# Model Traits
22

3-
In this table, `Table` and `Continuous` are names owned by the package
4-
[ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
3+
Ordinary traits are available for overloading by an new model implementation. Derived
4+
traits are not.
5+
6+
## Ordinary traits
7+
8+
In the examples column of the table below, `Table` and `Continuous` are names owned by the
9+
package [ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
510

611
| trait | fallback value | return value | example |
712
|:-------------------------------------------------|:----------------------|:--------------|:--------|
813
| [`LearnAPI.functions`](@ref)`(model)` | `()` | implemented LearnAPI functions (traits excluded) | `(:fit, :predict)` |
9-
| [`LearnAPI.target_proxy`](@ref)`(model)` | `NamedTuple()` | details form of target proxy output | `(; predict=LearnAPI.Distribution()` |
14+
| [`LearnAPI.target_proxies`](@ref)`(model)` | `NamedTuple()` | details form of target proxy output | `(; predict=LearnAPI.Distribution()` |
1015
| [`LearnAPI.position_of_target`](@ref)`(model)` | `0` | † the positional index of the **target** in `data` in `fit(..., data...; metadata)` calls | 2 |
1116
| [`LearnAPI.position_of_weights`](@ref)`(model)` | `0` | † the positional index of **observation weights** in `data` in `fit(..., data...; metadata)` | 3 |
1217
| [`LearnAPI.descriptors`](@ref)`(model)` | `()` | lists one or more suggestive model descriptors from `LearnAPI.descriptors()` | (:classifier, :probabilistic) |
@@ -35,3 +40,10 @@ is understood to exclude the variable, but note that `fit` can have multiple sig
3540
varying lengths, as in `fit(model, verbosity, X, y)` and `fit(model, verbosity, X, y,
3641
w)`. A non-zero value is a promise that `fit` includes a signature of sufficient length to
3742
include the variable.
43+
44+
## Dervied Traits
45+
46+
| trait | return value | example |
47+
|:---------------------------------------|:--------------------------|:--------|
48+
| [`LearnAPI.name`](@ref)`(model)` | model type name as string | "PCA" |
49+
| [`LearnAPI.ismodel`](@ref)`(model)` | `true` if `functions(model)` is not empty | `true` |

docs/src/operations.md

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
> **Summary** Methods like `predict` and `transform`, that generally depend on learned
44
> parameters, are called **operations**. All implemented operations must be included in
5-
> the output of the `methods` model trait. When an operation returns a [target
6-
> proxy](@ref scope), it must make a `target_proxy` declaration.
5+
> the output of the `functions` model trait. When an operation returns a [target
6+
> proxy](@ref scope), it must make a `target_proxies` declaration.
77
88
An *operation* is any method with signature `some_operation(model, fitted_params,
99
data...)`. Here `fitted_params` is the learned parameters object, as returned by
@@ -23,17 +23,14 @@ ŷ, predict_report = LearnAPI.predict(some_model, fitted_params, Xnew)
2323
[`LearnAPI.transform`](@ref) | no | none | |
2424
[`LearnAPI.inverse_transform`](@ref) | no | none | `transform` |
2525

26-
> **† MLJ only.** MLJBase provides fallbacks for `predict_mode`, `predict_mean` and
27-
> `predict_median` by broadcasting methods from `Statistics` and `StatsBase` over the
28-
> results of `predict`.
2926

3027
## General requirements
3128

3229
- Only implement `predict_joint` for outputing a *single* multivariate probability
3330
distribution for multiple target predictions, as described further at
3431
[`LearnAPI.predict_joint`](@ref).
3532

36-
- Each operation explicitly implemented or overloaded must be included in the return value
33+
- Each operation explicitly overloaded must be included in the return value
3734
of [`LearnAPI.functions`](@ref).
3835

3936
## Predict or transform?
@@ -91,27 +88,33 @@ have no fields.
9188
| `LearnAPI.SurvivalFunction` | survival function (possible requirement: observation is single-argument function mapping `Real` to `Real`) |
9289
| `LearnAPI.SurvivalDistribution` | probability distribution for survival time (possible requirement: observation have type `Distributions.ContinuousUnivariateDistribution`) |
9390

94-
> **† MLJ only.** To avoid [ambiguities in
95-
> representation](https://github.com/alan-turing-institute/MLJ.jl/blob/dev/paper/paper.md#a-unified-approach-to-probabilistic-predictions-and-their-evaluation),
96-
> these options are disallowed, in favour of preceding alternatives.
91+
† Provided for completeness but discouraged to avoid [ambiguities in
92+
representation](https://github.com/alan-turing-institute/MLJ.jl/blob/dev/paper/paper.md#a-unified-approach-to-probabilistic-predictions-and-their-evaluation).
93+
9794

9895
!!! warning
9996

10097
The "possible requirement"s listed are not part of LearnAPI.jl.
10198

10299
An operation with target proxy as output must declare a `TargetProxy` instance using the
103-
[`LearnAPI.target_proxy`](@ref), as in
100+
[`LearnAPI.target_proxies`](@ref), as in
101+
102+
```julia
103+
LearnAPI.target_proxies(::Type{<:SomeModel}) = (predict=LearnAPI.Distribution(),)
104+
```
105+
106+
which has the short form
104107

105108
```julia
106-
LearnAPI.target_proxy(::Type{<:SomeModel}) = (predict=LearnAPI.Distribution(),)
109+
LearnAPI.@trait target_proxies = (predict=LearnAPI.Distribution(),)
107110
```
108111

109-
If `predict_joint` is implemented, then a `target_proxy` declaration is also
112+
If `predict_joint` is implemented, then a `target_proxies` declaration is also
110113
required, but the interpretation is slightly different. This is because the output of
111114
`predict_joint` is not a number of observations but a single object. The trait value
112115
should be an instance of one of the following types:
113116

114-
| type | form of output of `predict_joint(model, _, data)`
117+
| type | form of output of `predict_joint(model, fitted_params, data)`
115118
|:-------------------------------:|:--------------------------------------------------|
116119
| `LearnAPI.Sampleable` | object that can be sampled to obtain a *vector* whose elements have the form of target observations; the vector length matches the number of observations in `data`. |
117120
| `LearnAPI.Distribution` | explicit probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` |

docs/src/reference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Reference
22

3-
Here we give the definitive specification of interface provided by LearnAPI.jl. For a more
4-
informal guide see [Common Implementation Patterns](@ref).
3+
Here we give the definitive specification of the interface provided by LearnAPI.jl. For a
4+
more informal guide see [Common Implementation Patterns](@ref).
55

66
## Models
77

src/LearnAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LearnAPI
22

33
using Statistics
4+
using InteractiveUtils
45

56
include("tools.jl")
67
include("models.jl")

src/model_traits.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
const DERIVED_TRAITS = (:name, :ismodel)
55
const ORDINARY_TRAITS = (
66
:functions,
7-
:target_proxy,
7+
:target_proxies,
88
:position_of_target,
99
:position_of_weights,
1010
:descriptors,
@@ -62,7 +62,22 @@ See also [`LearnAPI.Model`](@ref).
6262
"""
6363
functions(::Type) = ()
6464

65-
target_proxy(::Type) = NamedTuple()
65+
target_proxies() = subtypes(TargetProxy)
66+
67+
"""
68+
target_proxies(model)
69+
70+
Return a named tuple of target proxies, keyed on operation name, applying to `model`. For
71+
example, a value of
72+
73+
(predict=LearnAPI.Distribution(),)
74+
75+
means that `LearnAPI.predict` returns probability distributions, rather than actual values
76+
of the target. View all target proxy types with `target_proxies()`. For more information
77+
on target variables and target proxies, refer to the LearnAPI documentation.
78+
79+
"""
80+
target_proxies(::Type) = NamedTuple()
6681

6782
position_of_target(::Type) = 0
6883

src/operations.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
const PREDICT_OPERATIONS = (
2-
31
const OPERATIONS = (:predict, :predict_joint, :transform, :inverse_transform)
42

53
const DOC_NEW_DATA =
6-
"Here `report` contains ancilliary byproducts of the computation, or "*
4+
"The `report` contains ancilliary byproducts of the computation, or "*
75
"is `nothing`; `data` is a tuple of data objects, "*
86
"generally a single object representing new observations "*
97
"not seen in training. "
@@ -16,7 +14,7 @@ const DOC_NEW_DATA =
1614
1715
Return `(ŷ, report)` where `ŷ` are the predictions, or prediction-like output (such as
1816
probabilities), for a machine learning model `model`, with learned parameters
19-
`fitted_params`, as returned by a preceding call to [`LearnAPI.fit`](@ref)`(model, ...)`.
17+
`fitted_params` (first object returned by [`LearnAPI.fit`](@ref)`(model, ...)`).
2018
$DOC_NEW_DATA
2119
2220
@@ -36,13 +34,13 @@ implementation itself promises, by making an optional [`LearnAPI.output_scitypes
3634
declaration.
3735
3836
If `predict` is computing a target proxy, as defined in the MLJLearn documentation, then a
39-
[`LearnAPI.target_proxy`](@ref) declaration is required, as in
37+
[`LearnAPI.target_proxies`](@ref) declaration is required, as in
4038
4139
```julia
42-
LearnAPI.target_proxy(::Type{<:SomeModel}) = (predict=LearnAPI.Distribution,)
40+
LearnAPI.target_proxies(::Type{<:SomeModel}) = (predict=LearnAPI.Distribution,)
4341
```
4442
45-
Do `LearnAPI.target_proxy()` to list the available kinds.
43+
Do `LearnAPI.target_proxies()` to list the available kinds.
4644
4745
By default, it is expected that `data` has length one. Otherwise,
4846
[`LearnAPI.input_scitypes`](@ref) must be overloaded.
@@ -96,10 +94,10 @@ For a supervised learning model, return `(d, report)`, where `d` is some represe
9694
the *single* probability distribution for the sample space ``Y^n``. Here ``Y`` is the
9795
space in which the target variable associated with `model` takes its values, and `n` is
9896
the number of observations in `data`. The specific form of the representation is given by
99-
`LearnAPI.target_proxy(model)`.
97+
[`LearnAPI.target_proxies(model)`](@ref).
10098
101-
Here `fitted_params` are the model's learned parameters, as returned by a preceding call
102-
to [`LearnAPI.fit`](@ref). $DOC_NEW_DATA.
99+
Here `fitted_params` are the model's learned parameters (the first object returned by
100+
[`LearnAPI.fit`](@ref)). $DOC_NEW_DATA.
103101
104102
While the interpretation of this distribution depends on the model, marginalizing
105103
component-wise will generally deliver *correlated* univariate distributions, and these will
@@ -109,10 +107,10 @@ generally not agree with those returned by `LearnAPI.predict`, if also implement
109107
110108
Only implement this method if `model` has an associated concept of target variable, as
111109
defined in the LearnAPI.jl documentation. A trait declaration
112-
[`LearnAPI.target_proxy`](@ref), such as
110+
[`LearnAPI.target_proxies`](@ref), such as
113111
114112
```julia
115-
LearnAPI.target_proxy(::Type{SomeModel}) = (; predict_joint=Sampleable())
113+
LearnAPI.target_proxies(::Type{SomeModel}) = (; predict_joint=Sampleable())
116114
```
117115
118116
is required. Here the possible kinds of target proxies are `LearnAPI.Sampleable`,
@@ -129,9 +127,9 @@ function predict_joint end
129127
LearnAPI.transform(model, fitted_params, data...)
130128
131129
Return `(output, report)`, where `output` is some kind of transformation of `data`,
132-
provided by `model`, based on the learned parameters `fitted_params`, as returned by a
133-
preceding call to [`LearnAPI.fit`](@ref)`(model, ...)` (which could be `nothing` for
134-
models that do not generalize to new data, such as "static transformers"). $DOC_NEW_DATA
130+
provided by `model`, based on the learned parameters `fitted_params` (the first object
131+
returned by [`LearnAPI.fit`](@ref)`(model, ...)`). The `fitted_params` could be `nothing`,
132+
in the case of models that do not generalize to new data. $DOC_NEW_DATA
135133
136134
137135
# New model implementations
@@ -167,7 +165,7 @@ the map
167165
data -> first(transform(model, fitted_params, data))
168166
```
169167
170-
For example, if `transform` corresponds to a projection, `inverse_transform` is the
168+
For example, if `transform` corresponds to a projection, `inverse_transform` might be the
171169
corresponding embedding.
172170
173171

0 commit comments

Comments
 (0)