Skip to content

Commit e8bc4ac

Browse files
committed
Update test
Signed-off-by: dannawang <dannawang@google.com>
1 parent cc97559 commit e8bc4ac

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tests/distributed/offload/tpu_offload_connector_worker_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import functools
4+
import gc
45
import os
56
import random
67
from typing import List
@@ -94,9 +95,17 @@ def setUp(self):
9495

9596
def tearDown(self):
9697
super().tearDown()
98+
# Destroy references explicitly
99+
if hasattr(self, 'connector'):
100+
del self.connector
101+
102+
# Force JAX to release memory
97103
cc.reset_cache()
98104
jax.clear_caches()
99105

106+
# Force Python GC
107+
gc.collect()
108+
100109
def create_mesh(self, axis_shapes, axis_names):
101110
"""Creates a JAX device mesh with the default device order."""
102111
try:

0 commit comments

Comments
 (0)