Skip to content

Commit 36aeb1c

Browse files
ENH: Fix doc link in HTML representation of estimators (#2131)
* add doc link in interactive mode * Update sklearnex/_utils.py Co-authored-by: Ian Faust <icfaust@gmail.com> * rename base class * reword * Update sklearnex/_utils.py Co-authored-by: Ian Faust <icfaust@gmail.com> * linter --------- Co-authored-by: Ian Faust <icfaust@gmail.com>
1 parent 34617fc commit 36aeb1c

File tree

5 files changed

+27
-8
lines changed

5 files changed

+27
-8
lines changed

sklearnex/_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import sys
2020
import warnings
21+
from abc import ABC
2122

2223
from daal4py.sklearn._utils import (
2324
PatchingConditionsChain as daal4py_PatchingConditionsChain,
@@ -113,3 +114,19 @@ def get_hyperparameters(self, op):
113114
return estimator_class
114115

115116
return wrap_class
117+
118+
119+
# This abstract class is meant to generate a clickable doc link for classses
120+
# in sklearnex that are not part of base scikit-learn. It should be inherited
121+
# before inheriting from a scikit-learn estimator, otherwise will get overriden
122+
# by the estimator's original.
123+
class IntelEstimator(ABC):
124+
@property
125+
def _doc_link_module(self) -> str:
126+
return "sklearnex"
127+
128+
@property
129+
def _doc_link_template(self) -> str:
130+
module_path, _ = self.__class__.__module__.rsplit(".", 1)
131+
class_name = self.__class__.__name__
132+
return f"https://intel.github.io/scikit-learn-intelex/latest/non-scikit-algorithms.html#{module_path}.{class_name}"

sklearnex/basic_statistics/basic_statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
2727

2828
from .._device_offload import dispatch
29-
from .._utils import PatchingConditionsChain
29+
from .._utils import IntelEstimator, PatchingConditionsChain
3030

3131
if sklearn_check_version("1.6"):
3232
from sklearn.utils.validation import validate_data
@@ -38,7 +38,7 @@
3838

3939

4040
@control_n_jobs(decorated_methods=["fit"])
41-
class BasicStatistics(BaseEstimator):
41+
class BasicStatistics(IntelEstimator, BaseEstimator):
4242
"""
4343
Estimator for basic statistics.
4444
Allows to compute basic statistics for provided data.

sklearnex/basic_statistics/incremental_basic_statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727

2828
from .._device_offload import dispatch
29-
from .._utils import PatchingConditionsChain
29+
from .._utils import IntelEstimator, PatchingConditionsChain
3030

3131
if sklearn_check_version("1.2"):
3232
from sklearn.utils._param_validation import Interval, StrOptions
@@ -41,7 +41,7 @@
4141

4242

4343
@control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"])
44-
class IncrementalBasicStatistics(BaseEstimator):
44+
class IncrementalBasicStatistics(IntelEstimator, BaseEstimator):
4545
"""
4646
Calculates basic statistics on the given data, allows for computation when the data are split into
4747
batches. The user can use ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide

sklearnex/covariance/incremental_covariance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from sklearnex import config_context
3434

3535
from .._device_offload import dispatch, wrap_output_data
36-
from .._utils import PatchingConditionsChain, register_hyperparameters
36+
from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters
3737
from ..metrics import pairwise_distances
3838
from ..utils._array_api import get_namespace
3939

@@ -47,7 +47,7 @@
4747

4848

4949
@control_n_jobs(decorated_methods=["partial_fit", "fit", "_onedal_finalize_fit"])
50-
class IncrementalEmpiricalCovariance(BaseEstimator):
50+
class IncrementalEmpiricalCovariance(IntelEstimator, BaseEstimator):
5151
"""
5252
Maximum likelihood covariance estimator that allows for the estimation when the data are split into
5353
batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide

sklearnex/linear_model/incremental_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from onedal.common.hyperparameters import get_hyperparameters
4141

4242
from .._device_offload import dispatch, wrap_output_data
43-
from .._utils import PatchingConditionsChain, register_hyperparameters
43+
from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters
4444

4545

4646
@register_hyperparameters(
@@ -52,7 +52,9 @@
5252
@control_n_jobs(
5353
decorated_methods=["fit", "partial_fit", "predict", "score", "_onedal_finalize_fit"]
5454
)
55-
class IncrementalLinearRegression(MultiOutputMixin, RegressorMixin, BaseEstimator):
55+
class IncrementalLinearRegression(
56+
IntelEstimator, MultiOutputMixin, RegressorMixin, BaseEstimator
57+
):
5658
"""
5759
Trains a linear regression model, allows for computation if the data are split into
5860
batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide

0 commit comments

Comments
 (0)