diff --git a/provy/core/runner.py b/provy/core/runner.py index 865fa2f..d060031 100644 --- a/provy/core/runner.py +++ b/provy/core/runner.py @@ -7,6 +7,7 @@ It's recommended not to tinker with this module, as it might prevent your provyfile from working. ''' + from os.path import abspath, dirname, join from fabric.context_managers import settings as _settings @@ -15,8 +16,11 @@ from provy.core.errors import ConfigurationError from jinja2 import FileSystemLoader, ChoiceLoader +from .server import ProvyServer + def run(provfile_path, server_name, password, extra_options): + module_name = provyfile_module_from(provfile_path) prov = import_module(module_name) servers = get_servers_for(prov, server_name) @@ -35,14 +39,15 @@ def print_header(msg): def provision_server(server, provfile_path, password, prov): - host_string = "%s@%s" % (server['user'], server['address'].strip()) - context = { 'abspath': dirname(abspath(provfile_path)), 'path': dirname(provfile_path), - 'owner': server['user'], + 'owner': server.username, 'cleanup': [], - 'registered_loaders': [] + 'registered_loaders': [], + '__provy': { + 'current_server': server + } } aggregate_node_options(server, context) @@ -52,19 +57,19 @@ def provision_server(server, provfile_path, password, prov): ]) context['loader'] = loader - print_header("Provisioning %s..." % host_string) + print_header("Provisioning %s..." % server.host_string) - settings_dict = dict(host_string=host_string, password=password) - if 'ssh_key' in server and server['ssh_key']: - settings_dict['key_filename'] = server['ssh_key'] + settings_dict = dict(host_string=server.host_string, password=password) + if server.ssh_key is not None: + settings_dict['key_filename'] = server.ssh_key with _settings(**settings_dict): - context['host'] = server['address'] - context['user'] = server['user'] + context['host'] = server.address + context['user'] = server.username role_instances = [] try: - for role in server['roles']: + for role in server.roles: context['role'] = role instance = role(prov, context) role_instances.append(instance) @@ -76,28 +81,30 @@ def provision_server(server, provfile_path, password, prov): for role in context['cleanup']: role.cleanup() - print_header("%s provisioned!" % host_string) + print_header("%s provisioned!" % server.host_string) def aggregate_node_options(server, context): - for key, value in server.get('options', {}).iteritems(): + for key, value in server.options.iteritems(): context[key] = value def build_prompt_options(servers, extra_options): for server in servers: - for option_name, option in server.get('options', {}).iteritems(): + for option_name, option in server.options.iteritems(): if isinstance(option, AskFor): if option.key in extra_options: value = extra_options[option.key] else: value = option.get_value(server) - server['options'][option_name] = value + server.options[option_name] = value def get_servers_for(prov, server_name): - return get_items(prov, server_name, 'servers', lambda item: isinstance(item, dict) and 'address' in item) - + result = [] + for name, server in get_items(prov, server_name, 'servers', lambda item: isinstance(item, dict) and 'address' in item): + result.append(ProvyServer.from_dict(name, server)) + return result def get_items(prov, item_name, item_key, test_func): if not hasattr(prov, item_key): @@ -105,23 +112,26 @@ def get_items(prov, item_name, item_key, test_func): items = getattr(prov, item_key) + key = None + for item_part in item_name.split('.'): + key = item_part items = items[item_part] found_items = [] - recurse_items(items, test_func, found_items) + recurse_items(items, test_func, found_items, key) return found_items -def recurse_items(col, test_func, found_items): +def recurse_items(col, test_func, found_items, key=None): if not isinstance(col, dict): return if test_func(col): - found_items.append(col) + found_items.append([key, col]) else: for key, val in col.iteritems(): if test_func(val): - found_items.append(val) + found_items.append([key, val]) else: recurse_items(val, test_func, found_items) diff --git a/provy/core/server.py b/provy/core/server.py new file mode 100644 index 0000000..68e6472 --- /dev/null +++ b/provy/core/server.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +from copy import copy + + +class ProvyServer(object): + + def __init__(self, name, address, username, roles= tuple(), password=None): + """ + :param name: Logical name of the server (key under which it resides + in the provy file) + :param address: Address of the server + :param username: Username you log into + :param roles: List of roles for the server + :param password: Login password + """ + super(ProvyServer, self).__init__() + self.name = name.strip() + self.address = address.strip() + self.roles = list(roles) + self.username = username.strip() + self.password = password + self.options = {} + self.ssh_key = None + + @staticmethod + def from_dict(name, server_dict): + d = copy(server_dict) + d['name'] = name + s = ProvyServer.__new__(ProvyServer) + s.__setstate__(d) + return s + + @property + def host_string(self): + return "{}@{}".format(self.username, self.address) + + def __getstate__(self): + dict = { + "name": self.name, + "address": self.address, + "roles": self.roles, + "user": self.username, + "options": self.options, + "ssh_key": self.ssh_key + } + if self.password is not None: + dict['password'] = self.password + return dict + + def __setstate__(self, state): + self.name = state['name'].strip() + self.address = state['address'].strip() + self.roles = state.get("roles", []) + self.username = state['user'].strip() + self.password = state.get('password', None) + self.options = state.get('options', {}) + self.ssh_key = state.get('ssh_key', None) + diff --git a/provy/core/utils.py b/provy/core/utils.py index 820d4a8..b2ac2c8 100644 --- a/provy/core/utils.py +++ b/provy/core/utils.py @@ -50,5 +50,5 @@ def __init__(self, key, question): self.question = question def get_value(self, server): - value = getpass("[Server at %s] - %s: " % (server['address'], self.question)) + value = getpass("[Server at %s] - %s: " % (server.address, self.question)) return value diff --git a/tests/unit/core/test_runner.py b/tests/unit/core/test_runner.py index 9e0d5f6..63afb6f 100644 --- a/tests/unit/core/test_runner.py +++ b/tests/unit/core/test_runner.py @@ -33,11 +33,11 @@ def foo_macher(item): recurse_items(collection, foo_macher, found_items) expected_items = [ - 'foo name', - 'some foo truck', - ['foo', 'bar'], - { - 'foo': 'something undescribable', - }, + ['my name', 'foo name'], + ['car', 'some foo truck'], + ['books', + ['foo', 'bar']], + ['others', + {'foo': 'something undescribable'}] ] self.assertListEqual(sorted(found_items), sorted(expected_items), found_items) diff --git a/tests/unit/core/test_server.py b/tests/unit/core/test_server.py new file mode 100644 index 0000000..b34fa0d --- /dev/null +++ b/tests/unit/core/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +import unittest +from provy.core.server import ProvyServer + + +class TestServer(unittest.TestCase): + + def test_init(self): + roles = (object, object) + server = ProvyServer("foo", "testserver", "user", roles) + self.assertIsNone(server.password) + self.assertIsInstance(server.roles, list) + + def test_getstste(self): + roles = (object, object) + server = ProvyServer("foo", "testserver", "user", roles) + self.assertEqual( + server.__getstate__(), + {'address': 'testserver', + 'name': 'foo', + 'options': {}, + 'roles': list(roles), + 'ssh_key': None, + 'user': 'user'}) + + def test_getstste_with_password(self): + roles = (object, object) + server = ProvyServer("foo", "testserver", "user", roles, password="pass") + self.assertEqual( + server.__getstate__(), + {'address': 'testserver', + 'name': 'foo', + 'options': {}, + 'roles': list(roles), + 'ssh_key': None, + 'user': 'user', + 'password': 'pass'}) + + def test_setstate(self): + roles = (object, object) + server = ProvyServer("foo", "testserver", "user", roles) + server.__setstate__({ + 'name': " bar ", + "address" : ' foo ', + 'user': ' user ', + }) + self.assertEqual(server.name, 'bar') + self.assertEqual(server.address, 'foo') + self.assertEqual(server.username, 'user') + +