From 4dbc9572a3cd52fd746234b2d6e64df06bd906d9 Mon Sep 17 00:00:00 2001 From: Adam Cavendish Date: Thu, 8 Dec 2016 18:58:48 +0800 Subject: [PATCH 1/3] OPSM initial commit --- ruskit/opsm/__init__.py | 3 + ruskit/opsm/lib.py | 239 ++++++++++++++++++++++++++ tests/test_opsm.py | 365 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 607 insertions(+) create mode 100644 ruskit/opsm/__init__.py create mode 100644 ruskit/opsm/lib.py create mode 100644 tests/test_opsm.py diff --git a/ruskit/opsm/__init__.py b/ruskit/opsm/__init__.py new file mode 100644 index 0000000..ed1e331 --- /dev/null +++ b/ruskit/opsm/__init__.py @@ -0,0 +1,3 @@ +from .lib import Task, SequenceTask, ParallelTask +from .lib import TaskSuccess, TaskFailure, TASK_SUCCESS, TASK_FAILURE +from .lib import PreviousTaskFailedError diff --git a/ruskit/opsm/lib.py b/ruskit/opsm/lib.py new file mode 100644 index 0000000..8038034 --- /dev/null +++ b/ruskit/opsm/lib.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +from __future__ import division, print_function + +from collections import namedtuple +import sys + +TASK_SUCCESS = 30000 +TASK_FAILURE = 30001 + + +class TaskSuccess(namedtuple('TaskSuccess', 'task_name, value')): + tid = TASK_SUCCESS + + def __py2_str__(self): + ret = u'{}({}) ✓'.encode('utf8') + ret = ret.format(self.task_name, self.value) + return ret + + def __py3_str__(self): + return '{}({}) ✓'.format(self.task_name, self.value) + + def __py2_repr__(self): + ret = u'{} ✓'.encode('utf8') + ret = ret.format(self.task_name) + return ret + + def __py3_repr__(self): + return '{} ✓'.format(self.task_name) + + def __str__(self): + if sys.version_info.major < 3: + return self.__py2_str__() + else: + return self.__py3_str__() + + def __repr__(self): + if sys.version_info.major < 3: + return self.__py2_repr__() + else: + return self.__py3_repr__() + + +class TaskFailure( + namedtuple('TaskFailure', 'task_name, error, grdst')): + tid = TASK_FAILURE + + def __py2_str__(self): + if self.grdst: + ret = u'{}({}) ✗ => {}'.encode('utf8') + ret = ret.format(self.task_name, self.error, self.grdst) + else: + ret = u'{}({}) ✗'.encode('utf8') + ret = ret.format(self.task_name, self.error) + return ret + + def __py3_str__(self): + if self.grdst: + ret = '{}({}) ✗ => {}'.format(self.task_name, self.error, + self.grdst) + else: + ret = '{}({}) ✗'.format(self.task_name, self.error) + return ret + + def __py2_repr__(self): + if self.grdst: + ret = u'{} ✗ => {}'.encode('utf8') + ret = ret.format(self.task_name, self.grdst.task_name) + else: + ret = u'{} ✗'.encode('utf8') + ret = ret.format(self.task_name) + return ret + + def __py3_repr__(self): + if self.grdst: + ret = '{} ✗ => {}'.format(self.task_name, self.grdst.task_name) + else: + ret = '{} ✗'.format(self.task_name) + return ret + + def __str__(self): + if sys.version_info.major < 3: + return self.__py2_str__() + else: + return self.__py3_str__() + + def __repr__(self): + if sys.version_info.major < 3: + return self.__py2_repr__() + else: + return self.__py3_repr__() + + +class PreviousTaskFailedError(Exception): + def __init__(self, msg='PTF'): + self.msg = msg + + def __str__(self): + return self.msg + + +class Task(object): + ''' + A Task is a base abstract object for implementing task-based + event scheduling. + + Anyone who'd like to create a task should follow the example: + + @example: + class EchoTask(Task): + def _setup(self, *args, **kwargs): + self.msg = kwargs.get('msg') + def _run(self): + if not self.msg: + raise ValueError('No message found') + print(self.msg) + + the `setup` method is optional, while the run method should + be implemented. + + A Guard is a `Task` used to guard of + If you have specified any `guard`, raising any exception + or setting `self.ok = False` will trigger the `guard`. + ''' + + def __init__(self, *args, **kwargs): + self.ok = True + self._task_name = self.__class__.__name__ + self.guard = kwargs.get('guard') + self._setup(*args, **kwargs) + + def _setup(self, *args, **kwargs): + ''' + Optional for derived classes + ''' + pass + + def _try_guard(self): + try: + if self.guard: + ret = self.guard.run() + else: + ret = None + except Exception as e: + ret = TaskFailure( + self._task_name, error=e, grdst=None) + return ret + + def run(self): + try: + rslt = self._run() + ret = TaskSuccess(self._task_name, value=rslt) + except Exception as e: + ret = TaskFailure( + self._task_name, error=e, grdst=None) + self.ok = False + finally: + if self.ok: + return ret + else: + grdst = self._try_guard() + return TaskFailure( + self._task_name, + error=ret.error, + grdst=grdst) + + def _run(self): + raise NotImplementedError("Should override _run") + + +class SequenceTask(Task): + def __init__(self, *tasks, **kwargs): + super(SequenceTask, self).__init__(**kwargs) + + self.subtasks = list(tasks) + + def add(self, task): + self.subtasks.append(task) + + def run(self): + return self._run() + + def _run_one(self, task): + if self.ok is True: + ret = None + try: + ret = task.run() + except Exception as e: + ret = TaskFailure( + task._task_name, error=e, grdst=None) + if ret.tid == TASK_FAILURE: + self.ok = False + return ret + else: + return TaskFailure( + task._task_name, + error=PreviousTaskFailedError(), + grdst=None) + + def _run(self): + ret = [self._run_one(task) for task in self.subtasks] + if self.ok is True: + return TaskSuccess(self._task_name, value=ret) + else: + grdst = self._try_guard() + return TaskFailure( + self._task_name, error=ret, grdst=grdst) + + +class ParallelTask(Task): + def __init__(self, pool, *tasks, **kwargs): + super(ParallelTask, self).__init__(**kwargs) + self.subtasks = list(tasks) + self.gevent_pool = pool + + def add(self, task): + self.subtasks.append(task) + + def run(self): + return self._run() + + def _run_one(self, task): + ret = None + try: + ret = task.run() + except Exception as e: + ret = TaskFailure( + task._task_name, error=e, grdst=None) + if ret.tid == TASK_FAILURE: + self.ok = False + return ret + + def _run(self): + ret = self.gevent_pool.map(self._run_one, self.subtasks) + if self.ok: + return TaskSuccess(self._task_name, value=ret) + else: + grdst = self._try_guard() + return TaskFailure( + self._task_name, error=ret, grdst=grdst) diff --git a/tests/test_opsm.py b/tests/test_opsm.py new file mode 100644 index 0000000..41dc8bb --- /dev/null +++ b/tests/test_opsm.py @@ -0,0 +1,365 @@ +from __future__ import absolute_import, print_function + +from six.moves import range +import contextlib +import operator +import random +from functools import reduce + +from pprint import pprint + +import gevent +import gevent.pool +import mock + +import ruskit.opsm as opsm + +sleep_time_lb = 0.001 +sleep_time_ub = 0.005 +raise_msg = 'raise' +cleanup_msg = 'cleanup' +rterr = RuntimeError('RAISE EXCEPTION') + + +@contextlib.contextmanager +def global_echo_mock(): + global echo + echo = mock.Mock() + yield + del echo + + +def typical_fail(task_name): + return opsm.TaskFailure( + task_name=task_name, error=rterr, grdst=None) + + +def typical_failclean(task_name, guard_name): + return opsm.TaskFailure( + task_name=task_name, + error=rterr, + grdst=opsm.TaskSuccess( + task_name=guard_name, value=cleanup_msg)) + + +def previous_fail(task_name): + return opsm.TaskFailure( + task_name=task_name, + error=opsm.PreviousTaskFailedError(), + grdst=None) + + +def assert_task_result(expect, actual): + def assert_task_success(expect, actual): + if expect.task_name != actual.task_name: + return False + + if hasattr(expect.value, '__iter__'): + if len(expect.value) != len(actual.value): + return False + return reduce(operator.and_, [ + assert_task_result_one(*pair) + for pair in zip(expect.value, actual.value) + ]) + else: + return expect.value == actual.value + + def assert_task_failure(expect, actual): + if expect.task_name != actual.task_name: + return False + + check_error = True + if hasattr(expect.error, '__iter__'): + if len(expect.error) != len(actual.error): + check_error = False + check_error = reduce(operator.and_, [ + assert_task_result_one(*pair) + for pair in zip(expect.error, actual.error) + ]) + elif isinstance(expect.error, Exception): + check_error = isinstance(actual, expect.__class__) + else: + check_error = expect.error == actual.error + assert isinstance(actual.grdst, expect.grdst.__class__) + if expect.grdst: + return check_error and assert_task_result_one( + expect.grdst, actual.grdst) + else: + return check_error + + def assert_task_result_one(expect, actual): + _type_dispatch = { + opsm.TaskSuccess: assert_task_success, + opsm.TaskFailure: assert_task_failure, + } + if not isinstance(actual, expect.__class__): + return False + return _type_dispatch[expect.__class__](expect, actual) + + assert assert_task_result_one(expect, actual), '''Mismatch: + Expect: {} + Actual: {}'''.format(expect, actual) + + +class EchoTaskS(opsm.Task): + def _setup(self, *args, **kwargs): + self.msg = kwargs['msg'] + + def _run(self): + if self.msg == raise_msg: + raise rterr + else: + echo(self.msg) + return self.msg + + +class EchoTaskP(opsm.Task): + def _setup(self, *args, **kwargs): + self.msg = kwargs['msg'] + + def _run(self): + if self.msg == raise_msg: + raise rterr + else: + gevent.sleep(random.uniform(sleep_time_lb, sleep_time_ub)) + echo(self.msg) + return self.msg + + +class CleanupTask(opsm.Task): + def _run(self): + echo(cleanup_msg) + return cleanup_msg + + +def test_task_success(): + with global_echo_mock(): + msg = 'hello' + ee = EchoTaskS(msg=msg) + ret = ee.run() + + echo.assert_called_once_with(msg) + assert_task_result( + opsm.TaskSuccess( + task_name='EchoTaskS', value=msg), ret) + + +def test_task_failure(): + with global_echo_mock(): + ee = EchoTaskS(msg=raise_msg, guard=CleanupTask()) + ret = ee.run() + + assert_task_result(typical_failclean('EchoTaskS', 'CleanupTask'), ret) + + +def test_sequence_task_all_success(): + with global_echo_mock(): + num = 10 + + ret_expect = opsm.TaskSuccess( + task_name='SequenceTask', + value=[ + opsm.TaskSuccess( + 'EchoTaskS', value=i) for i in range(num) + ]) + mock_call_expect = [mock.call(i) for i in range(num)] + + worker = opsm.SequenceTask(guard=CleanupTask()) + for i in range(num): + worker.add(EchoTaskS(msg=i)) + ret = worker.run() + + assert mock_call_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_sequence_task_partial_failure(): + with global_echo_mock(): + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskS', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskS')] + ret_expect += [ + previous_fail('EchoTaskS') + for i in range(fail_num1 - 1 + succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='SequenceTask', + error=ret_expect, + grdst=opsm.TaskSuccess( + task_name='CleanupTask', value=cleanup_msg)) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + mock_calls_expect += [mock.call('cleanup')] + + # Actuals + worker = opsm.SequenceTask(guard=CleanupTask()) + for i in range(succ_num1): + worker.add(EchoTaskS(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskS(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskS(msg=i)) + ret = worker.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_sequence_task_partial_failure_without_guard(): + with global_echo_mock(): + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskS', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskS')] + ret_expect += [ + previous_fail('EchoTaskS') + for i in range(fail_num1 - 1 + succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='SequenceTask', + error=ret_expect, + grdst=None) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + + # Actuals + worker = opsm.SequenceTask() + for i in range(succ_num1): + worker.add(EchoTaskS(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskS(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskS(msg=i)) + ret = worker.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_parallel_task_all_success(): + with global_echo_mock(): + num = 10 + thread_num = 3 + + ret_expect = opsm.TaskSuccess( + task_name='ParallelTask', + value=[ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(num) + ]) + mock_call_expect = [mock.call(i) for i in range(num)] + + pool = gevent.pool.Pool(thread_num) + worker = opsm.ParallelTask(pool, guard=CleanupTask()) + for i in range(num): + worker.add(EchoTaskP(msg=i)) + ret = worker.run() + + assert sorted(mock_call_expect) == sorted(echo.mock_calls) + assert_task_result(ret_expect, ret) + + +def test_parallel_task_partial_failure(): + with global_echo_mock(): + thread_num = 3 + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num1) + ] + ret_expect += [ + typical_fail('EchoTaskP') + for i in range(fail_num1) + ] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='ParallelTask', + error=ret_expect, + grdst=opsm.TaskSuccess( + task_name='CleanupTask', value=cleanup_msg)) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + mock_calls_expect += [mock.call(i) for i in range(succ_num2)] + mock_calls_expect += [mock.call('cleanup')] + + # Actuals + pool = gevent.pool.Pool(thread_num) + worker = opsm.ParallelTask(pool, guard=CleanupTask()) + for i in range(succ_num1): + worker.add(EchoTaskP(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskP(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskP(msg=i)) + ret = worker.run() + + assert sorted(mock_calls_expect) == sorted(echo.mock_calls) + assert_task_result(ret_expect, ret) + + +def test_parallel_task_partial_failure_without_guard(): + with global_echo_mock(): + thread_num = 3 + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskP') for i in range(fail_num1)] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='ParallelTask', + error=ret_expect, + grdst=None) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + mock_calls_expect += [mock.call(i) for i in range(succ_num2)] + + # Actuals + pool = gevent.pool.Pool(thread_num) + worker = opsm.ParallelTask(pool) + for i in range(succ_num1): + worker.add(EchoTaskP(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskP(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskP(msg=i)) + ret = worker.run() + + assert sorted(mock_calls_expect) == sorted(echo.mock_calls) + assert_task_result(ret_expect, ret) + + +def test_complex_guard_successful(): + pass + + +def test_complex_guard_failed(): + pass From 7d1f96b0cc8f1d4bc4506af5fe5a2385c40cf41b Mon Sep 17 00:00:00 2001 From: Adam Cavendish Date: Tue, 13 Dec 2016 11:05:28 +0800 Subject: [PATCH 2/3] OPSM add complex guard test, split files --- ruskit/opsm/__init__.py | 10 ++++- ruskit/opsm/decorators.py | 16 ++++++++ ruskit/opsm/exceptions.py | 25 ++++++++++++ ruskit/opsm/lib.py | 76 +++++++++++++++++++++--------------- ruskit/opsm/utils.py | 5 +++ tests/test_opsm.py | 81 +++++++++++++++++++++++---------------- 6 files changed, 148 insertions(+), 65 deletions(-) create mode 100644 ruskit/opsm/decorators.py create mode 100644 ruskit/opsm/exceptions.py create mode 100644 ruskit/opsm/utils.py diff --git a/ruskit/opsm/__init__.py b/ruskit/opsm/__init__.py index ed1e331..16164ad 100644 --- a/ruskit/opsm/__init__.py +++ b/ruskit/opsm/__init__.py @@ -1,3 +1,9 @@ +# -*- coding: utf-8 -*- + from .lib import Task, SequenceTask, ParallelTask -from .lib import TaskSuccess, TaskFailure, TASK_SUCCESS, TASK_FAILURE -from .lib import PreviousTaskFailedError +from .lib import TaskSuccess, TaskFailure + +from .exceptions import OPSMReturnOnErrorShortcutException +from .exceptions import PreviousTaskFailedError + +from .decorators import enable_failure_unwrap diff --git a/ruskit/opsm/decorators.py b/ruskit/opsm/decorators.py new file mode 100644 index 0000000..6c2bf48 --- /dev/null +++ b/ruskit/opsm/decorators.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +import functools + +from . import exceptions + + +def enable_failure_unwrap(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + ret = f(*args, **kwargs) + except exceptions.OPSMReturnOnErrorShortcutException as e: + return e.failure + return ret + return wrapper diff --git a/ruskit/opsm/exceptions.py b/ruskit/opsm/exceptions.py new file mode 100644 index 0000000..5278099 --- /dev/null +++ b/ruskit/opsm/exceptions.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +import sys +if sys.version_info.major < 3: + # Circular import issue in python 2, but works in python 3 + # Do not use absolute_import: from __future__ import absolute_import + import lib +else: + from . import lib + + +class OPSMReturnOnErrorShortcutException(Exception): + def __init__(self, failure): + assert isinstance(failure, lib.TaskFailure) + self.failure = failure + + def __str__(self): + return 'OPSM_ROESE: {}'.format(self.failure) + + +class PreviousTaskFailedError(Exception): + def __init__(self, msg='PTF'): + self.msg = msg + + def __str__(self): + return self.msg diff --git a/ruskit/opsm/lib.py b/ruskit/opsm/lib.py index 8038034..84a0e0c 100644 --- a/ruskit/opsm/lib.py +++ b/ruskit/opsm/lib.py @@ -1,16 +1,14 @@ # -*- coding: utf-8 -*- -from __future__ import division, print_function +from __future__ import division, print_function, absolute_import from collections import namedtuple import sys -TASK_SUCCESS = 30000 -TASK_FAILURE = 30001 +from . import exceptions +from . import utils class TaskSuccess(namedtuple('TaskSuccess', 'task_name, value')): - tid = TASK_SUCCESS - def __py2_str__(self): ret = u'{}({}) ✓'.encode('utf8') ret = ret.format(self.task_name, self.value) @@ -39,11 +37,28 @@ def __repr__(self): else: return self.__py3_repr__() + def ok(self): + return True + + def val(self): + return self.value + + def err(self): + return None + + def unwrap(self): + return self.value + + def aggregate(self): + if utils.is_iterable_not_str(self.value): + return tuple(v.aggregate() for v in self.value) + elif isinstance(self.value, tuple): + return self.value + else: + return (self.value,) -class TaskFailure( - namedtuple('TaskFailure', 'task_name, error, grdst')): - tid = TASK_FAILURE +class TaskFailure(namedtuple('TaskFailure', 'task_name, error, grdst')): def __py2_str__(self): if self.grdst: ret = u'{}({}) ✗ => {}'.encode('utf8') @@ -89,13 +104,20 @@ def __repr__(self): else: return self.__py3_repr__() + def ok(self): + return False -class PreviousTaskFailedError(Exception): - def __init__(self, msg='PTF'): - self.msg = msg + def val(self): + return None - def __str__(self): - return self.msg + def err(self): + return self.error + + def unwrap(self): + raise exceptions.OPSMReturnOnErrorShortcutException() + + def aggregate(self): + raise exceptions.OPSMReturnOnErrorShortcutException() class Task(object): @@ -141,8 +163,7 @@ def _try_guard(self): else: ret = None except Exception as e: - ret = TaskFailure( - self._task_name, error=e, grdst=None) + ret = TaskFailure(self._task_name, error=e, grdst=None) return ret def run(self): @@ -150,8 +171,7 @@ def run(self): rslt = self._run() ret = TaskSuccess(self._task_name, value=rslt) except Exception as e: - ret = TaskFailure( - self._task_name, error=e, grdst=None) + ret = TaskFailure(self._task_name, error=e, grdst=None) self.ok = False finally: if self.ok: @@ -159,9 +179,7 @@ def run(self): else: grdst = self._try_guard() return TaskFailure( - self._task_name, - error=ret.error, - grdst=grdst) + self._task_name, error=ret.error, grdst=grdst) def _run(self): raise NotImplementedError("Should override _run") @@ -185,15 +203,14 @@ def _run_one(self, task): try: ret = task.run() except Exception as e: - ret = TaskFailure( - task._task_name, error=e, grdst=None) - if ret.tid == TASK_FAILURE: + ret = TaskFailure(task._task_name, error=e, grdst=None) + if not ret.ok(): self.ok = False return ret else: return TaskFailure( task._task_name, - error=PreviousTaskFailedError(), + error=exceptions.PreviousTaskFailedError(), grdst=None) def _run(self): @@ -202,8 +219,7 @@ def _run(self): return TaskSuccess(self._task_name, value=ret) else: grdst = self._try_guard() - return TaskFailure( - self._task_name, error=ret, grdst=grdst) + return TaskFailure(self._task_name, error=ret, grdst=grdst) class ParallelTask(Task): @@ -223,9 +239,8 @@ def _run_one(self, task): try: ret = task.run() except Exception as e: - ret = TaskFailure( - task._task_name, error=e, grdst=None) - if ret.tid == TASK_FAILURE: + ret = TaskFailure(task._task_name, error=e, grdst=None) + if not ret.ok(): self.ok = False return ret @@ -235,5 +250,4 @@ def _run(self): return TaskSuccess(self._task_name, value=ret) else: grdst = self._try_guard() - return TaskFailure( - self._task_name, error=ret, grdst=grdst) + return TaskFailure(self._task_name, error=ret, grdst=grdst) diff --git a/ruskit/opsm/utils.py b/ruskit/opsm/utils.py new file mode 100644 index 0000000..0c90405 --- /dev/null +++ b/ruskit/opsm/utils.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- + + +def is_iterable_not_str(obj): + return hasattr(obj, '__iter__') and not isinstance(obj, str) diff --git a/tests/test_opsm.py b/tests/test_opsm.py index 41dc8bb..f548d86 100644 --- a/tests/test_opsm.py +++ b/tests/test_opsm.py @@ -1,13 +1,13 @@ from __future__ import absolute_import, print_function +import sys +import functools from six.moves import range import contextlib import operator import random from functools import reduce -from pprint import pprint - import gevent import gevent.pool import mock @@ -29,9 +29,8 @@ def global_echo_mock(): del echo -def typical_fail(task_name): - return opsm.TaskFailure( - task_name=task_name, error=rterr, grdst=None) +def typical_fail(task_name, grdst=None): + return opsm.TaskFailure(task_name=task_name, error=rterr, grdst=grdst) def typical_failclean(task_name, guard_name): @@ -44,17 +43,22 @@ def typical_failclean(task_name, guard_name): def previous_fail(task_name): return opsm.TaskFailure( - task_name=task_name, - error=opsm.PreviousTaskFailedError(), - grdst=None) + task_name=task_name, error=opsm.PreviousTaskFailedError(), grdst=None) + + +def usorted(lst): + return sorted(lst, key=lambda e: str(e)) def assert_task_result(expect, actual): + def is_iterable(obj): + return hasattr(obj, '__iter__') and not isinstance(obj, str) + def assert_task_success(expect, actual): if expect.task_name != actual.task_name: return False - if hasattr(expect.value, '__iter__'): + if is_iterable(expect.value): if len(expect.value) != len(actual.value): return False return reduce(operator.and_, [ @@ -69,7 +73,7 @@ def assert_task_failure(expect, actual): return False check_error = True - if hasattr(expect.error, '__iter__'): + if is_iterable(expect.error): if len(expect.error) != len(actual.error): check_error = False check_error = reduce(operator.and_, [ @@ -82,8 +86,8 @@ def assert_task_failure(expect, actual): check_error = expect.error == actual.error assert isinstance(actual.grdst, expect.grdst.__class__) if expect.grdst: - return check_error and assert_task_result_one( - expect.grdst, actual.grdst) + return check_error and assert_task_result_one(expect.grdst, + actual.grdst) else: return check_error @@ -135,13 +139,16 @@ def _run(self): def test_task_success(): with global_echo_mock(): msg = 'hello' + + # Expect + ret_expect = opsm.TaskSuccess(task_name='EchoTaskS', value=msg) + + # Actual ee = EchoTaskS(msg=msg) ret = ee.run() echo.assert_called_once_with(msg) - assert_task_result( - opsm.TaskSuccess( - task_name='EchoTaskS', value=msg), ret) + assert_task_result(ret_expect, ret) def test_task_failure(): @@ -230,9 +237,7 @@ def test_sequence_task_partial_failure_without_guard(): for i in range(fail_num1 - 1 + succ_num2) ] ret_expect = opsm.TaskFailure( - task_name='SequenceTask', - error=ret_expect, - grdst=None) + task_name='SequenceTask', error=ret_expect, grdst=None) mock_calls_expect = [mock.call(i) for i in range(succ_num1)] # Actuals @@ -268,7 +273,7 @@ def test_parallel_task_all_success(): worker.add(EchoTaskP(msg=i)) ret = worker.run() - assert sorted(mock_call_expect) == sorted(echo.mock_calls) + assert usorted(mock_call_expect) == usorted(echo.mock_calls) assert_task_result(ret_expect, ret) @@ -285,10 +290,7 @@ def test_parallel_task_partial_failure(): opsm.TaskSuccess( task_name='EchoTaskP', value=i) for i in range(succ_num1) ] - ret_expect += [ - typical_fail('EchoTaskP') - for i in range(fail_num1) - ] + ret_expect += [typical_fail('EchoTaskP') for i in range(fail_num1)] ret_expect += [ opsm.TaskSuccess( task_name='EchoTaskP', value=i) for i in range(succ_num2) @@ -313,7 +315,7 @@ def test_parallel_task_partial_failure(): worker.add(EchoTaskP(msg=i)) ret = worker.run() - assert sorted(mock_calls_expect) == sorted(echo.mock_calls) + assert usorted(mock_calls_expect) == usorted(echo.mock_calls) assert_task_result(ret_expect, ret) @@ -336,9 +338,7 @@ def test_parallel_task_partial_failure_without_guard(): task_name='EchoTaskP', value=i) for i in range(succ_num2) ] ret_expect = opsm.TaskFailure( - task_name='ParallelTask', - error=ret_expect, - grdst=None) + task_name='ParallelTask', error=ret_expect, grdst=None) mock_calls_expect = [mock.call(i) for i in range(succ_num1)] mock_calls_expect += [mock.call(i) for i in range(succ_num2)] @@ -353,13 +353,30 @@ def test_parallel_task_partial_failure_without_guard(): worker.add(EchoTaskP(msg=i)) ret = worker.run() - assert sorted(mock_calls_expect) == sorted(echo.mock_calls) + assert usorted(mock_calls_expect) == usorted(echo.mock_calls) assert_task_result(ret_expect, ret) -def test_complex_guard_successful(): - pass +def test_complex_guard_failure(): + with global_echo_mock(): + # Expects + grdst = opsm.TaskSuccess( + task_name='SequenceTask', + value=[ + opsm.TaskSuccess( + task_name='EchoTaskS', value='complex'), opsm.TaskSuccess( + task_name='EchoTaskS', value='guard') + ]) + ret_expect = typical_fail('EchoTaskS', grdst) + + mock_calls_expect = [mock.call('complex'), mock.call('guard')] + # Actuals + complex_guard = opsm.SequenceTask() + complex_guard.add(EchoTaskS(msg='complex')) + complex_guard.add(EchoTaskS(msg='guard')) + worker = EchoTaskS(msg=raise_msg, guard=complex_guard) + ret = worker.run() -def test_complex_guard_failed(): - pass + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret) From 2eb8c38f2b37f9d29a956ca6f7f241ea08527c81 Mon Sep 17 00:00:00 2001 From: Adam Cavendish Date: Fri, 30 Dec 2016 19:01:27 +0800 Subject: [PATCH 3/3] Add RetryTask --- ruskit/opsm/__init__.py | 2 +- ruskit/opsm/lib.py | 15 +++++++++++++++ ruskit/utils.py | 9 +++++++++ tests/test_opsm.py | 38 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/ruskit/opsm/__init__.py b/ruskit/opsm/__init__.py index 16164ad..025d30b 100644 --- a/ruskit/opsm/__init__.py +++ b/ruskit/opsm/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from .lib import Task, SequenceTask, ParallelTask +from .lib import Task, SequenceTask, ParallelTask, RetryTask from .lib import TaskSuccess, TaskFailure from .exceptions import OPSMReturnOnErrorShortcutException diff --git a/ruskit/opsm/lib.py b/ruskit/opsm/lib.py index 84a0e0c..13eb5a3 100644 --- a/ruskit/opsm/lib.py +++ b/ruskit/opsm/lib.py @@ -251,3 +251,18 @@ def _run(self): else: grdst = self._try_guard() return TaskFailure(self._task_name, error=ret, grdst=grdst) + + +class RetryTask(Task): + def __init__(self, *args, **kwargs): + super(RetryTask, self).__init__(*args, **kwargs) + self.retry_times = kwargs.get('retry_times', 1) + + def run(self): + for i in range(self.retry_times): + # Cleanup self.ok flag + self.ok = True + ret = super(RetryTask, self).run() + if ret.ok(): + break + return ret diff --git a/ruskit/utils.py b/ruskit/utils.py index a340586..7a21d37 100644 --- a/ruskit/utils.py +++ b/ruskit/utils.py @@ -4,6 +4,7 @@ import os import sys from functools import wraps +import contextlib from ruskit import cli @@ -119,3 +120,11 @@ def _wrapper(*arguments): ClusterNode.socket_timeout = args.timeout return func(*arguments) return _wrapper + + +@contextlib.contextmanager +def contextlib_suppress(*exceptions): + try: + yield + except exceptions: + pass diff --git a/tests/test_opsm.py b/tests/test_opsm.py index f548d86..6308b14 100644 --- a/tests/test_opsm.py +++ b/tests/test_opsm.py @@ -1,7 +1,5 @@ from __future__ import absolute_import, print_function -import sys -import functools from six.moves import range import contextlib import operator @@ -380,3 +378,39 @@ def test_complex_guard_failure(): assert mock_calls_expect == echo.mock_calls assert_task_result(ret_expect, ret) + + +def test_retry_task(): + class EchoTaskR(opsm.RetryTask): + def _setup(self, *args, **kwargs): + self.error_times = kwargs['error_times'] + self.msg = kwargs['msg'] + + def _run(self): + if self.error_times > 0: + self.error_times -= 1 + raise rterr + else: + echo(self.msg) + return self.msg + + with global_echo_mock(): + msg = 'hello' + error_times = 2 + retry_times = 5 + + # Expects + ret_expect = opsm.TaskSuccess(task_name='EchoTaskR', value=msg) + mock_calls_expect = [mock.call('cleanup') for i in range(error_times)] + mock_calls_expect.append(mock.call(msg)) + + # Actuals + retry = EchoTaskR( + error_times=error_times, + msg=msg, + retry_times=retry_times, + guard=CleanupTask()) + ret = retry.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret)