Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
preupgrade_validator*.tgz

# Translations
*.mo
Expand Down
48 changes: 41 additions & 7 deletions aci-preupgrade-validation-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,13 +1023,17 @@ def __init__(
common_kwargs,
monitor_interval=0.5, # sec
monitor_timeout=600, # sec
max_threads=None,
callback_on_monitoring=None,
callback_on_start_failure=None,
callback_on_timeout=None,
):
self.funcs = funcs
self.threads = None
self.common_kwargs = common_kwargs
# Semaphore to cap the number of concurrently running check threads.
# None means unlimited.
self.semaphore = threading.Semaphore(max_threads) if max_threads and max_threads > 0 else None

# Not using `thread.join(timeout)` because it waits for each thread sequentially,
# which means the program may wait for "timeout * num of threads" at worst case.
Expand All @@ -1053,7 +1057,7 @@ def start(self):
raise RuntimeError("Threading on going. Cannot start again.")

self.threads = [
self._generate_thread(target=func, kwargs=self.common_kwargs)
self._generate_thread(target=func, kwargs=self.common_kwargs, use_semaphore=True)
for func in self.funcs
]

Expand All @@ -1080,9 +1084,19 @@ def join(self):
def is_timeout(self):
return self.timeout_event.is_set()

def _generate_thread(self, target, args=(), kwargs=None):
def _generate_thread(self, target, args=(), kwargs=None, use_semaphore=False):
if kwargs is None:
kwargs = {}
if use_semaphore and self.semaphore is not None:
semaphore = self.semaphore
original_target = target
def _wrapped_target(*a, **kw):
try:
original_target(*a, **kw)
finally:
semaphore.release()
_wrapped_target.__name__ = target.__name__
target = _wrapped_target
thread = CustomThread(
target=target, name=target.__name__, args=args, kwargs=kwargs
)
Expand All @@ -1102,11 +1116,16 @@ def _start_thread(self, thread):
thread_started = False
while not self.is_timeout():
try:
if self.semaphore is not None:
log.info("({}) Waiting for an available thread slot.".format(thread.name))
self.semaphore.acquire()
log.info("({}) Starting thread.".format(thread.name))
thread.start()
thread_started = True
break
except RuntimeError as e:
if self.semaphore is not None:
self.semaphore.release()
if str(e) != "can't start new thread":
log.error("({}) Unexpected error to start a thread.".format(thread.name), exc_info=True)
break
Expand All @@ -1121,6 +1140,8 @@ def _start_thread(self, thread):
time_elapsed += queue_interval
continue
except Exception:
if self.semaphore is not None:
self.semaphore.release()
log.error("({}) Unexpected error to start a thread.".format(thread.name), exc_info=True)
break

Expand Down Expand Up @@ -1493,11 +1514,14 @@ def get_row(widths, values, spad=" ", lpad=""):

def prints(objects, sep=' ', end='\n'):
with open(RESULT_FILE, 'a') as f:
print(objects, sep=sep, end=end, file=sys.stdout)
try:
print(objects, sep=sep, end=end, file=sys.stdout)
sys.stdout.flush()
except OSError:
pass
if end == "\r":
end = "\n" # easier to read with \n in a log file
print(objects, sep=sep, end=end, file=f)
sys.stdout.flush()
f.flush()


Expand Down Expand Up @@ -4714,6 +4738,7 @@ def fabricPathEp_target_check(**kwargs):
fex_a = groups.get("fexA")
fex_b = groups.get("fexB")
path = groups.get("path")
print(path)

# CHECK FEX ID(s) of extpath(s) is 101 or greater
if fex_a:
Expand Down Expand Up @@ -4748,7 +4773,13 @@ def fabricPathEp_target_check(**kwargs):
elif int(third) > 16:
data.append([dn, "eth port {} is invalid (1-16 expected) for breakout ports".format(third)])
else:
data.append([dn, "PathEp 'eth' syntax is invalid"])
# CHECK eth1//0 malform scenario (double slashes)
if "//" in path:
data.append([dn, "PathEp 'eth' syntax is invalid"])
# CHECK Ethx/y malform scenario (should not be caps)
elif path.startswith("Eth"):
data.append([dn, "PathEp 'eth' should be lowercase 'eth'"])

else:
data.append([dn, "target is not a valid fabricPathEp DN"])

Expand Down Expand Up @@ -6039,6 +6070,7 @@ def parse_args(args):
parser.add_argument("-v", "--version", action="store_true", help="Only show the script version, then end.")
parser.add_argument("--total-checks", action="store_true", help="Only show the total number of checks, then end.")
parser.add_argument("--timeout", action="store", nargs="?", type=int, const=-1, default=DEFAULT_TIMEOUT, help="Show default script timeout (sec) or overwrite it when a number is provided (e.g. --timeout 1200).")
parser.add_argument("--max-threads", action="store", type=int, default=None, help="Maximum number of check threads to run concurrently. Defaults to unlimited.")
parsed_args = parser.parse_args(args)
return parsed_args

Expand Down Expand Up @@ -6209,11 +6241,12 @@ class CheckManager:
apic_ca_cert_validation,
]

