diff --git a/skbase/base/_base.py b/skbase/base/_base.py index c6a0d16a..ebb89ef8 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -64,6 +64,7 @@ class name: BaseEstimator from skbase.base._clone_base import _check_clone, _clone from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager +from skbase.utils._hybridmethod import hybridmethod __author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -204,7 +205,7 @@ def _get_clone_plugins(cls): Returns ------- - list of str + list of BaseCloner descendants, default = None List of clone plugins for descendants. Each plugin must inherit from ``BaseCloner`` in ``skbase.base._clone_plugins``, and implement @@ -567,6 +568,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None): flag_attr_name="_tags", ) + @hybridmethod def get_tags(self): """Get tags from instance, with tag level inheritance and overrides. @@ -599,8 +601,13 @@ def get_tags(self): class attribute via nested inheritance and then any overrides and new tags from ``_tags_dynamic`` object attribute. """ + if isinstance(self, type): + # if called on class, return class tags + return self.get_class_tags() + # if called on instance, return instance tags with overrides return self._get_flags(flag_attr_name="_tags") + @hybridmethod def get_tag(self, tag_name, tag_value_default=None, raise_error=True): """Get tag value from instance, with tag level inheritance and overrides. @@ -645,6 +652,12 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): The ``ValueError`` is then raised if ``tag_name`` is not in ``self.get_tags().keys()``. """ + if isinstance(self, type): + # if called on class, return class tag + return self.get_class_tag( + tag_name=tag_name, + tag_value_default=tag_value_default, + ) return self._get_flag( flag_name=tag_name, flag_value_default=tag_value_default, diff --git a/skbase/utils/_hybridmethod.py b/skbase/utils/_hybridmethod.py new file mode 100644 index 00000000..10f3c79f --- /dev/null +++ b/skbase/utils/_hybridmethod.py @@ -0,0 +1,20 @@ +"""Decorator for methods that can be called both on the class and on instances.""" + + +class hybridmethod: + """Decorator for methods that can be called both on the class and on instances. + + The decorated method will receive the class as the first argument when called + on the class, and the instance when called on an instance. + """ + + def __init__(self, func): + self.func = func + + def __get__(self, obj, cls): + """Get method that can be called on both class and instance.""" + + def wrapper(*args, **kwargs): + return self.func(obj if obj is not None else cls, *args, **kwargs) + + return wrapper diff --git a/skbase/utils/dependencies/_import.py b/skbase/utils/dependencies/_import.py index 9ba41f2b..1275ce34 100644 --- a/skbase/utils/dependencies/_import.py +++ b/skbase/utils/dependencies/_import.py @@ -23,7 +23,6 @@ def _safe_import(import_path, pkg_name=None, condition=True, return_object="Magi Example: ``clone = _safe_import("sklearn.clone", pkg_name="scikit-learn")``. - Parameters ---------- import_path : str diff --git a/skbase/utils/tests/test_hybridmethod.py b/skbase/utils/tests/test_hybridmethod.py new file mode 100644 index 00000000..3932457c --- /dev/null +++ b/skbase/utils/tests/test_hybridmethod.py @@ -0,0 +1,29 @@ +"""Tests for hybridmethod decorator.""" + +from inspect import isclass + +from skbase.utils._hybridmethod import hybridmethod + + +class HybridmethodTestclass: + + def __init__(self): + self.ref_to_self = self + + @hybridmethod + def method(self): + + if isclass(self): + assert self is self.ref_to_self + else: + assert self is self.ref_to_self + assert isinstance(self, self.__class__) + + +HybridmethodTestclass.ref_to_self = HybridmethodTestclass + + +def test_hybridmethod(): + """Test that hybridmethod works as expected.""" + HybridmethodTestclass.method() + HybridmethodTestclass().method()