Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class name: BaseEstimator
from skbase.base._clone_base import _check_clone, _clone
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
from skbase.utils._hybridmethod import hybridmethod

__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]
Expand Down Expand Up @@ -204,7 +205,7 @@ def _get_clone_plugins(cls):

Returns
-------
list of str
list of BaseCloner descendants, default = None
List of clone plugins for descendants.
Each plugin must inherit from ``BaseCloner``
in ``skbase.base._clone_plugins``, and implement
Expand Down Expand Up @@ -567,6 +568,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
flag_attr_name="_tags",
)

@hybridmethod
def get_tags(self):
"""Get tags from instance, with tag level inheritance and overrides.

Expand Down Expand Up @@ -599,8 +601,13 @@ def get_tags(self):
class attribute via nested inheritance and then any overrides
and new tags from ``_tags_dynamic`` object attribute.
"""
if isinstance(self, type):
# if called on class, return class tags
return self.get_class_tags()
# if called on instance, return instance tags with overrides
return self._get_flags(flag_attr_name="_tags")

@hybridmethod
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from instance, with tag level inheritance and overrides.

Expand Down Expand Up @@ -645,6 +652,12 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
The ``ValueError`` is then raised if ``tag_name`` is
not in ``self.get_tags().keys()``.
"""
if isinstance(self, type):
# if called on class, return class tag
return self.get_class_tag(
tag_name=tag_name,
tag_value_default=tag_value_default,
)
return self._get_flag(
flag_name=tag_name,
flag_value_default=tag_value_default,
Expand Down
20 changes: 20 additions & 0 deletions skbase/utils/_hybridmethod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Decorator for methods that can be called both on the class and on instances."""


class hybridmethod:
"""Decorator for methods that can be called both on the class and on instances.

The decorated method will receive the class as the first argument when called
on the class, and the instance when called on an instance.
"""

def __init__(self, func):
self.func = func

def __get__(self, obj, cls):
"""Get method that can be called on both class and instance."""

def wrapper(*args, **kwargs):
return self.func(obj if obj is not None else cls, *args, **kwargs)

return wrapper
1 change: 0 additions & 1 deletion skbase/utils/dependencies/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def _safe_import(import_path, pkg_name=None, condition=True, return_object="Magi

Example: ``clone = _safe_import("sklearn.clone", pkg_name="scikit-learn")``.


Parameters
----------
import_path : str
Expand Down
29 changes: 29 additions & 0 deletions skbase/utils/tests/test_hybridmethod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Tests for hybridmethod decorator."""

from inspect import isclass

from skbase.utils._hybridmethod import hybridmethod


class HybridmethodTestclass:

def __init__(self):
self.ref_to_self = self

@hybridmethod
def method(self):

if isclass(self):
assert self is self.ref_to_self
else:
assert self is self.ref_to_self
assert isinstance(self, self.__class__)


HybridmethodTestclass.ref_to_self = HybridmethodTestclass


def test_hybridmethod():
"""Test that hybridmethod works as expected."""
HybridmethodTestclass.method()
HybridmethodTestclass().method()
Loading