Skip to content

Commit feb3b49

Browse files
Apply suggestion from @kylesayrs
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent f72f778 commit feb3b49

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

src/compressed_tensors/utils/helpers.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,19 +391,10 @@ def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]):
391391
>>> assert not hasattr(obj1, "attribute")
392392
>>> assert not hasattr(obj2, "attribute")
393393
"""
394-
_sentinel = object()
395-
original_values = [getattr(base, attr, _sentinel) for base in bases]
396-
397-
for base, value in zip(bases, values):
398-
setattr(base, attr, value)
399-
try:
394+
with contextlib.exitstack() as stack:
395+
for base, value in zip(bases, values):
396+
stack.add(patch_attr(base, attr, value))
400397
yield
401-
finally:
402-
for base, original_value in zip(bases, original_values):
403-
if original_value is not _sentinel:
404-
setattr(base, attr, original_value)
405-
else:
406-
delattr(base, attr)
407398

408399

409400
class ParameterizedDefaultDict(dict):

0 commit comments

Comments
 (0)