def __init__(self, api_only=False, debug_function="", timeout=600, monitor_interval=0.5):
def __init__(self, api_only=False, debug_function="", timeout=600, monitor_interval=0.5, max_threads=None):
self.api_only = api_only
self.debug_function = debug_function
self.monitor_interval = monitor_interval # sec
self.monitor_timeout = timeout # sec
self.max_threads = max_threads
self.timeout_event = None

self.check_funcs = self.get_check_funcs()
Expand Down Expand Up @@ -6284,6 +6317,7 @@ def run_checks(self, common_data):
common_kwargs=dict({"finalize_check": self.finalize_check}, **common_data),
monitor_interval=self.monitor_interval,
monitor_timeout=self.monitor_timeout,
max_threads=self.max_threads,
callback_on_monitoring=print_progress,
callback_on_start_failure=self.finalize_check_on_thread_failure,
callback_on_timeout=self.finalize_check_on_thread_timeout,
Expand All @@ -6303,7 +6337,7 @@ def main(_args=None):
print("Timeout(sec): {}".format(DEFAULT_TIMEOUT))
return

cm = CheckManager(args.api_only, args.debug_function, args.timeout)
cm = CheckManager(args.api_only, args.debug_function, args.timeout, max_threads=args.max_threads)

if args.total_checks:
print("Total Number of Checks: {}".format(cm.total_checks))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,13 @@
"tDn": "topology/pod-1/paths-101/pathep-[eth1/1]"
}
}
}, {
"infraRsHPathAtt": {
"attributes": {
"dn": "uni/infra/hpaths-__ui_xxx_201-202_Eth49-50/rsHPathAtt-[topology/pod-1/paths-201/pathep-[xxx_201-202_Eth49-50]]",
"tCl": "fabricPathEp",
"tDn": "topology/pod-1/paths-201/pathep-[xxx_201-202_Eth49-50]"
}
}
}
]
7 changes: 5 additions & 2 deletions tests/test_CheckManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def mock_generate_thread(monkeypatch, request):
def thread_start_with_exception(timeout=5.0):
raise exception

def _mock_generate_thread(self, target, args=(), kwargs=None):
def _mock_generate_thread(self, target, args=(), kwargs=None, use_semaphore=False):
if kwargs is None:
kwargs = {}
thread = script.CustomThread(target=target, name=target.__name__, args=args, kwargs=kwargs)
Expand Down Expand Up @@ -85,7 +85,10 @@ def test_initialize_checks(self, caplog, cm):
assert cm.get_check_result("fake_10_check") is None

# Check number of initialized checks in result files
result_files = os.listdir(script.JSON_DIR)
result_files = [
f for f in os.listdir(script.JSON_DIR)
if f.replace(".json", "") in cm.check_ids
]
assert len(result_files) == cm.total_checks

# Check the filename of result files and their `ruleStatus`
Expand Down
38 changes: 32 additions & 6 deletions tests/test_ThreadManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,67 @@


def task1(data=""):
time.sleep(2.5)
time.sleep(2)
if not global_timeout:
print("Thread task1: Finishing with data {}".format(data))


def task2(data=""):
time.sleep(0.5)
time.sleep(2.5)
if not global_timeout:
print("Thread task2: Finishing with data {}".format(data))


def task3(data=""):
time.sleep(0.2)
time.sleep(1)
if not global_timeout:
print("Thread task3: Finishing with data {}".format(data))


def task4(data=""):
time.sleep(5)
if not global_timeout:
print("Thread task4: Finishing with data {}".format(data))


def task5(data=""):
time.sleep(5)
if not global_timeout:
print("Thread task5: Finishing with data {}".format(data))


def test_ThreadManager(capsys):
global global_timeout
tm = script.ThreadManager(
funcs=[task1, task2, task3],
funcs=[task1, task2, task3, task4, task5],
common_kwargs={"data": "common_data"},
monitor_timeout=1,
max_threads=2,
callback_on_timeout=lambda x: print("Timeout. Abort {}".format(x))
)
tm.start()
tm.join()

# Join each task thread to ensure any in-progress prints complete before
# capsys.readouterr() is called. Without this there is a race where a
# thread passes the `if not global_timeout` check and then tries to print
# after pytest has already torn down the captured stdout fd, causing
# OSError: [Errno 9] Bad file descriptor.
for thread in tm.threads:
try:
thread.join(timeout=1.5)
except RuntimeError:
pass # thread was never started

if tm.is_timeout():
global_timeout = True

expected_output = """\
Thread task3: Finishing with data common_data
Thread task2: Finishing with data common_data
Timeout. Abort task1
Timeout. Abort task2
Thread task1: Finishing with data common_data
Thread task2: Finishing with data common_data
Thread task3: Finishing with data common_data
"""
captured = capsys.readouterr()
assert captured.out == expected_output
13 changes: 13 additions & 0 deletions tests/test_parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_no_args():
assert args.no_cleanup is False
assert args.version is False
assert args.total_checks is False
assert args.max_threads is None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -127,3 +128,15 @@ def test_version(args, expected_result):
def test_total_checks(args, expected_result):
args = script.parse_args(args)
assert args.total_checks == expected_result


@pytest.mark.parametrize(
"args, expected_result",
[
([], None),
(["--max-threads", "4"], 4),
],
)
def test_max_threads(args, expected_result):
args = script.parse_args(args)
assert args.max_threads == expected_result