Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 192 additions & 47 deletions tests/test_extensions/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# License: BSD 3-Clause
from __future__ import annotations

import inspect
from collections import OrderedDict

import inspect
import numpy as np
import pytest

from unittest.mock import patch
import openml.testing
from openml.extensions import get_extension_by_flow, get_extension_by_model, register_extension
from openml.extensions import Extension, get_extension_by_flow, get_extension_by_model, register_extension


class DummyFlow:
Expand Down Expand Up @@ -40,54 +42,197 @@ def can_handle_model(model):
return False


def _unregister():
# "Un-register" the test extensions
while True:
rem_dum_ext1 = False
rem_dum_ext2 = False
try:
openml.extensions.extensions.remove(DummyExtension1)
rem_dum_ext1 = True
except ValueError:
pass
try:
openml.extensions.extensions.remove(DummyExtension2)
rem_dum_ext2 = True
except ValueError:
pass
if not rem_dum_ext1 and not rem_dum_ext2:
break
class DummyExtension(Extension):
@classmethod
def can_handle_flow(cls, flow):
return isinstance(flow, DummyFlow)

@classmethod
def can_handle_model(cls, model):
return isinstance(model, DummyModel)

def flow_to_model(
self,
flow,
initialize_with_defaults=False,
strict_version=True,
):
if not isinstance(flow, DummyFlow):
raise ValueError("Invalid flow")

model = DummyModel()
model.defaults = initialize_with_defaults
model.strict_version = strict_version
return model

def model_to_flow(self, model):
if not isinstance(model, DummyModel):
raise ValueError("Invalid model")
return DummyFlow()

def get_version_information(self):
return ["dummy==1.0"]

def create_setup_string(self, model):
return "DummyModel()"

def is_estimator(self, model):
return isinstance(model, DummyModel)

def seed_model(self, model, seed):
model.seed = seed
return model

def _run_model_on_fold(
self,
model,
task,
X_train,
rep_no,
fold_no,
y_train=None,
X_test=None,
):
preds = np.zeros(len(X_train))
probs = None
measures = OrderedDict()
trace = None
return preds, probs, measures, trace

def obtain_parameter_values(self, flow, model=None):
return []

def check_if_model_fitted(self, model):
return False

def instantiate_model_from_hpo_class(self, model, trace_iteration):
return DummyModel()



class TestInit(openml.testing.TestBase):
def setUp(self):
super().setUp()
_unregister()

def test_get_extension_by_flow(self):
assert get_extension_by_flow(DummyFlow()) is None
with pytest.raises(ValueError, match="No extension registered which can handle flow:"):
get_extension_by_flow(DummyFlow(), raise_if_no_extension=True)
register_extension(DummyExtension1)
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
register_extension(DummyExtension2)
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
register_extension(DummyExtension1)
with pytest.raises(
ValueError, match="Multiple extensions registered which can handle flow:"
):
get_extension_by_flow(DummyFlow())
# We replace the global list with a new empty list [] ONLY for this block
with patch("openml.extensions.extensions", []):
assert get_extension_by_flow(DummyFlow()) is None

with pytest.raises(ValueError, match="No extension registered which can handle flow:"):
get_extension_by_flow(DummyFlow(), raise_if_no_extension=True)

register_extension(DummyExtension1)
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)

register_extension(DummyExtension2)
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)

register_extension(DummyExtension1)
with pytest.raises(
ValueError, match="Multiple extensions registered which can handle flow:"
):
get_extension_by_flow(DummyFlow())

def test_get_extension_by_model(self):
assert get_extension_by_model(DummyModel()) is None
with pytest.raises(ValueError, match="No extension registered which can handle model:"):
get_extension_by_model(DummyModel(), raise_if_no_extension=True)
register_extension(DummyExtension1)
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
register_extension(DummyExtension2)
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
register_extension(DummyExtension1)
with pytest.raises(
ValueError, match="Multiple extensions registered which can handle model:"
):
get_extension_by_model(DummyModel())
# Again, we start with a fresh empty list automatically
with patch("openml.extensions.extensions", []):
assert get_extension_by_model(DummyModel()) is None

with pytest.raises(ValueError, match="No extension registered which can handle model:"):
get_extension_by_model(DummyModel(), raise_if_no_extension=True)

register_extension(DummyExtension1)
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)

register_extension(DummyExtension2)
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)

register_extension(DummyExtension1)
with pytest.raises(
ValueError, match="Multiple extensions registered which can handle model:"
):
get_extension_by_model(DummyModel())


def test_flow_to_model_with_defaults():
"""Test flow_to_model with initialize_with_defaults=True."""
ext = DummyExtension()
flow = DummyFlow()

model = ext.flow_to_model(flow, initialize_with_defaults=True)

assert isinstance(model, DummyModel)
assert model.defaults is True

def test_flow_to_model_strict_version():
"""Test flow_to_model with strict_version parameter."""
ext = DummyExtension()
flow = DummyFlow()

model_strict = ext.flow_to_model(flow, strict_version=True)
model_non_strict = ext.flow_to_model(flow, strict_version=False)

assert isinstance(model_strict, DummyModel)
assert model_strict.strict_version is True

assert isinstance(model_non_strict, DummyModel)
assert model_non_strict.strict_version is False

def test_model_to_flow_conversion():
"""Test converting a model back to flow representation."""
ext = DummyExtension()
model = DummyModel()

flow = ext.model_to_flow(model)

assert isinstance(flow, DummyFlow)


def test_invalid_flow_raises_error():
"""Test that invalid flow raises appropriate error."""
class InvalidFlow:
pass

ext = DummyExtension()
flow = InvalidFlow()

with pytest.raises(ValueError, match="Invalid flow"):
ext.flow_to_model(flow)


@patch("openml.extensions.extensions", [])
def test_extension_not_found_error_message():
"""Test error message contains helpful information."""
class UnknownModel:
pass

with pytest.raises(ValueError, match="No extension registered"):
get_extension_by_model(UnknownModel(), raise_if_no_extension=True)


def test_register_same_extension_twice():
"""Test behavior when registering same extension twice."""
# Using a context manager here to isolate the list
with patch("openml.extensions.extensions", []):
register_extension(DummyExtension)
register_extension(DummyExtension)

matches = [
ext for ext in openml.extensions.extensions
if ext is DummyExtension
]
assert len(matches) == 2


@patch("openml.extensions.extensions", [])
def test_extension_priority_order():
"""Test that extensions are checked in registration order."""
class DummyExtensionA(DummyExtension):
pass
class DummyExtensionB(DummyExtension):
pass

register_extension(DummyExtensionA)
register_extension(DummyExtensionB)

assert openml.extensions.extensions[0] is DummyExtensionA
assert openml.extensions.extensions[1] is DummyExtensionB