Skip to content

Commit 1029d9d

Browse files
authored
🐛 Notify users about dependency condition (#150)
1 parent 5fe3b9a commit 1029d9d

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tests_unit/keras_assumptions_tests/test_experimental_calls_exist.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import inspect
2+
import unittest
23
from unittest import TestCase
34

45
import tensorflow as tf
56

7+
from uncertainty_wizard.internal_utils.tf_version_resolver import (
8+
current_tf_version_is_older_than,
9+
)
10+
611

712
class TestExperimentalAPIAreAvailable(TestCase):
813
"""
@@ -21,6 +26,9 @@ def test_list_physical_devices(self):
2126
self.assertTrue("device_type" in parameters)
2227
self.assertEqual(1, len(parameters))
2328

29+
@unittest.skipIf(
30+
not current_tf_version_is_older_than("2.10.0"), "Known to fail for tf >= 2.10.0"
31+
)
2432
def test_virtual_device_configuration(self):
2533
self.assertTrue("VirtualDeviceConfiguration" in dir(tf.config.experimental))
2634
parameters = inspect.signature(

uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import tensorflow as tf
99

10+
from uncertainty_wizard.internal_utils.tf_version_resolver import (
11+
current_tf_version_is_older_than,
12+
)
1013
from uncertainty_wizard.models.ensemble_utils._save_config import SaveConfig
1114

1215
global number_of_tasks_in_this_process
@@ -35,7 +38,7 @@ def __init__(self, model_id: int, varargs: dict = None):
3538
it will have to generate a context.
3639
Later, to make it easier for custom child classes of EnsembleContextManager,
3740
a (now still empty) varargs is also passed which may be populated with more information
38-
in future versions of uncertainty_wizard.
41+
in future s of uncertainty_wizard.
3942
"""
4043
self.ensemble_id = (model_id,)
4144
self.varargs = varargs
@@ -225,6 +228,15 @@ class DeviceAllocatorContextManager(EnsembleContextManager, abc.ABC):
225228
the abstract methods.
226229
"""
227230

231+
def __init__(self):
232+
super().__init__()
233+
if not current_tf_version_is_older_than("2.10.0"):
234+
raise RuntimeError(
235+
"The DeviceAllocatorContextManager is not compatible with tensorflow 2.10.0 "
236+
"or newer. Please fall back to a single GPU for now (see issue #75),"
237+
"or downgrade to tensorflow 2.9.0."
238+
)
239+
228240
# docstr-coverage: inherited
229241
def __enter__(self) -> "DeviceAllocatorContextManager":
230242
super().__enter__()

0 commit comments

Comments
 (0)