diff --git a/multipledispatch/core.py b/multipledispatch/core.py index ecbf84a..7f2ad04 100644 --- a/multipledispatch/core.py +++ b/multipledispatch/core.py @@ -57,7 +57,7 @@ def dispatch(*types, **kwargs): def _(func): name = func.__name__ - if ismethod(func): + if ismethod(func) and isinclass(): dispatcher = inspect.currentframe().f_back.f_locals.get(name, MethodDispatcher(name)) else: @@ -78,3 +78,11 @@ def ismethod(func): """ spec = inspect.getargspec(func) return spec and spec.args and spec.args[0] == 'self' + + +def isinclass(n=1): + """ Is the nth previous frame in a class definition?""" + frame = inspect.currentframe().f_back # escape from current function + for _ in range(n): + frame = getattr(frame, 'f_back') + return '__module__' in frame.f_locals diff --git a/multipledispatch/dispatcher.py b/multipledispatch/dispatcher.py index 6fd3728..bb5cb19 100644 --- a/multipledispatch/dispatcher.py +++ b/multipledispatch/dispatcher.py @@ -203,14 +203,16 @@ class MethodDispatcher(Dispatcher): Dispatcher """ def __get__(self, instance, owner): - self.obj = instance - self.cls = owner - return self + dispatcher = self + def method(self, *args, **kwargs): + return dispatcher(self, *args, **kwargs) + method.__name__ = self.name + return method.__get__(instance, owner) - def __call__(self, *args, **kwargs): + def __call__(self, obj, *args, **kwargs): types = tuple([type(arg) for arg in args]) func = self.resolve(types) - return func(self.obj, *args, **kwargs) + return func(obj, *args, **kwargs) def str_signature(sig): diff --git a/multipledispatch/tests/test_core.py b/multipledispatch/tests/test_core.py index 373ce99..4ae2b51 100644 --- a/multipledispatch/tests/test_core.py +++ b/multipledispatch/tests/test_core.py @@ -189,6 +189,21 @@ def g(self, x): def test_methods_multiple_dispatch(): + class Foo(object): + @dispatch(A) + def f(self, y): + return 1 + + @dispatch(C) + def f(self, y): + return 2 + + foo = Foo() + assert foo.f(A()) == 1 + assert foo.f(C()) == 2 + + +def test_methods_multiple_dispatch_fail(): class Foo(object): @dispatch(A, A) def f(x, y): @@ -198,8 +213,63 @@ def f(x, y): def f(x, y): return 2 + @dispatch(int) + def f(x, y): # 'x' as self + return 1 + y foo = Foo() + # We require the 'self' argument to be used to infer methods assert foo.f(A(), A()) == 1 assert foo.f(A(), C()) == 2 assert foo.f(C(), C()) == 2 + assert raises(TypeError, lambda: foo.f(2)) + + +def test_function_with_self(): + @dispatch(A, A) + def f(self, x): + return 1 + + @dispatch(A, C) + def f(self, x): + return 2 + + @dispatch(C, A) + def f(self, x): + return 3 + + @dispatch(C, C) + def f(self, x): + return 4 + + assert f(A(), A()) == 1 + assert f(A(), C()) == 2 + assert f(C(), A()) == 3 + assert f(C(), C()) == 4 + + +def test_method_dispatch_is_safe(): + class Foo(object): + def __init__(self, x): + self.x = x + + @dispatch(int) + def f(self, y): + return self.x + y + + @dispatch(float) + def f(self, y): + return self.x - y + + foo1 = Foo(1) + foo2 = Foo(2) + assert foo1.f(1) == 2 + assert foo1.f(1.0) == 0.0 + assert foo2.f(1) == 3 + assert foo2.f(1.0) == 1.0 + f1 = foo1.f + f2 = foo2.f + assert f1(1) == 2 + assert f1(1.0) == 0.0 + assert f2(1) == 3 + assert f2(1.0) == 1.0