diff --git a/setup.py b/setup.py index 8aae81862..4e97e09d2 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "mock~=2.0", "moto~=1.3.7", "testfixtures~=4.10.0", - "flake8-future-import", + "flake8-future-import" ] scripts = [ diff --git a/stacker/actions/base.py b/stacker/actions/base.py index 0763e5245..65a09afd5 100644 --- a/stacker/actions/base.py +++ b/stacker/actions/base.py @@ -8,20 +8,15 @@ import threading from ..dag import walk, ThreadedWalker, UnlimitedSemaphore -from ..plan import Step, build_plan, build_graph +from ..plan import Graph, Plan, Step +from ..target import Target import botocore.exceptions from stacker.session_cache import get_session from stacker.exceptions import PlanFailed +from stacker.status import COMPLETE +from stacker.util import ensure_s3_bucket, get_s3_endpoint -from ..status import ( - COMPLETE -) - -from stacker.util import ( - ensure_s3_bucket, - get_s3_endpoint, -) logger = logging.getLogger(__name__) @@ -61,42 +56,6 @@ def build_walker(concurrency): return ThreadedWalker(semaphore).walk -def plan(description, stack_action, context, - tail=None, reverse=False): - """A simple helper that builds a graph based plan from a set of stacks. - - Args: - description (str): a description of the plan. - action (func): a function to call for each stack. - context (:class:`stacker.context.Context`): a - :class:`stacker.context.Context` to build the plan from. - tail (func): an optional function to call to tail the stack progress. - reverse (bool): if True, execute the graph in reverse (useful for - destroy actions). - - Returns: - :class:`plan.Plan`: The resulting plan object - """ - - def target_fn(*args, **kwargs): - return COMPLETE - - steps = [ - Step(stack, fn=stack_action, watch_func=tail) - for stack in context.get_stacks()] - - steps += [ - Step(target, fn=target_fn) for target in context.get_targets()] - - graph = build_graph(steps) - - return build_plan( - description=description, - graph=graph, - targets=context.stack_names, - reverse=reverse) - - def stack_template_key_name(blueprint): """Given a blueprint, produce an appropriate key name. @@ -156,6 +115,119 @@ def __init__(self, context, provider_builder=None, cancel=None): self.bucket_region = provider_builder.region self.s3_conn = get_session(self.bucket_region).client('s3') + def plan(self, description, action_name, action, context, tail=None, + reverse=False, run_hooks=True): + """A helper that builds a graph based plan from a set of stacks. + + Args: + description (str): a description of the plan. + action_name (str): name of the action being run. Used to generate + target names and filter out which hooks to run. + action (func): a function to call for each stack. + context (stacker.context.Context): a context to build the plan + from. + tail (func): an optional function to call to tail the stack + progress. + reverse (bool): whether to flip the direction of dependencies. + Use it when planning an action for destroying resources, + which usually must happen in the reverse order of creation. + Note: this does not change the order of execution of pre/post + action hooks, as the build and destroy hooks are currently + configured in separate. + run_hooks (bool): whether to run hooks configured for this action + + Returns: stacker.plan.Plan: the resulting plan for this action + """ + + def target_fn(*args, **kwargs): + return COMPLETE + + def hook_fn(hook, *args, **kwargs): + return hook.run_step(provider_builder=self.provider_builder, + context=self.context) + + pre_hooks_target = Target( + name="pre_{}_hooks".format(action_name)) + pre_action_target = Target( + name="pre_{}".format(action_name), + requires=[pre_hooks_target.name]) + action_target = Target( + name=action_name, + requires=[pre_action_target.name]) + post_action_target = Target( + name="post_{}".format(action_name), + requires=[action_target.name]) + post_hooks_target = Target( + name="post_{}_hooks".format(action_name), + requires=[post_action_target.name]) + + def steps(): + yield Step.from_target(pre_hooks_target, fn=target_fn) + yield Step.from_target(pre_action_target, fn=target_fn) + yield Step.from_target(action_target, fn=target_fn) + yield Step.from_target(post_action_target, fn=target_fn) + yield Step.from_target(post_hooks_target, fn=target_fn) + + if run_hooks: + # Since we need to maintain compatibility with legacy hooks, + # we separate them completely from the new hooks. + # The legacy hooks will run in two separate phases, completely + # isolated from regular stacks and targets, and any of the new + # hooks. + # Hence, all legacy pre-hooks will finish before any of the + # new hooks, and all legacy post-hooks will only start after + # the new hooks. + + hooks = self.context.get_hooks_for_action(action_name) + logger.debug("Found hooks for action {}: {}".format( + action_name, hooks)) + + for hook in hooks.pre: + yield Step.from_hook( + hook, fn=hook_fn, + required_by=[pre_hooks_target.name]) + + for hook in hooks.custom: + step = Step.from_hook( + hook, fn=hook_fn) + if reverse: + step.reverse_requirements() + + step.requires.add(pre_action_target.name) + step.required_by.add(post_action_target.name) + yield step + + for hook in hooks.post: + yield Step.from_hook( + hook, fn=hook_fn, + requires=[post_hooks_target.name]) + + for target in context.get_targets(): + step = Step.from_target(target, fn=target_fn) + if reverse: + step.reverse_requirements() + + yield step + + for stack in context.get_stacks(): + step = Step.from_stack(stack, fn=action, watch_func=tail) + if reverse: + step.reverse_requirements() + + # Contain stack execution in the boundaries of the pre_action + # and post_action targets. + step.requires.add(pre_action_target.name) + step.required_by.add(action_target.name) + + yield step + + graph = Graph.from_steps(list(steps())) + + return Plan.from_graph( + description=description, + graph=graph, + targets=context.stack_names) + def ensure_cfn_bucket(self): """The CloudFormation bucket where templates will be stored.""" if self.bucket_name: diff --git a/stacker/actions/build.py b/stacker/actions/build.py index bd2b91714..64c042e03 100644 --- a/stacker/actions/build.py +++ b/stacker/actions/build.py @@ -3,11 +3,10 @@ from __future__ import absolute_import import logging -from .base import BaseAction, plan, build_walker +from .base import BaseAction, build_walker from .base import STACK_POLL_TIME from ..providers.base import Template -from .. import util from ..exceptions import ( MissingParameterException, StackDidNotChange, @@ -181,29 +180,6 @@ def _handle_missing_parameters(parameter_values, all_params, required_params, return list(parameter_values.items()) -def handle_hooks(stage, hooks, provider, context, dump, outline): - """Handle pre/post hooks. - - Args: - stage (str): The name of the hook stage - pre_build/post_build. - hooks (list): A list of dictionaries containing the hooks to execute. - provider (:class:`stacker.provider.base.BaseProvider`): The provider - the current stack is using. - context (:class:`stacker.context.Context`): The current stacker - context. - dump (bool): Whether running with dump set or not. - outline (bool): Whether running with outline set or not. - - """ - if not outline and not dump and hooks: - util.handle_hooks( - stage=stage, - hooks=hooks, - provider=provider, - context=context - ) - - class Action(BaseAction): """Responsible for building & coordinating CloudFormation stacks. @@ -273,8 +249,6 @@ def _launch_stack(self, stack, **kwargs): provider_stack = None if provider_stack and not should_update(stack): - stack.set_outputs( - self.provider.get_output_dict(provider_stack)) return NotUpdatedStatus() recreate = False @@ -316,8 +290,6 @@ def _launch_stack(self, stack, **kwargs): return FailedStatus(reason) elif provider.is_stack_completed(provider_stack): - stack.set_outputs( - provider.get_output_dict(provider_stack)) return CompleteStatus(old_status.reason) else: return old_status @@ -366,10 +338,8 @@ def _launch_stack(self, stack, **kwargs): else: return SubmittedStatus("destroying stack for re-creation") except CancelExecution: - stack.set_outputs(provider.get_output_dict(provider_stack)) return SkippedStatus(reason="canceled execution") except StackDidNotChange: - stack.set_outputs(provider.get_output_dict(provider_stack)) return DidNotChangeStatus() def _template(self, blueprint): @@ -391,26 +361,19 @@ def _stack_policy(self, stack): if stack.stack_policy: return Template(body=stack.stack_policy) - def _generate_plan(self, tail=False): - return plan( + def _generate_plan(self, tail=False, outline=False, dump=False): + return self.plan( description="Create/Update stacks", - stack_action=self._launch_stack, + action_name="build", + action=self._launch_stack, tail=self._tail_stack if tail else None, - context=self.context) + context=self.context, + run_hooks=not outline and not dump) def pre_run(self, outline=False, dump=False, *args, **kwargs): """Any steps that need to be taken prior to running the action.""" if should_ensure_cfn_bucket(outline, dump): self.ensure_cfn_bucket() - hooks = self.context.config.pre_build - handle_hooks( - "pre_build", - hooks, - self.provider, - self.context, - dump, - outline - ) def run(self, concurrency=0, outline=False, tail=False, dump=False, *args, **kwargs): @@ -419,7 +382,7 @@ def run(self, concurrency=0, outline=False, This is the main entry point for the Builder. """ - plan = self._generate_plan(tail=tail) + plan = self._generate_plan(tail=tail, outline=outline, dump=dump) if not plan.keys(): logger.warn('WARNING: No stacks detected (error in config?)') if not outline and not dump: @@ -433,15 +396,3 @@ def run(self, concurrency=0, outline=False, if dump: plan.dump(directory=dump, context=self.context, provider=self.provider) - - def post_run(self, outline=False, dump=False, *args, **kwargs): - """Any steps that need to be taken after running the action.""" - hooks = self.context.config.post_build - handle_hooks( - "post_build", - hooks, - self.provider, - self.context, - dump, - outline - ) diff --git a/stacker/actions/destroy.py b/stacker/actions/destroy.py index 4f26692ad..e5245cca7 100644 --- a/stacker/actions/destroy.py +++ b/stacker/actions/destroy.py @@ -3,10 +3,9 @@ from __future__ import absolute_import import logging -from .base import BaseAction, plan, build_walker +from .base import BaseAction, build_walker from .base import STACK_POLL_TIME from ..exceptions import StackDoesNotExist -from .. import util from ..status import ( CompleteStatus, SubmittedStatus, @@ -37,12 +36,14 @@ class Action(BaseAction): """ def _generate_plan(self, tail=False): - return plan( + return self.plan( description="Destroy stacks", - stack_action=self._destroy_stack, + action_name='destroy', + action=self._destroy_stack, tail=self._tail_stack if tail else None, context=self.context, - reverse=True) + reverse=True, + run_hooks=True) def _destroy_stack(self, stack, **kwargs): old_status = kwargs.get("status") @@ -78,16 +79,6 @@ def _destroy_stack(self, stack, **kwargs): provider.destroy_stack(provider_stack) return DestroyingStatus - def pre_run(self, outline=False, *args, **kwargs): - """Any steps that need to be taken prior to running the action.""" - pre_destroy = self.context.config.pre_destroy - if not outline and pre_destroy: - util.handle_hooks( - stage="pre_destroy", - hooks=pre_destroy, - provider=self.provider, - context=self.context) - def run(self, force, concurrency=0, tail=False, *args, **kwargs): plan = self._generate_plan(tail=tail) if not plan.keys(): @@ -101,13 +92,3 @@ def run(self, force, concurrency=0, tail=False, *args, **kwargs): else: plan.outline(message="To execute this plan, run with \"--force\" " "flag.") - - def post_run(self, outline=False, *args, **kwargs): - """Any steps that need to be taken after running the action.""" - post_destroy = self.context.config.post_destroy - if not outline and post_destroy: - util.handle_hooks( - stage="post_destroy", - hooks=post_destroy, - provider=self.provider, - context=self.context) diff --git a/stacker/actions/diff.py b/stacker/actions/diff.py index 97801ae7d..84d067f37 100644 --- a/stacker/actions/diff.py +++ b/stacker/actions/diff.py @@ -8,7 +8,7 @@ import logging from operator import attrgetter -from .base import plan, build_walker +from .base import build_walker from . import build from ..ui import ui from .. import exceptions @@ -272,15 +272,13 @@ def _diff_stack(self, stack, **kwargs): new_params, old_params)) ui.info('\n' + '\n'.join(output)) - stack.set_outputs( - provider.get_output_dict(provider_stack)) - return COMPLETE def _generate_plan(self): - return plan( + return self.plan( description="Diff stacks", - stack_action=self._diff_stack, + action_name="diff", + action=self._diff_stack, context=self.context) def run(self, concurrency=0, *args, **kwargs): diff --git a/stacker/actions/graph.py b/stacker/actions/graph.py index 1f069a68d..f7cdffb50 100644 --- a/stacker/actions/graph.py +++ b/stacker/actions/graph.py @@ -5,7 +5,7 @@ import sys import json -from .base import BaseAction, plan +from .base import BaseAction logger = logging.getLogger(__name__) @@ -55,9 +55,10 @@ def json_format(out, graph): class Action(BaseAction): def _generate_plan(self): - return plan( + return self.plan( description="Print graph", - stack_action=None, + action_name='graph', + action=None, context=self.context) def run(self, format=None, reduce=False, *args, **kwargs): diff --git a/stacker/config/__init__.py b/stacker/config/__init__.py index 5fdde4162..f3666b0a6 100644 --- a/stacker/config/__init__.py +++ b/stacker/config/__init__.py @@ -280,6 +280,8 @@ class PackageSources(Model): class Hook(Model): + name = StringType(serialize_when_none=None) + path = StringType(required=True) required = BooleanType(default=True) @@ -290,6 +292,14 @@ class Hook(Model): args = DictType(AnyType) + required_by = ListType(StringType, serialize_when_none=False) + + requires = ListType(StringType, serialize_when_none=False) + + region = StringType(serialize_when_none=False) + + profile = StringType(serialize_when_none=False) + class Target(Model): name = StringType(required=True) @@ -414,10 +424,14 @@ class Config(Model): post_build = ListType(ModelType(Hook), serialize_when_none=False) + build_hooks = ListType(ModelType(Hook), serialize_when_none=False) + pre_destroy = ListType(ModelType(Hook), serialize_when_none=False) post_destroy = ListType(ModelType(Hook), serialize_when_none=False) + destroy_hooks = ListType(ModelType(Hook), serialize_when_none=False) + tags = DictType(StringType, serialize_when_none=False) template_indent = StringType(serialize_when_none=False) diff --git a/stacker/context.py b/stacker/context.py index 0eac9236f..8fba2731c 100644 --- a/stacker/context.py +++ b/stacker/context.py @@ -4,10 +4,12 @@ from builtins import object import collections import logging +import threading from stacker.config import Config from .stack import Stack from .target import Target +from .hooks import ActionHooks logger = logging.getLogger(__name__) @@ -57,6 +59,8 @@ def __init__(self, environment=None, self.force_stacks = force_stacks or [] self.hook_data = {} + self._hook_lock = threading.RLock() + @property def namespace(self): return self.config.namespace @@ -134,8 +138,9 @@ def get_targets(self): if not hasattr(self, "_targets"): targets = [] for target_def in self.config.targets or []: - target = Target(target_def) + target = Target.from_definition(target_def) targets.append(target) + self._targets = targets return self._targets @@ -183,6 +188,9 @@ def get_fqn(self, name=None): """ return get_fqn(self._base_fqn, self.namespace_delimiter, name) + def get_hooks_for_action(self, action_name): + return ActionHooks.from_config(self.config, action_name) + def set_hook_data(self, key, data): """Set hook data for the given key. @@ -201,4 +209,5 @@ def set_hook_data(self, key, data): raise KeyError("Hook data for key %s already exists, each hook " "must have a unique data_key.", key) - self.hook_data[key] = data + with self._hook_lock: + self.hook_data[key] = data diff --git a/stacker/exceptions.py b/stacker/exceptions.py index e1ae8339f..9b6cbd50e 100644 --- a/stacker/exceptions.py +++ b/stacker/exceptions.py @@ -126,6 +126,7 @@ def __init__(self, stack_name, *args, **kwargs): message = ("Stack: \"%s\" does not exist in outputs or the lookup is " "not available in this stacker run") % (stack_name,) super(StackDoesNotExist, self).__init__(message, *args, **kwargs) + self.stack_name = stack_name class MissingParameterException(Exception): @@ -273,3 +274,20 @@ def __init__(self, exception, stack, dependency): "as a dependency of '%s': %s" ) % (dependency, stack, str(exception)) super(GraphError, self).__init__(message) + + +class HookExecutionFailed(Exception): + """Raised when running a required hook fails""" + + def __init__(self, hook, result=None, cause=None): + self.hook = hook + self.result = result + self.cause = cause + + if self.cause: + message = ("Hook '{}' threw exception: {}".format( + hook.name, cause)) + else: + message = ("Hook '{}' failed (result: {})".format( + hook.name, result)) + super(HookExecutionFailed, self).__init__(message) diff --git a/stacker/hooks/__init__.py b/stacker/hooks/__init__.py index e69de29bb..7869e0489 100644 --- a/stacker/hooks/__init__.py +++ b/stacker/hooks/__init__.py @@ -0,0 +1,202 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import Mapping, namedtuple + +from stacker.exceptions import HookExecutionFailed, StackDoesNotExist +from stacker.util import load_object_from_string +from stacker.status import ( + COMPLETE, SKIPPED, FailedStatus, NotSubmittedStatus, SkippedStatus +) +from stacker.variables import Variable + +logger = logging.getLogger(__name__) + + +def no_op(*args, **kwargs): + logger.info("No-op hook called with arguments: {}".format(kwargs)) + return True + + +class Hook(object): + @classmethod + def from_definition(cls, definition, name_fallback=None): + """Create a hook instance from a config definition""" + name = definition.name or name_fallback + if not name: + raise ValueError('Hook definition does not include name and no ' + 'fallback provided') + + data_key = definition.data_key or name + return cls( + name=name, + path=definition.path, + required=definition.required, + enabled=definition.enabled, + data_key=data_key, + args=definition.args, + required_by=definition.required_by, + requires=definition.requires, + profile=definition.profile, + region=definition.region) + + def __init__(self, name, path, required=True, enabled=True, + data_key=None, args=None, required_by=None, requires=None, + profile=None, region=None): + self.path = path + self.name = name + self.required = required + self.enabled = enabled + self.data_key = data_key + self.args = args + self.required_by = set(required_by or []) + self.requires = set(requires or []) + self.profile = profile + self.region = region + + self._args = {} + self._args, deps = self.parse_args(args) + self.requires.update(deps) + + self._callable = self.resolve_path() + + def parse_args(self, args): + arg_vars = {} + deps = set() + + if args: + for key, value in args.items(): + var = arg_vars[key] = \ + Variable('{}.args.{}'.format(self.name, key), value) + deps.update(var.dependencies()) + + return arg_vars, deps + + def resolve_path(self): + try: + return load_object_from_string(self.path) + except (AttributeError, ImportError) as e: + raise ValueError("Unable to load method at %s for hook %s: %s", + self.path, self.name, str(e)) + + def check_args_dependencies(self, provider, context): + # When running hooks for destruction, we might rely on outputs of + # stacks that we assume have been deployed. Unfortunately, since + # destruction must happen in the reverse order of creation, those stack + # dependencies will not be present on `requires`, but in `required_by`, + # meaning the execution engine won't stop the hook from running early. + + # To deal with that, manually find the dependencies coming from + # lookups in the hook arguments, select those that represent stacks, + # and check if they are actually available. + + dependencies = set() + for value in self._args.values(): + dependencies.update(value.dependencies()) + + for dep in dependencies: + # We assume all dependency names are valid here. Hence, if we can't + # find a stack with that same name, it must be a target or a hook, + # and hence we don't need to check it + stack = context.get_stack(dep) + if stack is None: + continue + + # This will raise if the stack is missing + provider.get_stack(stack.fqn) + + def resolve_args(self, provider, context): + for key, value in self._args.items(): + value.resolve(context, provider) + yield key, value.value + + def run(self, provider, context): + """Run a Hook and capture its result + + These are pieces of external code that we want to run in addition to + CloudFormation deployments, to perform actions that are not easily + handled in a template. + + Args: + provider (:class:`stacker.provider.base.BaseProvider`): + Provider to pass to the hook + context (:class:`stacker.context.Context`): The current stacker + context + Raises: + :class:`stacker.exceptions.HookExecutionFailed`: + if the hook failed + Returns: the result of the hook if it was run, ``None`` if it was + skipped. + """ + + logger.info("Executing hook %s", self) + kwargs = dict(self.resolve_args(provider, context)) + try: + result = self._callable(context=context, provider=provider, + **kwargs) + except Exception as e: + if self.required: + raise HookExecutionFailed(self, cause=e) + + return None + + if not result: + if self.required: + raise HookExecutionFailed(self, result=result) + + logger.warning("Non-required hook %s failed. Return value: %s", + self.name, result) + return result + + if isinstance(result, Mapping): + if self.data_key: + logger.debug("Adding result for hook %s to context in " + "data_key %s.", self.name, self.data_key) + context.set_hook_data(self.data_key, result) + + return result + + def run_step(self, provider_builder, context): + if not self.enabled: + return NotSubmittedStatus() + + provider = provider_builder.build(profile=self.profile, + region=self.region) + + try: + self.check_args_dependencies(provider, context) + except StackDoesNotExist as e: + reason = "required stack not deployed: {}".format(e.stack_name) + return SkippedStatus(reason=reason) + + try: + result = self.run(provider, context) + except HookExecutionFailed as e: + return FailedStatus(reason=str(e)) + + if not result: + return SKIPPED + + return COMPLETE + + def __str__(self): + return 'Hook(name={}, path={}, profile={}, region={})'.format( + self.name, self.path, self.profile, self.region) + + +class ActionHooks(namedtuple('ActionHooks', 'action_name pre post custom')): + @classmethod + def from_config(cls, config, action_name): + def from_key(key): + for i, hook_def in enumerate(config.get(key) or [], 1): + name_fallback = '{}_{}_{}'.format(key, i, hook_def.path) + yield Hook.from_definition(hook_def, + name_fallback=name_fallback) + + return ActionHooks( + action_name=action_name, + pre=list(from_key('pre_{}'.format(action_name))), + post=list(from_key('post_{}'.format(action_name))), + custom=list(from_key('{}_hooks'.format(action_name)))) diff --git a/stacker/hooks/aws_lambda.py b/stacker/hooks/aws_lambda.py index 4b388f40c..9e8c201b7 100644 --- a/stacker/hooks/aws_lambda.py +++ b/stacker/hooks/aws_lambda.py @@ -11,10 +11,10 @@ import hashlib from io import BytesIO as StringIO from zipfile import ZipFile, ZIP_DEFLATED + import botocore import formic from troposphere.awslambda import Code -from stacker.session_cache import get_session from stacker.util import ( get_config_directory, @@ -508,7 +508,7 @@ def create_template(self): payload_acl = kwargs.get('payload_acl', 'private') # Always use the global client for s3 - session = get_session(bucket_region) + session = provider.get_session(region=bucket_region) s3_client = session.client('s3') ensure_s3_bucket(s3_client, bucket_name, bucket_region) diff --git a/stacker/hooks/ecs.py b/stacker/hooks/ecs.py index 308c2eccc..daad432d3 100644 --- a/stacker/hooks/ecs.py +++ b/stacker/hooks/ecs.py @@ -7,7 +7,6 @@ from past.builtins import basestring import logging -from stacker.session_cache import get_session logger = logging.getLogger(__name__) @@ -26,7 +25,7 @@ def create_clusters(provider, context, **kwargs): Returns: boolean for whether or not the hook succeeded. """ - conn = get_session(provider.region).client('ecs') + conn = provider.get_session().client('ecs') try: clusters = kwargs["clusters"] diff --git a/stacker/hooks/iam.py b/stacker/hooks/iam.py index 009888157..2fe9c345a 100644 --- a/stacker/hooks/iam.py +++ b/stacker/hooks/iam.py @@ -5,7 +5,6 @@ import copy import logging -from stacker.session_cache import get_session from botocore.exceptions import ClientError from awacs.aws import Statement, Allow, Policy @@ -32,7 +31,7 @@ def create_ecs_service_role(provider, context, **kwargs): """ role_name = kwargs.get("role_name", "ecsServiceRole") - client = get_session(provider.region).client('iam') + client = provider.get_session().client('iam') try: client.create_role( @@ -125,7 +124,7 @@ def get_cert_contents(kwargs): def ensure_server_cert_exists(provider, context, **kwargs): - client = get_session(provider.region).client('iam') + client = provider.get_session().client('iam') cert_name = kwargs["cert_name"] status = "unknown" try: diff --git a/stacker/hooks/keypair.py b/stacker/hooks/keypair.py index 3114729cd..100d3ef72 100644 --- a/stacker/hooks/keypair.py +++ b/stacker/hooks/keypair.py @@ -8,7 +8,6 @@ from botocore.exceptions import ClientError -from stacker.session_cache import get_session from stacker.hooks import utils from stacker.ui import get_raw_input @@ -220,8 +219,8 @@ def ensure_keypair_exists(provider, context, **kwargs): "specified at the same time") return False - session = get_session(region=provider.region, - profile=kwargs.get("profile")) + session = provider.get_session( + profile=kwargs.get("profile")) ec2 = session.client("ec2") keypair = get_existing_key_pair(ec2, keypair_name) diff --git a/stacker/hooks/route53.py b/stacker/hooks/route53.py index c163e091d..01bc04b41 100644 --- a/stacker/hooks/route53.py +++ b/stacker/hooks/route53.py @@ -3,8 +3,6 @@ from __future__ import absolute_import import logging -from stacker.session_cache import get_session - from stacker.util import create_route53_zone logger = logging.getLogger(__name__) @@ -21,7 +19,7 @@ def create_domain(provider, context, **kwargs): Returns: boolean for whether or not the hook succeeded. """ - session = get_session(provider.region) + session = provider.get_session() client = session.client("route53") domain = kwargs.get("domain") if not domain: diff --git a/stacker/lookups/handlers/ami.py b/stacker/lookups/handlers/ami.py index 8d51c0619..634dce500 100644 --- a/stacker/lookups/handlers/ami.py +++ b/stacker/lookups/handlers/ami.py @@ -1,7 +1,7 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from stacker.session_cache import get_session + import re import operator @@ -22,31 +22,31 @@ def __init__(self, search_string): class AmiLookup(LookupHandler): @classmethod - def handle(cls, value, provider, **kwargs): + def handle(cls, value, context, provider): """Fetch the most recent AMI Id using a filter - + For example: - + ${ami [@]owners:self,account,amazon name_regex:serverX-[0-9]+ architecture:x64,i386} - + The above fetches the most recent AMI where owner is self account or amazon and the ami name matches the regex described, the architecture will be either x64 or i386 - + You can also optionally specify the region in which to perform the AMI lookup. - + Valid arguments: - + owners (comma delimited) REQUIRED ONCE: aws_account_id | amazon | self - + name_regex (a regex) REQUIRED ONCE: e.g. my-ubuntu-server-[0-9]+ - + executable_users (comma delimited) OPTIONAL ONCE: aws_account_id | amazon | self - + Any other arguments specified are sent as filters to the aws api For example, "architecture:x86_64" will add a filter """ # noqa @@ -57,13 +57,13 @@ def handle(cls, value, provider, **kwargs): else: region = provider.region - ec2 = get_session(region).client('ec2') + ec2 = provider.get_session(region=region).client('ec2') values = {} describe_args = {} # now find any other arguments that can be filters - matches = re.findall('([0-9a-zA-z_-]+:[^\s$]+)', value) + matches = re.findall(r'([0-9a-zA-z_-]+:[^\s$]+)', value) for match in matches: k, v = match.split(':', 1) values[k] = v @@ -77,10 +77,9 @@ def handle(cls, value, provider, **kwargs): raise Exception("'name_regex' value required when using ami") name_regex = values.pop('name_regex') - executable_users = None - if values.get('executable_users'): - executable_users = values.pop('executable_users').split(',') - describe_args["ExecutableUsers"] = executable_users + executable_users = values.get('executable_users') + if executable_users: + describe_args["ExecutableUsers"] = executable_users.split(',') filters = [] for k, v in values.items(): diff --git a/stacker/lookups/handlers/dynamodb.py b/stacker/lookups/handlers/dynamodb.py index 9dcd97ce8..44df1b4b0 100644 --- a/stacker/lookups/handlers/dynamodb.py +++ b/stacker/lookups/handlers/dynamodb.py @@ -4,7 +4,6 @@ from builtins import str from botocore.exceptions import ClientError import re -from stacker.session_cache import get_session from . import LookupHandler from ...util import read_value_from_path @@ -14,7 +13,7 @@ class DynamodbLookup(LookupHandler): @classmethod - def handle(cls, value, **kwargs): + def handle(cls, value, context, provider): """Get a value from a dynamodb table dynamodb field types should be in the following format: @@ -53,7 +52,7 @@ def handle(cls, value, **kwargs): projection_expression = _build_projection_expression(clean_table_keys) # lookup the data from dynamodb - dynamodb = get_session(region).client('dynamodb') + dynamodb = provider.get_session(region=region).client('dynamodb') try: response = dynamodb.get_item( TableName=table_name, diff --git a/stacker/lookups/handlers/kms.py b/stacker/lookups/handlers/kms.py index ba80d2779..921925768 100644 --- a/stacker/lookups/handlers/kms.py +++ b/stacker/lookups/handlers/kms.py @@ -2,7 +2,6 @@ from __future__ import division from __future__ import absolute_import import codecs -from stacker.session_cache import get_session from . import LookupHandler from ...util import read_value_from_path @@ -12,7 +11,7 @@ class KmsLookup(LookupHandler): @classmethod - def handle(cls, value, **kwargs): + def handle(cls, value, context, provider): """Decrypt the specified value with a master key in KMS. kmssimple field types should be in the following format: @@ -55,7 +54,7 @@ def handle(cls, value, **kwargs): if "@" in value: region, value = value.split("@", 1) - kms = get_session(region).client('kms') + kms = provider.get_session(region=region).client('kms') # encode str value as an utf-8 bytestring for use with codecs.decode. value = value.encode('utf-8') diff --git a/stacker/lookups/handlers/output.py b/stacker/lookups/handlers/output.py index a40ba0fb3..7d67162cd 100644 --- a/stacker/lookups/handlers/output.py +++ b/stacker/lookups/handlers/output.py @@ -5,6 +5,9 @@ import re from collections import namedtuple +import yaml + +from stacker.exceptions import StackDoesNotExist from . import LookupHandler TYPE_NAME = "output" @@ -14,25 +17,28 @@ class OutputLookup(LookupHandler): @classmethod - def handle(cls, value, context=None, **kwargs): - """Fetch an output from the designated stack. - - Args: - value (str): string with the following format: - ::, ie. some-stack::SomeOutput - context (:class:`stacker.context.Context`): stacker context - - Returns: - str: output from the specified stack - - """ - - if context is None: - raise ValueError('Context is required') + def handle(cls, value, context, provider): + """Fetch an output from the designated stack.""" d = deconstruct(value) - stack = context.get_stack(d.stack_name) - return stack.outputs[d.output_name] + try: + stack = context.get_stack(d.stack_name) + if not stack: + raise StackDoesNotExist(d.stack_name) + outputs = provider.get_outputs(stack.fqn) + except StackDoesNotExist: + raise LookupError("Stack is missing from configuration or not " + "deployed: {}".format(d.stack_name)) + + try: + return outputs[d.output_name] + except KeyError: + available_lookups = yaml.safe_dump( + list(outputs.keys()), default_flow_style=False) + msg = ("Lookup missing from stack: {}::{}. " + "Available lookups:\n{}") + raise LookupError(msg.format( + d.stack_name, d.output_name, available_lookups)) @classmethod def dependencies(cls, lookup_data): diff --git a/stacker/lookups/handlers/rxref.py b/stacker/lookups/handlers/rxref.py index 858a13a3d..cfb327d04 100644 --- a/stacker/lookups/handlers/rxref.py +++ b/stacker/lookups/handlers/rxref.py @@ -22,7 +22,7 @@ class RxrefLookup(LookupHandler): @classmethod - def handle(cls, value, provider=None, context=None, **kwargs): + def handle(cls, value, context, provider): """Fetch an output from the designated stack. Args: diff --git a/stacker/lookups/handlers/ssmstore.py b/stacker/lookups/handlers/ssmstore.py index 2da724d30..8f6d8eaae 100644 --- a/stacker/lookups/handlers/ssmstore.py +++ b/stacker/lookups/handlers/ssmstore.py @@ -3,8 +3,6 @@ from __future__ import absolute_import from builtins import str -from stacker.session_cache import get_session - from . import LookupHandler from ...util import read_value_from_path @@ -13,7 +11,7 @@ class SsmstoreLookup(LookupHandler): @classmethod - def handle(cls, value, **kwargs): + def handle(cls, value, context, provider): """Retrieve (and decrypt if applicable) a parameter from AWS SSM Parameter Store. @@ -48,7 +46,7 @@ def handle(cls, value, **kwargs): if "@" in value: region, value = value.split("@", 1) - client = get_session(region).client("ssm") + client = provider.get_session(region=region).client("ssm") response = client.get_parameters( Names=[ value, diff --git a/stacker/lookups/handlers/xref.py b/stacker/lookups/handlers/xref.py index a318d252b..6e591fa72 100644 --- a/stacker/lookups/handlers/xref.py +++ b/stacker/lookups/handlers/xref.py @@ -21,7 +21,7 @@ class XrefLookup(LookupHandler): @classmethod - def handle(cls, value, provider=None, **kwargs): + def handle(cls, value, context, provider): """Fetch an output from the designated stack. Args: @@ -34,9 +34,6 @@ def handle(cls, value, provider=None, **kwargs): str: output from the specified stack """ - if provider is None: - raise ValueError('Provider is required') - d = deconstruct(value) stack_fqn = d.stack_name output = provider.get_output(stack_fqn, d.output_name) diff --git a/stacker/plan.py b/stacker/plan.py index 24b415e04..81f968852 100644 --- a/stacker/plan.py +++ b/stacker/plan.py @@ -8,6 +8,7 @@ import uuid import threading +from .stack import Stack from .util import stack_template_key_name from .exceptions import ( GraphError, @@ -43,27 +44,54 @@ def log_step(step): class Step(object): """State machine for executing generic actions related to stacks. + Args: - stack (:class:`stacker.stack.Stack`): the stack associated - with this step - fn (func): the function to run to execute the step. This function will - be ran multiple times until the step is "done". - watch_func (func): an optional function that will be called to "tail" - the step action. + subject: the subject associated with this + step. Usually a :class:`stacker.stack.Stack`, + :class:`stacker.target.Target` or :class:`stacker.hooks.Hook` + fn (funcb): the function to run to execute the step. This function + will be ran multiple times until the step is "done". + watch_func (func): an optional function that will be called to + monitor the step action. """ - def __init__(self, stack, fn, watch_func=None): - self.stack = stack - self.status = PENDING - self.last_updated = time.time() + @classmethod + def from_stack(cls, stack, fn, **kwargs): + kwargs.setdefault('logging', stack.logging) + return cls(stack.name, subject=stack, fn=fn, **kwargs) + + @classmethod + def from_target(cls, target, fn, **kwargs): + kwargs.setdefault('logging', True) + return cls(target.name, subject=target, fn=fn, **kwargs) + + @classmethod + def from_hook(cls, hook, fn, **kwargs): + kwargs.setdefault('logging', True) + return cls(hook.name, subject=hook, fn=fn, **kwargs) + + def __init__(self, name, fn, subject=None, watch_func=None, requires=None, + required_by=None, logging=False): + self.name = name + self.subject = subject self.fn = fn + self.watch_func = watch_func + self.requires = set(requires or []) + self.required_by = set(required_by or []) + if subject is not None: + self.requires.update(subject.requires or []) + self.required_by.update(subject.required_by or []) + self.logging = logging + + self.status = PENDING + self.last_updated = time.time() def __repr__(self): - return "" % (self.stack.name,) + return "" % (self.name,) def __str__(self): - return self.stack.name + return self.name def run(self): """Runs this step until it has completed successfully, or been @@ -75,7 +103,7 @@ def run(self): if self.watch_func: watcher = threading.Thread( target=self.watch_func, - args=(self.stack, stop_watcher) + args=(self.subject, stop_watcher) ) watcher.start() @@ -90,25 +118,13 @@ def run(self): def _run_once(self): try: - status = self.fn(self.stack, status=self.status) + status = self.fn(self.subject, status=self.status) except Exception as e: logger.exception(e) status = FailedStatus(reason=str(e)) self.set_status(status) return status - @property - def name(self): - return self.stack.name - - @property - def requires(self): - return self.stack.requires - - @property - def required_by(self): - return self.stack.required_by - @property def completed(self): """Returns True if the step is in a COMPLETE state.""" @@ -126,18 +142,20 @@ def failed(self): @property def done(self): - """Returns True if the step is finished (either COMPLETE, SKIPPED or FAILED) + """Whether this step is finished (either COMPLETE, SKIPPED or FAILED) """ return self.completed or self.skipped or self.failed @property def ok(self): - """Returns True if the step is finished (either COMPLETE or SKIPPED)""" + """Whether this step is finished (either COMPLETE or SKIPPED)""" return self.completed or self.skipped @property def submitted(self): - """Returns True if the step is SUBMITTED, COMPLETE, or SKIPPED.""" + """Whether this step is has been submitted (SUBMITTED, COMPLETE, or + SKIPPED). + """ return self.status >= SUBMITTED def set_status(self, status): @@ -147,11 +165,10 @@ def set_status(self, status): step to. """ if status is not self.status: - logger.debug("Setting %s state to %s.", self.stack.name, - status.name) + logger.debug("Setting %s state to %s.", self.name, status.name) self.status = status self.last_updated = time.time() - if self.stack.logging: + if self.logging: log_step(self) def complete(self): @@ -166,58 +183,15 @@ def submit(self): """A shortcut for set_status(SUBMITTED)""" self.set_status(SUBMITTED) + def reverse_requirements(self): + """ + Change this step so it is suitable for use in operations in reverse + dependency order. -def build_plan(description, graph, - targets=None, reverse=False): - """Builds a plan from a list of steps. - Args: - description (str): an arbitrary string to - describe the plan. - graph (:class:`Graph`): a list of :class:`Graph` to execute. - targets (list): an optional list of step names to filter the graph to. - If provided, only these steps, and their transitive dependencies - will be executed. If no targets are specified, every node in the - graph will be executed. - reverse (bool): If provided, the graph will be walked in reverse order - (dependencies last). - """ - - # If we want to execute the plan in reverse (e.g. Destroy), transpose the - # graph. - if reverse: - graph = graph.transposed() - - # If we only want to build a specific target, filter the graph. - if targets: - nodes = [] - for target in targets: - for k, step in graph.steps.items(): - if step.name == target: - nodes.append(step.name) - graph = graph.filtered(nodes) - - return Plan(description=description, graph=graph) - - -def build_graph(steps): - """Builds a graph of steps. - Args: - steps (list): a list of :class:`Step` objects to execute. - """ - - graph = Graph() - - for step in steps: - graph.add_step(step) - - for step in steps: - for dep in step.requires: - graph.connect(step.name, dep) - - for parent in step.required_by: - graph.connect(parent, step.name) - - return graph + This can be used to correctly generate an action graph when destroying + stacks. + """ + self.required_by, self.requires = self.requires, self.required_by class Graph(object): @@ -231,6 +205,7 @@ class Graph(object): Example: >>> dag = DAG() + >>> def build(*args, **kwargs): return COMPLETE >>> a = Step("a", fn=build) >>> b = Step("b", fn=build) >>> dag.add_step(a) @@ -238,11 +213,34 @@ class Graph(object): >>> dag.connect(a, b) Args: - steps (list): an optional list of :class:`Step` objects to execute. + steps (dict): an optional list of :class:`Step` objects to execute. dag (:class:`stacker.dag.DAG`): an optional :class:`stacker.dag.DAG` object. If one is not provided, a new one will be initialized. """ + @classmethod + def from_steps(cls, steps): + """Builds a graph of steps respecting dependencies + + Args: + steps (List[Step]): steps to include in the graph + Returns: :class:`Graph`: the resulting graph + """ + + graph = Graph() + + for step in steps: + graph.add_step(step) + + for step in steps: + for dep in step.requires: + graph.connect(step.name, dep) + + for parent in step.required_by: + graph.connect(parent, step.name) + + return graph + def __init__(self, steps=None, dag=None): self.steps = steps or {} self.dag = dag or DAG() @@ -287,6 +285,9 @@ def topological_sort(self): nodes = self.dag.topological_sort() return [self.steps[step_name] for step_name in nodes] + def get(self, name, default=None): + return self.steps.get(name, default) + def to_dict(self): return self.dag.graph @@ -298,6 +299,26 @@ class Plan(object): graph (:class:`Graph`): a graph of steps. """ + @classmethod + def from_graph(cls, description, graph, targets=None): + """Builds a plan from a list of steps. + + Args: + description (str): an arbitrary string to describe the plan. + graph (Graph): a :class:`Graph` to base the plan on + targets (list, optional): names of steps to include in the graph. + If provided, only these steps, and their transitive + dependencies will be executed. Otherwise, every node in the + graph will be executed. + Returns: Plan: the resulting plan + """ + + # If we only want to build a specific target, filter the graph. + if targets: + graph = graph.filtered(targets) + + return Plan(description=description, graph=graph) + def __init__(self, description, graph): self.id = uuid.uuid4() self.description = description @@ -335,11 +356,14 @@ def dump(self, directory, context, provider=None): os.makedirs(directory) def walk_func(step): - step.stack.resolve( + if not isinstance(step.subject, Stack): + return True + + step.subject.resolve( context=context, provider=provider, ) - blueprint = step.stack.blueprint + blueprint = step.subject.blueprint filename = stack_template_key_name(blueprint) path = os.path.join(directory, filename) @@ -401,3 +425,10 @@ def step_names(self): def keys(self): return self.step_names + + def get(self, name, default=None): + for step in self.steps: + if step.name == name: + return step + + return default diff --git a/stacker/providers/aws/default.py b/stacker/providers/aws/default.py index 808531346..1805c70c2 100644 --- a/stacker/providers/aws/default.py +++ b/stacker/providers/aws/default.py @@ -11,19 +11,16 @@ import time import urllib.parse import sys - -# thread safe, memoized, provider builder. from threading import Lock import botocore.exceptions from botocore.config import Config -from ..base import BaseProvider -from ... import exceptions -from ...ui import ui +from stacker import exceptions +from stacker.ui import ui +from stacker.providers.base import BaseProvider from stacker.session_cache import get_session - -from ...actions.diff import ( +from stacker.actions.diff import ( DictValue, diff_parameters, format_params_diff as format_diff @@ -550,17 +547,19 @@ def __init__(self, session, region=None, interactive=False, replacements_only=False, recreate_failed=False, service_role=None, **kwargs): self._outputs = {} - self.region = region - self.cloudformation = get_cloudformation_client(session) + self.region = region or session.region_name self.interactive = interactive # replacements only is only used in interactive mode self.replacements_only = interactive and replacements_only self.recreate_failed = interactive or recreate_failed self.service_role = service_role + self._session = session + self._cloudformation = get_cloudformation_client(session) + def get_stack(self, stack_name, **kwargs): try: - return self.cloudformation.describe_stacks( + return self._cloudformation.describe_stacks( StackName=stack_name)['Stacks'][0] except botocore.exceptions.ClientError as e: if "does not exist" not in str(e): @@ -630,11 +629,11 @@ def get_events(self, stack_name, chronological=True): event_list = [] while True: if next_token is not None: - events = self.cloudformation.describe_stack_events( + events = self._cloudformation.describe_stack_events( StackName=stack_name, NextToken=next_token ) else: - events = self.cloudformation.describe_stack_events( + events = self._cloudformation.describe_stack_events( StackName=stack_name ) event_list.append(events['StackEvents']) @@ -690,7 +689,7 @@ def destroy_stack(self, stack, **kwargs): if self.service_role: args["RoleARN"] = self.service_role - self.cloudformation.delete_stack(**args) + self._cloudformation.delete_stack(**args) return True def create_stack(self, fqn, template, parameters, tags, @@ -723,11 +722,11 @@ def create_stack(self, fqn, template, parameters, tags, logger.debug("force_change_set set to True, creating stack with " "changeset.") _changes, change_set_id = create_change_set( - self.cloudformation, fqn, template, parameters, tags, + self._cloudformation, fqn, template, parameters, tags, 'CREATE', service_role=self.service_role, **kwargs ) - self.cloudformation.execute_change_set( + self._cloudformation.execute_change_set( ChangeSetName=change_set_id, ) else: @@ -738,14 +737,14 @@ def create_stack(self, fqn, template, parameters, tags, ) try: - self.cloudformation.create_stack(**args) + self._cloudformation.create_stack(**args) except botocore.exceptions.ClientError as e: if e.response['Error']['Message'] == ('TemplateURL must ' 'reference a valid S3 ' 'object to which you ' 'have access.'): s3_fallback(fqn, template, parameters, tags, - self.cloudformation.create_stack, + self._cloudformation.create_stack, self.service_role) else: raise @@ -887,7 +886,7 @@ def deal_with_changeset_stack_policy(self, fqn, stack_policy): kwargs = generate_stack_policy_args(stack_policy) kwargs["StackName"] = fqn logger.debug("Setting stack policy on %s.", fqn) - self.cloudformation.set_stack_policy(**kwargs) + self._cloudformation.set_stack_policy(**kwargs) def interactive_update_stack(self, fqn, template, old_parameters, parameters, stack_policy, tags, @@ -909,7 +908,7 @@ def interactive_update_stack(self, fqn, template, old_parameters, """ logger.debug("Using interactive provider mode for %s.", fqn) changes, change_set_id = create_change_set( - self.cloudformation, fqn, template, parameters, tags, + self._cloudformation, fqn, template, parameters, tags, 'UPDATE', service_role=self.service_role, **kwargs ) old_parameters_as_dict = self.params_as_dict(old_parameters) @@ -944,7 +943,7 @@ def interactive_update_stack(self, fqn, template, old_parameters, self.deal_with_changeset_stack_policy(fqn, stack_policy) - self.cloudformation.execute_change_set( + self._cloudformation.execute_change_set( ChangeSetName=change_set_id, ) @@ -972,13 +971,13 @@ def noninteractive_changeset_update(self, fqn, template, old_parameters, logger.debug("Using noninterative changeset provider mode " "for %s.", fqn) _changes, change_set_id = create_change_set( - self.cloudformation, fqn, template, parameters, tags, + self._cloudformation, fqn, template, parameters, tags, 'UPDATE', service_role=self.service_role, **kwargs ) self.deal_with_changeset_stack_policy(fqn, stack_policy) - self.cloudformation.execute_change_set( + self._cloudformation.execute_change_set( ChangeSetName=change_set_id, ) @@ -1008,7 +1007,7 @@ def default_update_stack(self, fqn, template, old_parameters, parameters, ) try: - self.cloudformation.update_stack(**args) + self._cloudformation.update_stack(**args) except botocore.exceptions.ClientError as e: if "No updates are to be performed." in str(e): logger.debug( @@ -1021,7 +1020,7 @@ def default_update_stack(self, fqn, template, old_parameters, parameters, 'S3 object to which ' 'you have access.'): s3_fallback(fqn, template, parameters, tags, - self.cloudformation.update_stack, + self._cloudformation.update_stack, self.service_role) else: raise @@ -1038,9 +1037,6 @@ def get_outputs(self, stack_name, *args, **kwargs): self._outputs[stack_name] = get_output_dict(stack) return self._outputs[stack_name] - def get_output_dict(self, stack): - return get_output_dict(stack) - def get_stack_info(self, stack): """ Get the template and parameters of the stack currently in AWS @@ -1049,7 +1045,7 @@ def get_stack_info(self, stack): stack_name = stack['StackId'] try: - template = self.cloudformation.get_template( + template = self._cloudformation.get_template( StackName=stack_name)['TemplateBody'] except botocore.exceptions.ClientError as e: if "does not exist" not in str(e): @@ -1066,3 +1062,9 @@ def params_as_dict(parameters_list): for p in parameters_list: parameters[p['ParameterKey']] = p['ParameterValue'] return parameters + + def get_session(self, **kwargs): + kwargs.setdefault('region', self._session.region_name) + kwargs.setdefault('profile', self._session.profile_name) + + return get_session(**kwargs) diff --git a/stacker/providers/base.py b/stacker/providers/base.py index c48291f13..36208d1f0 100644 --- a/stacker/providers/base.py +++ b/stacker/providers/base.py @@ -43,6 +43,10 @@ def get_output(self, stack_name, output): # pylint: disable=unused-argument return self.get_outputs(stack_name)[output] + def get_session(self, region=None, profile=None): + # pylint: disable=unused-argument + not_implemented("get_session") + class Template(object): """A value object that represents a CloudFormation stack template, which diff --git a/stacker/stack.py b/stacker/stack.py index aa5ab81b4..60fb2f564 100644 --- a/stacker/stack.py +++ b/stacker/stack.py @@ -73,7 +73,6 @@ def __init__(self, definition, context, variables=None, mappings=None, self.enabled = enabled self.protected = protected self.context = context - self.outputs = None self.in_progress_behavior = definition.in_progress_behavior def __repr__(self): @@ -192,6 +191,3 @@ def resolve(self, context, provider): """ resolve_variables(self.variables, context, provider) self.blueprint.resolve_variables(self.variables) - - def set_outputs(self, outputs): - self.outputs = outputs diff --git a/stacker/target.py b/stacker/target.py index b57b3e672..a2171fd9d 100644 --- a/stacker/target.py +++ b/stacker/target.py @@ -9,8 +9,15 @@ class Target(object): a set of stacks together that can be targeted with the `--targets` flag. """ - def __init__(self, definition): - self.name = definition.name - self.requires = definition.requires or [] - self.required_by = definition.required_by or [] - self.logging = False + @classmethod + def from_definition(cls, definition): + return cls(name=definition.name, + requires=definition.requires, + required_by=definition.required_by, + logging=False) + + def __init__(self, name, requires=None, required_by=None, logging=False): + self.name = name + self.requires = list(requires or []) + self.required_by = list(required_by or []) + self.logging = logging diff --git a/stacker/tests/actions/test_build.py b/stacker/tests/actions/test_build.py index 018101401..1c0983ceb 100644 --- a/stacker/tests/actions/test_build.py +++ b/stacker/tests/actions/test_build.py @@ -2,9 +2,8 @@ from __future__ import division from __future__ import absolute_import from builtins import str -import unittest from collections import namedtuple - +import unittest import mock from stacker import exceptions @@ -18,7 +17,6 @@ from stacker.blueprints.variables.types import CFNString from stacker.context import Context, Config from stacker.exceptions import StackDidNotChange, StackDoesNotExist -from stacker.providers.base import BaseProvider from stacker.providers.aws.default import Provider from stacker.status import ( NotSubmittedStatus, @@ -29,7 +27,7 @@ FAILED ) -from ..factories import MockThreadingEvent, MockProviderBuilder +from ..factories import MockThreadingEvent, MockProviderBuilder, mock_provider def mock_stack_parameters(parameters): @@ -41,27 +39,10 @@ def mock_stack_parameters(parameters): } -class TestProvider(BaseProvider): - def __init__(self, outputs=None, *args, **kwargs): - self._outputs = outputs or {} - - def set_outputs(self, outputs): - self._outputs = outputs - - def get_stack(self, stack_name, **kwargs): - if stack_name not in self._outputs: - raise exceptions.StackDoesNotExist(stack_name) - return {"name": stack_name, "outputs": self._outputs[stack_name]} - - def get_outputs(self, stack_name, *args, **kwargs): - stack = self.get_stack(stack_name) - return stack["outputs"] - - class TestBuildAction(unittest.TestCase): def setUp(self): self.context = Context(config=Config({"namespace": "namespace"})) - self.provider = TestProvider() + self.provider = mock_provider() self.build_action = build.Action( self.context, provider_builder=MockProviderBuilder(self.provider)) @@ -80,6 +61,22 @@ def _get_context(self, **kwargs): "else": "${output bastion::something}"}}, {"name": "other", "variables": {}} ], + "build_hooks": [ + {"name": "before-db-hook", + "path": "stacker.hooks.no_op", + "required_by": ["db"]}, + {"name": "after-db-hook", + "path": "stacker.hooks.no_op", + "requires": ["db"]} + ], + "pre_build": [ + {"name": "pre-build-hook", + "path": "stacker.hooks.no_op"} + ], + "post_build": [ + {"name": "post-build-hook", + "path": "stacker.hooks.no_op"} + ] }) return Context(config=config, **kwargs) @@ -130,14 +127,28 @@ def test_existing_stack_params_dont_override_given_params(self): def test_generate_plan(self): context = self._get_context() build_action = build.Action(context, cancel=MockThreadingEvent()) + plan = build_action._generate_plan() + plan.graph.transitive_reduction() + self.assertEqual( - { - 'db': set(['bastion', 'vpc']), - 'bastion': set(['vpc']), - 'other': set([]), - 'vpc': set([])}, - plan.graph.to_dict() + sorted({ + 'pre-build-hook': set(), + 'pre_build_hooks': {'pre-build-hook'}, + 'pre_build': {'pre_build_hooks'}, + 'build': {'other', 'db'}, + 'post_build': {'build', 'after-db-hook'}, + 'post_build_hooks': {'post_build'}, + 'post-build-hook': {'post_build_hooks'}, + + 'other': {'pre_build'}, + 'vpc': {'pre_build'}, + 'bastion': {'vpc'}, + 'before-db-hook': {'pre_build'}, + 'db': {'before-db-hook', 'bastion'}, + 'after-db-hook': {'db'}, + }.items()), + sorted(plan.graph.to_dict().items()) ) def test_dont_execute_plan_when_outline_specified(self): @@ -227,8 +238,9 @@ def setUp(self): self.stack_status = None plan = self.build_action._generate_plan() - self.step = plan.steps[0] - self.step.stack = self.stack + self.step = next(step for step in plan.steps + if step.name == self.stack.name) + self.step.subject = self.stack def patch_object(*args, **kwargs): m = mock.patch.object(*args, **kwargs) @@ -244,9 +256,9 @@ def get_stack(name, *args, **kwargs): 'Outputs': [], 'Tags': []} - def get_events(name, *args, **kwargs): + def get_events(*args, **kwargs): return [{'ResourceStatus': 'ROLLBACK_IN_PROGRESS', - 'ResourceStatusReason': 'CFN fail'}] + 'ResourceStatusReason': 'CFN fail'}] patch_object(self.provider, 'get_stack', side_effect=get_stack) patch_object(self.provider, 'update_stack') diff --git a/stacker/tests/actions/test_destroy.py b/stacker/tests/actions/test_destroy.py index 697afd660..7eb3f0dd6 100644 --- a/stacker/tests/actions/test_destroy.py +++ b/stacker/tests/actions/test_destroy.py @@ -38,10 +38,29 @@ def setUp(self): "stacks": [ {"name": "vpc"}, {"name": "bastion", "requires": ["vpc"]}, - {"name": "instance", "requires": ["vpc", "bastion"]}, - {"name": "db", "requires": ["instance", "vpc", "bastion"]}, - {"name": "other", "requires": ["db"]}, + {"name": "db", "requires": ["vpc", "bastion"]}, + {"name": "instance", "requires": ["db", "vpc", "bastion"]}, + {"name": "other", "requires": []}, ], + "destroy_hooks": [ + {"name": "before-db-hook-1", + "path": "stacker.hooks.no_op", + "args": {"x": "${output db::whatever}"}}, + {"name": "before-db-hook-2", + "path": "stacker.hooks.no_op", + "requires": ["db"]}, + {"name": "after-db-hook", + "path": "stacker.hooks.no_op", + "required_by": ["db"]} + ], + "pre_destroy": [ + {"name": "pre-destroy-hook", + "path": "stacker.hooks.no_op"} + ], + "post_destroy": [ + {"name": "post-destroy-hook", + "path": "stacker.hooks.no_op"} + ] }) self.context = Context(config=config) self.action = destroy.Action(self.context, @@ -49,18 +68,29 @@ def setUp(self): def test_generate_plan(self): plan = self.action._generate_plan() + plan.graph.transitive_reduction() + self.assertEqual( { - 'vpc': set( - ['db', 'instance', 'bastion']), - 'other': set([]), - 'bastion': set( - ['instance', 'db']), - 'instance': set( - ['db']), - 'db': set( - ['other'])}, - plan.graph.to_dict() + 'pre-destroy-hook': set(), + 'pre_destroy_hooks': {'pre-destroy-hook'}, + 'pre_destroy': {'pre_destroy_hooks'}, + 'destroy': {'vpc', 'other'}, + 'post_destroy': {'destroy', 'after-db-hook'}, + 'post_destroy_hooks': {'post_destroy'}, + 'post-destroy-hook': {'post_destroy_hooks'}, + + 'before-db-hook-1': {'pre_destroy'}, + 'before-db-hook-2': {'pre_destroy'}, + 'after-db-hook': {'db'}, + + 'instance': {'pre_destroy'}, + 'db': {'instance', 'before-db-hook-1', 'before-db-hook-2'}, + 'bastion': {'db'}, + 'vpc': {'bastion'}, + 'other': {'pre_destroy'}, + }, + dict(plan.graph.to_dict()) ) def test_only_execute_plan_when_forced(self): @@ -98,7 +128,7 @@ def get_stack(stack_name): return stacks_dict.get(stack_name) plan = self.action._generate_plan() - step = plan.steps[0] + step = plan.get("vpc") # we need the AWS provider to generate the plan, but swap it for # the mock one to make the test easier self.action.provider_builder = MockProviderBuilder(mock_provider) diff --git a/stacker/tests/conftest.py b/stacker/tests/conftest.py index 6597ebc81..81b40aa86 100644 --- a/stacker/tests/conftest.py +++ b/stacker/tests/conftest.py @@ -3,8 +3,11 @@ import logging import os +import mock import pytest import py.path +from boto3 import Session + logger = logging.getLogger(__name__) @@ -42,3 +45,13 @@ def stacker_fixture_dir(): path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'fixtures') return py.path.local(path) + + +@pytest.fixture(scope='session', autouse=True) +def boto3_disable_session_caching(): + def get_session(**kwargs): + return Session(**kwargs) + + with mock.patch('boto3._get_default_session', + side_effect=get_session): + yield diff --git a/stacker/tests/factories.py b/stacker/tests/factories.py index f930c5177..e46e2cf37 100644 --- a/stacker/tests/factories.py +++ b/stacker/tests/factories.py @@ -2,11 +2,15 @@ from __future__ import division from __future__ import absolute_import from builtins import object -from mock import MagicMock + +import mock + +import boto3 from stacker.context import Context from stacker.config import Config, Stack -from stacker.lookups import Lookup +from stacker.exceptions import StackDoesNotExist, StackUpdateBadStatus +from stacker.providers.base import BaseProvider class MockThreadingEvent(object): @@ -23,23 +27,103 @@ def build(self, region=None, profile=None): return self.provider -def mock_provider(**kwargs): - return MagicMock(**kwargs) +class MockProvider(BaseProvider): + def __init__(self, outputs=None, region=None, profile=None): + self.region = region + self.profile = profile + + self._stacks = {} + for stack_name, stack_outputs in (outputs or {}).items(): + self._stacks[stack_name] = { + "StackName": stack_name, + "Outputs": stack_outputs, + "StackStatus": "CREATED" + } + + def get_stack(self, stack_name, **kwargs): + try: + return self._stacks[stack_name] + except KeyError: + raise StackDoesNotExist(stack_name) + + def get_outputs(self, stack_name, *args, **kwargs): + return self.get_stack(stack_name)["Outputs"] + + def get_stack_status(self, stack_name, *args, **kwargs): + return self.get_stack(stack_name)["StackStatus"] + + def create_stack(self, stack_name, *args, **kwargs): + try: + stack = self.get_stack(stack_name) + status = self.get_stack_status(stack) + if status != "DELETED": + raise StackUpdateBadStatus(stack_name, status, "can't create") + except StackDoesNotExist: + pass + + return None + + def update_stack(self, stack_name, *args, **kwargs): + stack = self.get_stack(stack_name) + status = self.get_stack_status(stack) + if status == "DELETED": + raise StackUpdateBadStatus(stack_name, status, "can't update") + + stack["StackStatus"] = "UPDATED" + return None + + def destroy_stack(self, stack_name, *args, **kwargs): + stack = self.get_stack(stack_name) + status = self.get_stack_status(stack) + if status == "DELETED": + raise StackUpdateBadStatus(stack_name, status, "can't destroy") + + stack["StackStatus"] = "DELETED" + return None + def get_session(self, region=None, profile=None): + return boto3.Session(region_name=region or self.region, + profile_name=profile or self.profile) -def mock_context(namespace="default", extra_config_args=None, **kwargs): + +def mock_provider(outputs=None, region=None, profile=None, **kwargs): + provider = MockProvider(outputs, region=region, profile=profile) + return provider + + +def mock_context(namespace="default", extra_config_args=None, + environment=None, **kwargs): config_args = {"namespace": namespace} if extra_config_args: config_args.update(extra_config_args) + config = Config(config_args) - if kwargs.get("environment"): - return Context( - config=config, - **kwargs) - return Context( - config=config, - environment={}, - **kwargs) + environment = environment or {} + return Context(config=config, environment=environment, **kwargs) + + +def mock_boto3_client(service_name, region=None, profile=None): + client = boto3.client(service_name, region_name=region) + default_session = boto3._get_default_session() + + region = region or default_session.region_name + profile = profile or default_session.profile_name + svc_name = service_name + + def create_client(self, service_name, region_name=None, **kwargs): + region_name = region_name or self.region_name + profile_name = self.profile_name + if (svc_name, region, profile) == \ + (service_name, region_name, profile_name): + return client + + raise AssertionError( + "Attempted to create non-mocked AWS client: service={} region={} " + "profile={}".format(service_name, region_name, profile_name)) + + mock_ = mock.patch('boto3.Session.client', autospec=True, + side_effect=create_client) + return client, mock_ def generate_definition(base_name, stack_id, **overrides): @@ -53,12 +137,6 @@ def generate_definition(base_name, stack_id, **overrides): return Stack(definition) -def mock_lookup(lookup_input, lookup_type, raw=None): - if raw is None: - raw = "%s %s" % (lookup_type, lookup_input) - return Lookup(type=lookup_type, input=lookup_input, raw=raw) - - class SessionStub(object): """Stubber class for boto3 sessions made with session_cache.get_session() diff --git a/stacker/tests/hooks/test_aws_lambda.py b/stacker/tests/hooks/test_aws_lambda.py index 67acc934d..d184d82f2 100644 --- a/stacker/tests/hooks/test_aws_lambda.py +++ b/stacker/tests/hooks/test_aws_lambda.py @@ -12,22 +12,18 @@ from io import BytesIO as StringIO from zipfile import ZipFile -import boto3 import botocore from troposphere.awslambda import Code from moto import mock_s3 from testfixtures import TempDirectory, ShouldRaise, compare -from stacker.context import Context -from stacker.config import Config from stacker.hooks.aws_lambda import ( upload_lambda_functions, ZIP_PERMS_MASK, _calculate_hash, select_bucket_region, ) -from ..factories import mock_provider - +from ..factories import mock_provider, mock_context, mock_boto3_client REGION = "us-east-1" ALL_FILES = ( @@ -52,12 +48,6 @@ def temp_directory_with_files(cls, files=ALL_FILES): d.write(f, b'') return d - @property - def s3(self): - if not hasattr(self, '_s3'): - self._s3 = boto3.client('s3', region_name=REGION) - return self._s3 - def assert_s3_zip_file_list(self, bucket, key, files): object_info = self.s3.get_object(Bucket=bucket, Key=key) zip_data = StringIO(object_info['Body'].read()) @@ -82,11 +72,6 @@ def assert_s3_bucket(self, bucket, present=True): if present: self.fail('s3: bucket {} does not exist'.format(bucket)) - def setUp(self): - self.context = Context( - config=Config({'namespace': 'test', 'stacker_bucket': 'test'})) - self.provider = mock_provider(region="us-east-1") - def run_hook(self, **kwargs): real_kwargs = { 'context': self.context, @@ -96,14 +81,26 @@ def run_hook(self, **kwargs): return upload_lambda_functions(**real_kwargs) - @mock_s3 + def setUp(self): + self.context = mock_context( + extra_config_args={'stacker_bucket': 'test'}) + self.provider = mock_provider(region="us-east-1") + + self.mock_s3 = mock_s3() + self.mock_s3.start() + self.s3, self.client_mock = mock_boto3_client('s3', 'us-east-1') + self.client_mock.start() + + def tearDown(self): + self.client_mock.stop() + self.mock_s3.stop() + def test_bucket_default(self): self.assertIsNotNone( self.run_hook(functions={})) self.assert_s3_bucket('test') - @mock_s3 def test_bucket_custom(self): self.assertIsNotNone( self.run_hook(bucket='custom', functions={})) @@ -111,7 +108,6 @@ def test_bucket_custom(self): self.assert_s3_bucket('test', present=False) self.assert_s3_bucket('custom') - @mock_s3 def test_prefix(self): with self.temp_directory_with_files() as d: results = self.run_hook(prefix='cloudformation-custom-resources/', @@ -129,7 +125,6 @@ def test_prefix(self): self.assertTrue(code.S3Key.startswith( 'cloudformation-custom-resources/lambda-MyFunction-')) - @mock_s3 def test_prefix_missing(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -145,7 +140,6 @@ def test_prefix_missing(self): self.assert_s3_zip_file_list(code.S3Bucket, code.S3Key, F1_FILES) self.assertTrue(code.S3Key.startswith('lambda-MyFunction-')) - @mock_s3 def test_path_missing(self): msg = "missing required property 'path' in function 'MyFunction'" with ShouldRaise(ValueError(msg)): @@ -154,7 +148,6 @@ def test_path_missing(self): } }) - @mock_s3 def test_path_relative(self): get_config_directory = 'stacker.hooks.aws_lambda.get_config_directory' with self.temp_directory_with_files(['test/test.py']) as d, \ @@ -173,7 +166,6 @@ def test_path_relative(self): self.assertIsInstance(code, Code) self.assert_s3_zip_file_list(code.S3Bucket, code.S3Key, ['test.py']) - @mock_s3 def test_path_home_relative(self): test_path = '~/test' @@ -195,7 +187,6 @@ def test_path_home_relative(self): self.assertIsInstance(code, Code) self.assert_s3_zip_file_list(code.S3Bucket, code.S3Key, ['test.py']) - @mock_s3 def test_multiple_functions(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -217,7 +208,6 @@ def test_multiple_functions(self): self.assertIsInstance(f2_code, Code) self.assert_s3_zip_file_list(f2_code.S3Bucket, f2_code.S3Key, F2_FILES) - @mock_s3 def test_patterns_invalid(self): msg = ("Invalid file patterns in key 'include': must be a string or " 'list of strings') @@ -230,7 +220,6 @@ def test_patterns_invalid(self): } }) - @mock_s3 def test_patterns_include(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -252,7 +241,6 @@ def test_patterns_include(self): 'test2/test.txt' ]) - @mock_s3 def test_patterns_exclude(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -272,7 +260,6 @@ def test_patterns_exclude(self): 'test2/test.txt' ]) - @mock_s3 def test_patterns_include_exclude(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -292,7 +279,6 @@ def test_patterns_include_exclude(self): '__init__.py' ]) - @mock_s3 def test_patterns_exclude_all(self): msg = ('Empty list of files for Lambda payload. Check your ' 'include/exclude options for errors.') @@ -309,7 +295,6 @@ def test_patterns_exclude_all(self): self.assertIsNone(results) - @mock_s3 def test_idempotence(self): bucket_name = 'test' @@ -396,7 +381,6 @@ def test_select_bucket_region(self): for args, result in tests: self.assertEqual(select_bucket_region(*args), result) - @mock_s3 def test_follow_symlink_nonbool(self): msg = "follow_symlinks option must be a boolean" with ShouldRaise(ValueError(msg)): @@ -405,7 +389,6 @@ def test_follow_symlink_nonbool(self): } }) - @mock_s3 def test_follow_symlink_true(self): # Testing if symlinks are followed with self.temp_directory_with_files() as d1: @@ -439,7 +422,6 @@ def test_follow_symlink_true(self): 'f3/test2/test.txt' ]) - @mock_s3 def test_follow_symlink_false(self): # testing if syminks are present and not folllowed with self.temp_directory_with_files() as d1: @@ -466,7 +448,6 @@ def test_follow_symlink_false(self): 'f2/f2.js', ]) - @mock_s3 def test_follow_symlink_omitted(self): # same as test_follow_symlink_false, but default behaivor with self.temp_directory_with_files() as d1: diff --git a/stacker/tests/hooks/test_ecs.py b/stacker/tests/hooks/test_ecs.py index 12998590f..0b1980dcc 100644 --- a/stacker/tests/hooks/test_ecs.py +++ b/stacker/tests/hooks/test_ecs.py @@ -3,15 +3,12 @@ from __future__ import absolute_import import unittest -import boto3 from moto import mock_ecs from testfixtures import LogCapture from stacker.hooks.ecs import create_clusters -from ..factories import ( - mock_context, - mock_provider, -) +from ..factories import mock_boto3_client, mock_context, mock_provider + REGION = "us-east-1" @@ -22,14 +19,48 @@ def setUp(self): self.provider = mock_provider(region=REGION) self.context = mock_context(namespace="fake") + self.mock_ecs = mock_ecs() + self.mock_ecs.start() + self.ecs, self.ecs_mock = mock_boto3_client("ecs", region=REGION) + self.ecs_mock.start() + + def tearDown(self): + self.ecs_mock.stop() + self.mock_ecs.stop() + def test_create_single_cluster(self): - with mock_ecs(): - cluster = "test-cluster" - logger = "stacker.hooks.ecs" - client = boto3.client("ecs", region_name=REGION) - response = client.list_clusters() + cluster = "test-cluster" + logger = "stacker.hooks.ecs" + response = self.ecs.list_clusters() + + self.assertEqual(len(response["clusterArns"]), 0) + with LogCapture(logger) as logs: + self.assertTrue( + create_clusters( + provider=self.provider, + context=self.context, + clusters=cluster, + ) + ) - self.assertEqual(len(response["clusterArns"]), 0) + logs.check( + ( + logger, + "DEBUG", + "Creating ECS cluster: %s" % cluster + ) + ) + + response = self.ecs.list_clusters() + self.assertEqual(len(response["clusterArns"]), 1) + + def test_create_multiple_clusters(self): + clusters = ("test-cluster0", "test-cluster1") + logger = "stacker.hooks.ecs" + response = self.ecs.list_clusters() + + self.assertEqual(len(response["clusterArns"]), 0) + for cluster in clusters: with LogCapture(logger) as logs: self.assertTrue( create_clusters( @@ -47,58 +78,27 @@ def test_create_single_cluster(self): ) ) - response = client.list_clusters() - self.assertEqual(len(response["clusterArns"]), 1) - - def test_create_multiple_clusters(self): - with mock_ecs(): - clusters = ("test-cluster0", "test-cluster1") - logger = "stacker.hooks.ecs" - client = boto3.client("ecs", region_name=REGION) - response = client.list_clusters() - - self.assertEqual(len(response["clusterArns"]), 0) - for cluster in clusters: - with LogCapture(logger) as logs: - self.assertTrue( - create_clusters( - provider=self.provider, - context=self.context, - clusters=cluster, - ) - ) - - logs.check( - ( - logger, - "DEBUG", - "Creating ECS cluster: %s" % cluster - ) - ) - - response = client.list_clusters() - self.assertEqual(len(response["clusterArns"]), 2) + response = self.ecs.list_clusters() + self.assertEqual(len(response["clusterArns"]), 2) def test_fail_create_cluster(self): - with mock_ecs(): - logger = "stacker.hooks.ecs" - client = boto3.client("ecs", region_name=REGION) - response = client.list_clusters() - - self.assertEqual(len(response["clusterArns"]), 0) - with LogCapture(logger) as logs: - create_clusters( - provider=self.provider, - context=self.context - ) - - logs.check( - ( - logger, - "ERROR", - "setup_clusters hook missing \"clusters\" argument" - ) + logger = "stacker.hooks.ecs" + response = self.ecs.list_clusters() + + self.assertEqual(len(response["clusterArns"]), 0) + with LogCapture(logger) as logs: + create_clusters( + provider=self.provider, + context=self.context + ) + + logs.check( + ( + logger, + "ERROR", + "setup_clusters hook missing \"clusters\" argument" ) + ) - response = client.list_clusters() - self.assertEqual(len(response["clusterArns"]), 0) + response = self.ecs.list_clusters() + self.assertEqual(len(response["clusterArns"]), 0) diff --git a/stacker/tests/hooks/test_iam.py b/stacker/tests/hooks/test_iam.py index d194f4f06..55d799ec0 100644 --- a/stacker/tests/hooks/test_iam.py +++ b/stacker/tests/hooks/test_iam.py @@ -3,22 +3,15 @@ from __future__ import absolute_import import unittest -import boto3 +from awacs.helpers.trust import get_ecs_assumerole_policy from botocore.exceptions import ClientError - from moto import mock_iam from stacker.hooks.iam import ( create_ecs_service_role, _get_cert_arn_from_response, ) - -from awacs.helpers.trust import get_ecs_assumerole_policy - -from ..factories import ( - mock_context, - mock_provider, -) +from ..factories import mock_boto3_client, mock_context, mock_provider REGION = "us-east-1" @@ -34,6 +27,15 @@ def setUp(self): self.context = mock_context(namespace="fake") self.provider = mock_provider(region=REGION) + self.mock_iam = mock_iam() + self.mock_iam.start() + self.iam, self.client_mock = mock_boto3_client("iam", region=REGION) + self.client_mock.start() + + def tearDown(self): + self.client_mock.stop() + self.mock_iam.stop() + def test_get_cert_arn_from_response(self): arn = "fake-arn" # Creation response @@ -52,50 +54,48 @@ def test_get_cert_arn_from_response(self): def test_create_service_role(self): role_name = "ecsServiceRole" policy_name = "AmazonEC2ContainerServiceRolePolicy" - with mock_iam(): - client = boto3.client("iam", region_name=REGION) - with self.assertRaises(ClientError): - client.get_role(RoleName=role_name) + with self.assertRaises(ClientError): + self.iam.get_role(RoleName=role_name) - self.assertTrue( - create_ecs_service_role( - context=self.context, - provider=self.provider, - ) + self.assertTrue( + create_ecs_service_role( + context=self.context, + provider=self.provider, ) + ) - role = client.get_role(RoleName=role_name) + role = self.iam.get_role(RoleName=role_name) - self.assertIn("Role", role) - self.assertEqual(role_name, role["Role"]["RoleName"]) - client.get_role_policy( - RoleName=role_name, - PolicyName=policy_name - ) + self.assertIn("Role", role) + self.assertEqual(role_name, role["Role"]["RoleName"]) + + self.iam.get_role_policy( + RoleName=role_name, + PolicyName=policy_name + ) def test_create_service_role_already_exists(self): role_name = "ecsServiceRole" policy_name = "AmazonEC2ContainerServiceRolePolicy" - with mock_iam(): - client = boto3.client("iam", region_name=REGION) - client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=get_ecs_assumerole_policy().to_json() - ) - self.assertTrue( - create_ecs_service_role( - context=self.context, - provider=self.provider, - ) + self.iam.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=get_ecs_assumerole_policy().to_json() + ) + + self.assertTrue( + create_ecs_service_role( + context=self.context, + provider=self.provider, ) + ) - role = client.get_role(RoleName=role_name) + role = self.iam.get_role(RoleName=role_name) - self.assertIn("Role", role) - self.assertEqual(role_name, role["Role"]["RoleName"]) - client.get_role_policy( - RoleName=role_name, - PolicyName=policy_name - ) + self.assertIn("Role", role) + self.assertEqual(role_name, role["Role"]["RoleName"]) + self.iam.get_role_policy( + RoleName=role_name, + PolicyName=policy_name + ) diff --git a/stacker/tests/lookups/handlers/test_ami.py b/stacker/tests/lookups/handlers/test_ami.py index 0e34b7b47..0b4f46423 100644 --- a/stacker/tests/lookups/handlers/test_ami.py +++ b/stacker/tests/lookups/handlers/test_ami.py @@ -1,194 +1,209 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -import unittest -import mock -from botocore.stub import Stubber -from stacker.lookups.handlers.ami import AmiLookup, ImageNotFound -import boto3 -from stacker.tests.factories import SessionStub, mock_provider - -REGION = "us-east-1" +from botocore.stub import Stubber +import pytest -class TestAMILookup(unittest.TestCase): - client = boto3.client("ec2", region_name=REGION) - - def setUp(self): - self.stubber = Stubber(self.client) - self.provider = mock_provider(region=REGION) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_single_image(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - } - ] - } - ) - - with self.stubber: - value = AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_with_region(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - } - ] - } - ) +from stacker.lookups.handlers.ami import AmiLookup, ImageNotFound +from ...factories import mock_boto3_client, mock_context, mock_provider - with self.stubber: - value = AmiLookup.handle( - value="us-west-1@owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_multiple_images(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": "ami-fffccc110", - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - }, - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-14T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 2", - "VirtualizationType": "hvm", - }, - ] - } - ) - with self.stubber: - value = AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_multiple_images_name_match(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": "ami-fffccc110", - "Name": "Fa---ke Image 1", - "VirtualizationType": "hvm", - }, - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-14T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 2", - "VirtualizationType": "hvm", - }, - ] - } +REGION = "us-east-1" +ALT_REGION = "us-east-2" + + +@pytest.fixture +def context(): + return mock_context() + + +@pytest.fixture(params=[dict(region=REGION)]) +def provider(request): + return mock_provider(**request.param) + + +@pytest.fixture(params=[dict(region=REGION)]) +def ec2(request): + client, mock = mock_boto3_client("ec2", **request.param) + with mock: + yield client + + +@pytest.fixture +def ec2_stubber(ec2): + with Stubber(ec2) as stubber: + yield stubber + + +def test_basic_lookup_single_image(ec2_stubber, context, provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + } + ] + } + ) + + value = AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider + ) + assert value == image_id + + +@pytest.mark.parametrize("ec2", [dict(region=ALT_REGION)], indirect=True) +def test_basic_lookup_with_region(ec2_stubber, context, provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + } + ] + } + ) + + key = r"{}@owners:self name_regex:Fake\sImage\s\d".format(ALT_REGION) + value = AmiLookup.handle( + value=key, + context=context, + provider=provider + ) + assert value == image_id + + +def test_basic_lookup_multiple_images(ec2_stubber, context, provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": "ami-fffccc110", + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + }, + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-14T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 2", + "VirtualizationType": "hvm", + }, + ] + } + ) + + value = AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider + ) + assert value == image_id + + +def test_basic_lookup_multiple_images_name_match(ec2_stubber, context, + provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": "ami-fffccc110", + "Name": "Fa---ke Image 1", + "VirtualizationType": "hvm", + }, + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-14T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 2", + "VirtualizationType": "hvm", + }, + ] + } + ) + + value = AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider + ) + assert value == image_id + + +def test_basic_lookup_no_matching_images(ec2_stubber, context, provider): + ec2_stubber.add_response( + "describe_images", + { + "Images": [] + } + ) + + with pytest.raises(ImageNotFound): + AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider ) - with self.stubber: - value = AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_no_matching_images(self, mock_client): - self.stubber.add_response( - "describe_images", - { - "Images": [] - } - ) - with self.stubber: - with self.assertRaises(ImageNotFound): - AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_no_matching_images_from_name(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - } - ] - } +def test_basic_lookup_no_matching_images_from_name(ec2_stubber, context, + provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + } + ] + } + ) + + with pytest.raises(ImageNotFound): + AmiLookup.handle( + value=r"owners:self name_regex:MyImage\s\d", + context=context, + provider=provider ) - - with self.stubber: - with self.assertRaises(ImageNotFound): - AmiLookup.handle( - value="owners:self name_regex:MyImage\s\d", - provider=self.provider - ) diff --git a/stacker/tests/lookups/handlers/test_default.py b/stacker/tests/lookups/handlers/test_default.py index a59ccd6d8..990e510fe 100644 --- a/stacker/tests/lookups/handlers/test_default.py +++ b/stacker/tests/lookups/handlers/test_default.py @@ -1,22 +1,18 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from mock import MagicMock import unittest -from stacker.context import Context from stacker.lookups.handlers.default import DefaultLookup +from ...factories import mock_context, mock_provider -class TestDefaultLookup(unittest.TestCase): +class TestDefaultLookup(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = Context( - environment={ - 'namespace': 'test', - 'env_var': 'val_in_env'} - ) + self.provider = mock_provider() + self.context = mock_context( + namespace='test', environment={'env_var': 'val_in_env'}) def test_env_var_present(self): lookup_val = "env_var::fallback" diff --git a/stacker/tests/lookups/handlers/test_dynamodb.py b/stacker/tests/lookups/handlers/test_dynamodb.py index 44b6cc693..a23dcf3cf 100644 --- a/stacker/tests/lookups/handlers/test_dynamodb.py +++ b/stacker/tests/lookups/handlers/test_dynamodb.py @@ -2,18 +2,26 @@ from __future__ import division from __future__ import absolute_import import unittest -import mock + from botocore.stub import Stubber + from stacker.lookups.handlers.dynamodb import DynamodbLookup -import boto3 -from stacker.tests.factories import SessionStub +from ...factories import mock_context, mock_provider, mock_boto3_client +REGION = 'us-east-1' -class TestDynamoDBHandler(unittest.TestCase): - client = boto3.client('dynamodb', region_name='us-east-1') +class TestDynamoDBHandler(unittest.TestCase): def setUp(self): - self.stubber = Stubber(self.client) + self.context = mock_context() + self.provider = mock_provider(region=REGION) + + self.dynamodb, self.client_mock = \ + mock_boto3_client("dynamodb", region=REGION) + self.client_mock.start() + self.stubber = Stubber(self.dynamodb) + self.stubber.activate() + self.get_parameters_response = {'Item': {'TestMap': {'M': { 'String1': {'S': 'StringVal1'}, 'List1': {'L': [ @@ -21,9 +29,11 @@ def setUp(self): {'S': 'ListVal2'}]}, 'Number1': {'N': '12345'}, }}}} - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_handler(self, mock_client): + def tearDown(self): + self.client_mock.stop() + self.stubber.deactivate() + + def test_dynamodb_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -36,13 +46,11 @@ def test_dynamodb_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - value = DynamodbLookup.handle(base_lookup_key) - self.assertEqual(value, base_lookup_key_valid) + value = DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEqual(value, base_lookup_key_valid) - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_number_handler(self, mock_client): + def test_dynamodb_number_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -56,13 +64,12 @@ def test_dynamodb_number_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - value = DynamodbLookup.handle(base_lookup_key) - self.assertEqual(value, base_lookup_key_valid) - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_list_handler(self, mock_client): + value = DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEqual(value, base_lookup_key_valid) + + def test_dynamodb_list_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -76,13 +83,12 @@ def test_dynamodb_list_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - value = DynamodbLookup.handle(base_lookup_key) - self.assertEqual(value, base_lookup_key_valid) - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_empty_table_handler(self, mock_client): + value = DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEqual(value, base_lookup_key_valid) + + def test_dynamodb_empty_table_handler(self): expected_params = { 'TableName': '', 'Key': { @@ -94,17 +100,14 @@ def test_dynamodb_empty_table_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'Please make sure to include a dynamodb table name', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_missing_table_handler(self, mock_client): + + msg = 'Please make sure to include a dynamodb table name' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_missing_table_handler(self): expected_params = { 'Key': { 'TestKey': {'S': 'TestVal'} @@ -115,17 +118,14 @@ def test_dynamodb_missing_table_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'Please make sure to include a tablename', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_invalid_table_handler(self, mock_client): + + msg = 'Please make sure to include a tablename' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_invalid_table_handler(self): expected_params = { 'TableName': 'FakeTable', 'Key': { @@ -138,17 +138,14 @@ def test_dynamodb_invalid_table_handler(self, mock_client): self.stubber.add_client_error('get_item', service_error_code=service_error_code, expected_params=expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'Cannot find the dynamodb table: FakeTable', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_invalid_partition_key_handler(self, mock_client): + + msg = 'Cannot find the dynamodb table: FakeTable' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_invalid_partition_key_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -162,17 +159,13 @@ def test_dynamodb_invalid_partition_key_handler(self, mock_client): service_error_code=service_error_code, expected_params=expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'No dynamodb record matched the partition key: FakeKey', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_invalid_partition_val_handler(self, mock_client): + msg = 'No dynamodb record matched the partition key: FakeKey' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_invalid_partition_val_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -185,11 +178,10 @@ def test_dynamodb_invalid_partition_val_handler(self, mock_client): self.stubber.add_response('get_item', empty_response, expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'The dynamodb record could not be found using ' - 'the following key: {\'S\': \'FakeVal\'}', - str(e)) + + msg = ('The dynamodb record could not be found using the following ' + 'key: {\'S\': \'FakeVal\'}') + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) diff --git a/stacker/tests/lookups/handlers/test_kms.py b/stacker/tests/lookups/handlers/test_kms.py index bb199a639..254a082de 100644 --- a/stacker/tests/lookups/handlers/test_kms.py +++ b/stacker/tests/lookups/handlers/test_kms.py @@ -6,31 +6,39 @@ from moto import mock_kms -import boto3 - from stacker.lookups.handlers.kms import KmsLookup +from ...factories import mock_boto3_client, mock_context, mock_provider + +REGION = 'us-east-1' class TestKMSHandler(unittest.TestCase): def setUp(self): + self.context = mock_context() + self.provider = mock_provider(region=REGION) + + self.mock_kms = mock_kms() + self.mock_kms.start() + self.kms, self.client_mock = mock_boto3_client('kms', region=REGION) + self.client_mock.start() + self.plain = b"my secret" - with mock_kms(): - kms = boto3.client("kms", region_name="us-east-1") - self.secret = kms.encrypt( - KeyId="alias/stacker", - Plaintext=codecs.encode(self.plain, 'base64').decode('utf-8'), - )["CiphertextBlob"] - if isinstance(self.secret, bytes): - self.secret = self.secret.decode() + self.secret = self.kms.encrypt( + KeyId="alias/stacker", + Plaintext=codecs.encode(self.plain, 'base64').decode('utf-8'), + )["CiphertextBlob"] + if isinstance(self.secret, bytes): + self.secret = self.secret.decode() + + def tearDown(self): + self.client_mock.stop() + self.mock_kms.stop() def test_kms_handler(self): - with mock_kms(): - decrypted = KmsLookup.handle(self.secret) - self.assertEqual(decrypted, self.plain) + decrypted = KmsLookup.handle(self.secret, self.context, self.provider) + self.assertEqual(decrypted, self.plain) def test_kms_handler_with_region(self): - region = "us-east-1" - value = "%s@%s" % (region, self.secret) - with mock_kms(): - decrypted = KmsLookup.handle(value) - self.assertEqual(decrypted, self.plain) + value = "%s@%s" % (REGION, self.secret) + decrypted = KmsLookup.handle(value, self.context, self.provider) + self.assertEqual(decrypted, self.plain) diff --git a/stacker/tests/lookups/handlers/test_output.py b/stacker/tests/lookups/handlers/test_output.py index 3891dfe25..691f8814d 100644 --- a/stacker/tests/lookups/handlers/test_output.py +++ b/stacker/tests/lookups/handlers/test_output.py @@ -5,25 +5,22 @@ import unittest from stacker.stack import Stack -from ...factories import generate_definition from stacker.lookups.handlers.output import OutputLookup +from ...factories import generate_definition, mock_context, mock_provider -class TestOutputHandler(unittest.TestCase): +class TestOutputHandler(unittest.TestCase): def setUp(self): - self.context = MagicMock() + stack_def = generate_definition("vpc", 1) + self.context = mock_context() + self.stack = Stack(definition=stack_def, context=self.context) + self.context.get_stacks = MagicMock(return_value=[self.stack]) + self.provider = mock_provider( + outputs={self.stack.fqn: {"SomeOutput": "Test Output"}}) def test_output_handler(self): - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "SomeOutput": "Test Output"}) - self.context.get_stack.return_value = stack - value = OutputLookup.handle("stack-name::SomeOutput", - context=self.context) + value = OutputLookup.handle("{}::SomeOutput".format(self.stack.name), + context=self.context, + provider=self.provider) self.assertEqual(value, "Test Output") - self.assertEqual(self.context.get_stack.call_count, 1) - args = self.context.get_stack.call_args - self.assertEqual(args[0][0], "stack-name") diff --git a/stacker/tests/lookups/handlers/test_rxref.py b/stacker/tests/lookups/handlers/test_rxref.py index b5e7cb828..c6480b916 100644 --- a/stacker/tests/lookups/handlers/test_rxref.py +++ b/stacker/tests/lookups/handlers/test_rxref.py @@ -1,30 +1,24 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from mock import MagicMock import unittest from stacker.lookups.handlers.rxref import RxrefLookup -from ....context import Context -from ....config import Config + +from ...factories import mock_context, mock_provider class TestRxrefHandler(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = Context( - config=Config({"namespace": "ns"}) - ) + self.context = mock_context() + self.stack_name = "stack-name" + self.stack_fqn = self.context.get_fqn(self.stack_name) + self.provider = mock_provider( + outputs={self.stack_fqn: {"SomeOutput": "Test Output"}}) def test_rxref_handler(self): - self.provider.get_output.return_value = "Test Output" - - value = RxrefLookup.handle("fully-qualified-stack-name::SomeOutput", + value = RxrefLookup.handle("{}::SomeOutput".format(self.stack_name), provider=self.provider, context=self.context) self.assertEqual(value, "Test Output") - - args = self.provider.get_output.call_args - self.assertEqual(args[0][0], "ns-fully-qualified-stack-name") - self.assertEqual(args[0][1], "SomeOutput") diff --git a/stacker/tests/lookups/handlers/test_ssmstore.py b/stacker/tests/lookups/handlers/test_ssmstore.py index daff2444d..d0d89a81a 100644 --- a/stacker/tests/lookups/handlers/test_ssmstore.py +++ b/stacker/tests/lookups/handlers/test_ssmstore.py @@ -2,74 +2,93 @@ from __future__ import division from __future__ import absolute_import from builtins import str -import unittest -import mock + +import pytest from botocore.stub import Stubber + from stacker.lookups.handlers.ssmstore import SsmstoreLookup -import boto3 -from stacker.tests.factories import SessionStub - - -class TestSSMStoreHandler(unittest.TestCase): - client = boto3.client('ssm', region_name='us-east-1') - - def setUp(self): - self.stubber = Stubber(self.client) - self.get_parameters_response = { - 'Parameters': [ - { - 'Name': 'ssmkey', - 'Type': 'String', - 'Value': 'ssmvalue' - } - ], - 'InvalidParameters': [ - 'invalidssmparam' - ] - } - self.invalid_get_parameters_response = { - 'InvalidParameters': [ - 'ssmkey' - ] - } - self.expected_params = { - 'Names': ['ssmkey'], - 'WithDecryption': True +from ...factories import mock_context, mock_provider, mock_boto3_client + +REGION = 'us-east-1' +ALT_REGION = 'us-east-2' + + +@pytest.fixture +def context(): + return mock_context() + + +@pytest.fixture(params=[dict(region=REGION)]) +def provider(request): + return mock_provider(**request.param) + + +@pytest.fixture(params=[dict(region=REGION)]) +def ssm(request): + client, mock = mock_boto3_client("ssm", **request.param) + with mock: + yield client + + +@pytest.fixture +def ssm_stubber(ssm): + with Stubber(ssm) as stubber: + yield stubber + + +get_parameters_response = { + 'Parameters': [ + { + 'Name': 'ssmkey', + 'Type': 'String', + 'Value': 'ssmvalue' } - self.ssmkey = "ssmkey" - self.ssmvalue = "ssmvalue" - - @mock.patch('stacker.lookups.handlers.ssmstore.get_session', - return_value=SessionStub(client)) - def test_ssmstore_handler(self, mock_client): - self.stubber.add_response('get_parameters', - self.get_parameters_response, - self.expected_params) - with self.stubber: - value = SsmstoreLookup.handle(self.ssmkey) - self.assertEqual(value, self.ssmvalue) - self.assertIsInstance(value, str) - - @mock.patch('stacker.lookups.handlers.ssmstore.get_session', - return_value=SessionStub(client)) - def test_ssmstore_invalid_value_handler(self, mock_client): - self.stubber.add_response('get_parameters', - self.invalid_get_parameters_response, - self.expected_params) - with self.stubber: - try: - SsmstoreLookup.handle(self.ssmkey) - except ValueError: - assert True - - @mock.patch('stacker.lookups.handlers.ssmstore.get_session', - return_value=SessionStub(client)) - def test_ssmstore_handler_with_region(self, mock_client): - self.stubber.add_response('get_parameters', - self.get_parameters_response, - self.expected_params) - region = "us-east-1" - temp_value = "%s@%s" % (region, self.ssmkey) - with self.stubber: - value = SsmstoreLookup.handle(temp_value) - self.assertEqual(value, self.ssmvalue) + ], + 'InvalidParameters': [ + 'invalidssmparam' + ] +} + +invalid_get_parameters_response = { + 'InvalidParameters': [ + 'ssmkey' + ] +} + +expected_params = { + 'Names': ['ssmkey'], + 'WithDecryption': True +} + +ssmkey = "ssmkey" +ssmvalue = "ssmvalue" + + +def test_ssmstore_handler(ssm_stubber, context, provider): + ssm_stubber.add_response('get_parameters', + get_parameters_response, + expected_params) + + value = SsmstoreLookup.handle(ssmkey, context, provider) + assert value == ssmvalue + assert isinstance(value, str) + + +def test_ssmstore_invalid_value_handler(ssm_stubber, context, provider): + ssm_stubber.add_response('get_parameters', + invalid_get_parameters_response, + expected_params) + + with pytest.raises(ValueError): + SsmstoreLookup.handle(ssmkey, context, provider) + + +@pytest.mark.parametrize("ssm", [dict(region=ALT_REGION)], indirect=True) +def test_ssmstore_handler_with_region(ssm_stubber, context, provider): + ssm_stubber.add_response('get_parameters', + get_parameters_response, + expected_params) + temp_value = '%s@%s' % (ALT_REGION, ssmkey) + + value = SsmstoreLookup.handle(temp_value, context, provider) + assert value == ssmvalue diff --git a/stacker/tests/lookups/handlers/test_xref.py b/stacker/tests/lookups/handlers/test_xref.py index cb611ed65..c2b1d1b46 100644 --- a/stacker/tests/lookups/handlers/test_xref.py +++ b/stacker/tests/lookups/handlers/test_xref.py @@ -1,25 +1,23 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from mock import MagicMock import unittest from stacker.lookups.handlers.xref import XrefLookup +from ...factories import mock_context, mock_provider + class TestXrefHandler(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = MagicMock() + self.stack_fqn = "fully-qualified-stack-name" + self.context = mock_context() + self.provider = mock_provider( + outputs={self.stack_fqn: {"SomeOutput": "Test Output"}}) def test_xref_handler(self): - self.provider.get_output.return_value = "Test Output" - value = XrefLookup.handle("fully-qualified-stack-name::SomeOutput", + value = XrefLookup.handle("{}::SomeOutput".format(self.stack_fqn), provider=self.provider, context=self.context) self.assertEqual(value, "Test Output") - self.assertEqual(self.context.get_fqn.call_count, 0) - args = self.provider.get_output.call_args - self.assertEqual(args[0][0], "fully-qualified-stack-name") - self.assertEqual(args[0][1], "SomeOutput") diff --git a/stacker/tests/providers/aws/test_default.py b/stacker/tests/providers/aws/test_default.py index 10dc5577c..e9a729722 100644 --- a/stacker/tests/providers/aws/test_default.py +++ b/stacker/tests/providers/aws/test_default.py @@ -382,7 +382,7 @@ def setUp(self): self.session = get_session(region=region) self.provider = Provider( self.session, region=region, recreate_failed=False) - self.stubber = Stubber(self.provider.cloudformation) + self.stubber = Stubber(self.provider._cloudformation) def test_get_stack_stack_does_not_exist(self): stack_name = "MockStack" @@ -657,7 +657,7 @@ def setUp(self): self.session = get_session(region=region) self.provider = Provider( self.session, interactive=True, recreate_failed=True) - self.stubber = Stubber(self.provider.cloudformation) + self.stubber = Stubber(self.provider._cloudformation) def test_successful_init(self): replacements = True diff --git a/stacker/tests/test_context.py b/stacker/tests/test_context.py index 088fed5f0..689e1a36e 100644 --- a/stacker/tests/test_context.py +++ b/stacker/tests/test_context.py @@ -1,11 +1,15 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import + import unittest +import mock from stacker.context import Context, get_fqn from stacker.config import load, Config -from stacker.util import handle_hooks + + +FAKE_HOOK_PATH = "stacker.tests.fixtures.mock_hooks.mock_hook" class TestContext(unittest.TestCase): @@ -118,16 +122,69 @@ def test_hook_with_sys_path(self): "pre_build": [ { "data_key": "myHook", - "path": "fixtures.mock_hooks.mock_hook", + "path": FAKE_HOOK_PATH.replace('stacker.tests.', ''), "required": True, "args": { "value": "mockResult"}}]}) load(config) + context = Context(config=config) - stage = "pre_build" - handle_hooks(stage, context.config[stage], "mock-region-1", context) + provider = mock.Mock() + hooks = context.get_hooks_for_action('build') + hook = hooks.pre[0] + + hook.run(provider, context) self.assertEqual("mockResult", context.hook_data["myHook"]["result"]) + def test_get_hooks_for_action(self): + + config = Config({ + "pre_build": [ + {"path": FAKE_HOOK_PATH}, + {"name": "pre_build_test", "path": FAKE_HOOK_PATH}, + {"path": FAKE_HOOK_PATH} + ], + "post_build": [ + {"path": FAKE_HOOK_PATH}, + {"name": "post_build_test", "path": FAKE_HOOK_PATH}, + {"path": FAKE_HOOK_PATH} + ], + "build_hooks": [ + {"path": FAKE_HOOK_PATH}, + {"name": "build_test", "path": FAKE_HOOK_PATH}, + {"path": FAKE_HOOK_PATH} + ] + }) + + context = Context(config=config) + hooks = context.get_hooks_for_action('build') + + assert hooks.pre[0].name == "pre_build_1_{}".format(FAKE_HOOK_PATH) + assert hooks.pre[1].name == "pre_build_test" + assert hooks.pre[2].name == "pre_build_3_{}".format(FAKE_HOOK_PATH) + + assert hooks.post[0].name == "post_build_1_{}".format(FAKE_HOOK_PATH) + assert hooks.post[1].name == "post_build_test" + assert hooks.post[2].name == "post_build_3_{}".format(FAKE_HOOK_PATH) + + assert hooks.custom[0].name == \ + "build_hooks_1_{}".format(FAKE_HOOK_PATH) + assert hooks.custom[1].name == "build_test" + assert hooks.custom[2].name == \ + "build_hooks_3_{}".format(FAKE_HOOK_PATH) + + def test_hook_data_key_fallback(self): + config = Config({ + "build_hooks": [ + {"name": "my-hook", "path": FAKE_HOOK_PATH} + ] + }) + context = Context(config=config) + hooks = context.get_hooks_for_action("build") + hook = hooks.custom[0] + + assert hook.data_key == "my-hook" + class TestFunctions(unittest.TestCase): """ Test the module level functions """ diff --git a/stacker/tests/test_hooks.py b/stacker/tests/test_hooks.py new file mode 100644 index 000000000..275f7f0b2 --- /dev/null +++ b/stacker/tests/test_hooks.py @@ -0,0 +1,147 @@ +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +import unittest +import mock + + +from stacker.exceptions import HookExecutionFailed +from stacker.hooks import Hook +from stacker.status import ( + COMPLETE, FailedStatus, NotSubmittedStatus, SkippedStatus +) +from .factories import MockProviderBuilder, mock_context, mock_provider + + +mock_hook = mock.Mock() + + +class TestHooks(unittest.TestCase): + mock_hook_path = __name__ + ".mock_hook" + + def setUp(self): + self.context = mock_context(extra_config_args={ + "stacks": [ + {"name": "undeployed-stack", "template_path": "missing"} + ] + }) + self.provider = mock_provider(region="us-east-1") + self.provider_builder = MockProviderBuilder(self.provider) + + global mock_hook + mock_hook = mock.Mock() + + def test_missing_module(self): + with self.assertRaises(ValueError): + Hook("test", path="not.a.real.path") + + def test_missing_method(self): + with self.assertRaises(ValueError): + Hook("test", path=self.mock_hook_path + "garbage") + + def test_valid_enabled_hook(self): + hook = Hook("test", path=self.mock_hook_path, + required=True, enabled=True) + + result = mock_hook.return_value = mock.Mock() + self.assertIs(result, hook.run(self.provider, self.context)) + mock_hook.assert_called_once() + + def test_context_provided_to_hook(self): + hook = Hook("test", path=self.mock_hook_path, + required=True) + + def return_context(*args, **kwargs): + return kwargs['context'] + + mock_hook.side_effect = return_context + result = hook.run(self.provider, self.context) + self.assertIs(result, self.context) + + def test_hook_failure(self): + hook = Hook("test", path=self.mock_hook_path, + required=True) + + err = Exception() + mock_hook.side_effect = err + + with self.assertRaises(HookExecutionFailed) as raised: + hook.run(self.provider, self.context) + + self.assertIs(hook, raised.exception.hook) + self.assertIs(err, raised.exception.cause) + + def test_hook_failure_skip(self): + hook = Hook("test", path=self.mock_hook_path, + required=False) + + mock_hook.side_effect = Exception() + result = hook.run(self.provider, self.context) + self.assertIsNone(result) + + def test_return_data_hook(self): + hook = Hook("test", path=self.mock_hook_path, + data_key='test') + hook_data = {'hello': 'world'} + mock_hook.return_value = hook_data + + result = hook.run(self.provider, self.context) + self.assertEqual(hook_data, result) + self.assertEqual(hook_data, self.context.hook_data.get('test')) + + def test_return_data_hook_duplicate_key(self): + hook = Hook("test", path=self.mock_hook_path, + data_key='test') + mock_hook.return_value = {'foo': 'bar'} + + hook_data = {'hello': 'world'} + self.context.set_hook_data('test', hook_data) + with self.assertRaises(KeyError): + hook.run(self.provider, self.context) + + self.assertEqual(hook_data, self.context.hook_data['test']) + + def test_run_step_disabled(self): + hook = Hook("test", path=self.mock_hook_path, enabled=False) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, NotSubmittedStatus) + + def test_run_step_stack_dep_missing(self): + hook = Hook("test", path=self.mock_hook_path, + args={"hello": "${output undeployed-stack::Output}"}) + stack_fqn = self.context.get_stack("undeployed-stack").fqn + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, SkippedStatus) + self.assertEqual(status.reason, + "required stack not deployed: {}".format(stack_fqn)) + + def test_run_step_hook_raised(self): + hook = Hook("test", path=self.mock_hook_path) + err = HookExecutionFailed(hook, cause=RuntimeError("canary")) + hook.run = mock.Mock(side_effect=err) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, FailedStatus) + self.assertIn("canary", status.reason) + self.assertIn("threw exception", status.reason) + + def test_run_step_hook_failed(self): + hook = Hook("test", path=self.mock_hook_path, required=True) + hook.run = mock.Mock(return_value=False) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, SkippedStatus) + + def test_run_step_hook_succeeded(self): + hook = Hook("test", path=self.mock_hook_path) + hook.run = mock.Mock(return_value=True) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertEqual(status, COMPLETE) diff --git a/stacker/tests/test_plan.py b/stacker/tests/test_plan.py index a88c5e460..e71b6cd95 100644 --- a/stacker/tests/test_plan.py +++ b/stacker/tests/test_plan.py @@ -16,11 +16,7 @@ register_lookup_handler, unregister_lookup_handler, ) -from stacker.plan import ( - Step, - build_plan, - build_graph, -) +from stacker.plan import Graph, Step, Plan from stacker.exceptions import ( CancelExecution, GraphError, @@ -45,7 +41,7 @@ def setUp(self): stack = mock.MagicMock() stack.name = "stack" stack.fqn = "namespace-stack" - self.step = Step(stack=stack, fn=None) + self.step = Step.from_stack(stack=stack, fn=None) def test_status(self): self.assertFalse(self.step.submitted) @@ -86,9 +82,11 @@ def test_plan(self): definition=generate_definition('bastion', 1, requires=[vpc.name]), context=self.context) - graph = build_graph([ - Step(vpc, fn=None), Step(bastion, fn=None)]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([ + Step.from_stack(vpc, fn=None), + Step.from_stack(bastion, fn=None) + ]) + plan = Plan.from_graph(description="Test", graph=graph) self.assertEqual(plan.graph.to_dict(), { 'bastion.1': set(['vpc.1']), @@ -108,8 +106,11 @@ def fn(stack, status=None): calls.append(stack.fqn) return COMPLETE - graph = build_graph([Step(vpc, fn), Step(bastion, fn)]) - plan = build_plan( + graph = Graph.from_steps([ + Step.from_stack(vpc, fn), + Step.from_stack(bastion, fn) + ]) + plan = Plan.from_graph( description="Test", graph=graph) plan.execute(walk) @@ -132,9 +133,12 @@ def fn(stack, status=None): calls.append(stack.fqn) return COMPLETE - graph = build_graph([ - Step(vpc, fn), Step(db, fn), Step(app, fn)]) - plan = build_plan( + graph = Graph.from_steps([ + Step.from_stack(vpc, fn), + Step.from_stack(db, fn), + Step.from_stack(app, fn) + ]) + plan = Plan.from_graph( description="Test", graph=graph, targets=['db.1']) @@ -159,11 +163,11 @@ def fn(stack, status=None): raise ValueError('Boom') return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) - graph = build_graph([vpc_step, bastion_step]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([vpc_step, bastion_step]) + plan = Plan.from_graph(description="Test", graph=graph) with self.assertRaises(PlanFailed): plan.execute(walk) @@ -187,11 +191,11 @@ def fn(stack, status=None): return SKIPPED return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) - graph = build_graph([vpc_step, bastion_step]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([vpc_step, bastion_step]) + plan = Plan.from_graph(description="Test", graph=graph) plan.execute(walk) self.assertEquals(calls, ['namespace-vpc.1', 'namespace-bastion.1']) @@ -215,13 +219,13 @@ def fn(stack, status=None): return FAILED return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) - db_step = Step(db, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) + db_step = Step.from_stack(db, fn) - graph = build_graph([ + graph = Graph.from_steps([ vpc_step, bastion_step, db_step]) - plan = build_plan(description="Test", graph=graph) + plan = Plan.from_graph(description="Test", graph=graph) with self.assertRaises(PlanFailed): plan.execute(walk) @@ -245,11 +249,11 @@ def fn(stack, status=None): raise CancelExecution return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) - graph = build_graph([vpc_step, bastion_step]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([vpc_step, bastion_step]) + plan = Plan.from_graph(description="Test", graph=graph) plan.execute(walk) self.assertEquals(calls, ['namespace-vpc.1', 'namespace-bastion.1']) @@ -261,7 +265,7 @@ def test_build_graph_missing_dependency(self): context=self.context) with self.assertRaises(GraphError) as expected: - build_graph([Step(bastion, None)]) + Graph.from_steps([Step.from_stack(bastion, None)]) message_starts = ( "Error detected when adding 'vpc.1' " "as a dependency of 'bastion.1':" @@ -285,7 +289,11 @@ def test_build_graph_cyclic_dependencies(self): context=self.context) with self.assertRaises(GraphError) as expected: - build_graph([Step(vpc, None), Step(db, None), Step(app, None)]) + Graph.from_steps([ + Step.from_stack(vpc, None), + Step.from_stack(db, None), + Step.from_stack(app, None) + ]) message = ("Error detected when adding 'db.1' " "as a dependency of 'app.1': graph is " "not acyclic") @@ -311,19 +319,22 @@ def test_dump(self, *args): context=self.context) requires = [stack.name] - steps += [Step(stack, None)] + steps += [Step.from_stack(stack, None)] - graph = build_graph(steps) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps(steps) + plan = Plan.from_graph(description="Test", graph=graph) tmp_dir = tempfile.mkdtemp() try: plan.dump(tmp_dir, context=self.context) for step in plan.steps: + if not isinstance(step.subject, Stack): + continue + template_path = os.path.join( tmp_dir, - stack_template_key_name(step.stack.blueprint)) + stack_template_key_name(step.subject.blueprint)) self.assertTrue(os.path.isfile(template_path)) finally: shutil.rmtree(tmp_dir) diff --git a/stacker/tests/test_util.py b/stacker/tests/test_util.py index 9c4fa7635..218b594af 100644 --- a/stacker/tests/test_util.py +++ b/stacker/tests/test_util.py @@ -3,23 +3,19 @@ from __future__ import absolute_import from future import standard_library standard_library.install_aliases() - -import unittest - -import string import os -import queue +import string +import unittest import mock import boto3 -from stacker.config import Hook, GitPackageSource +from stacker.config import GitPackageSource from stacker.util import ( cf_safe_name, load_object_from_string, camel_to_snake, - handle_hooks, merge_map, yaml_to_ordered_dict, get_client_region, @@ -33,10 +29,6 @@ SourceProcessor ) -from .factories import ( - mock_context, - mock_provider, -) regions = ["us-east-1", "cn-north-1", "ap-northeast-1", "eu-west-1", "ap-southeast-1", "ap-southeast-2", "us-west-2", "us-gov-west-1", @@ -274,148 +266,6 @@ def test_SourceProcessor_helpers(self): ) -hook_queue = queue.Queue() - - -def mock_hook(*args, **kwargs): - hook_queue.put(kwargs) - return True - - -def fail_hook(*args, **kwargs): - return None - - -def exception_hook(*args, **kwargs): - raise Exception - - -def context_hook(*args, **kwargs): - return "context" in kwargs - - -def result_hook(*args, **kwargs): - return {"foo": "bar"} - - -class TestHooks(unittest.TestCase): - - def setUp(self): - self.context = mock_context(namespace="namespace") - self.provider = mock_provider(region="us-east-1") - - def test_empty_hook_stage(self): - hooks = [] - handle_hooks("fake", hooks, self.provider, self.context) - self.assertTrue(hook_queue.empty()) - - def test_missing_required_hook(self): - hooks = [Hook({"path": "not.a.real.path", "required": True})] - with self.assertRaises(ImportError): - handle_hooks("missing", hooks, self.provider, self.context) - - def test_missing_required_hook_method(self): - hooks = [{"path": "stacker.hooks.blah", "required": True}] - with self.assertRaises(AttributeError): - handle_hooks("missing", hooks, self.provider, self.context) - - def test_missing_non_required_hook_method(self): - hooks = [Hook({"path": "stacker.hooks.blah", "required": False})] - handle_hooks("missing", hooks, self.provider, self.context) - self.assertTrue(hook_queue.empty()) - - def test_default_required_hook(self): - hooks = [Hook({"path": "stacker.hooks.blah"})] - with self.assertRaises(AttributeError): - handle_hooks("missing", hooks, self.provider, self.context) - - def test_valid_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.mock_hook", - "required": True})] - handle_hooks("missing", hooks, self.provider, self.context) - good = hook_queue.get_nowait() - self.assertEqual(good["provider"].region, "us-east-1") - with self.assertRaises(queue.Empty): - hook_queue.get_nowait() - - def test_valid_enabled_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.mock_hook", - "required": True, "enabled": True})] - handle_hooks("missing", hooks, self.provider, self.context) - good = hook_queue.get_nowait() - self.assertEqual(good["provider"].region, "us-east-1") - with self.assertRaises(queue.Empty): - hook_queue.get_nowait() - - def test_valid_enabled_false_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.mock_hook", - "required": True, "enabled": False})] - handle_hooks("missing", hooks, self.provider, self.context) - self.assertTrue(hook_queue.empty()) - - def test_context_provided_to_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.context_hook", - "required": True})] - handle_hooks("missing", hooks, "us-east-1", self.context) - - def test_hook_failure(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.fail_hook", - "required": True})] - with self.assertRaises(SystemExit): - handle_hooks("fail", hooks, self.provider, self.context) - hooks = [{"path": "stacker.tests.test_util.exception_hook", - "required": True}] - with self.assertRaises(Exception): - handle_hooks("fail", hooks, self.provider, self.context) - hooks = [ - Hook({"path": "stacker.tests.test_util.exception_hook", - "required": False})] - # Should pass - handle_hooks("ignore_exception", hooks, self.provider, self.context) - - def test_return_data_hook(self): - hooks = [ - Hook({ - "path": "stacker.tests.test_util.result_hook", - "data_key": "my_hook_results" - }), - # Shouldn't return data - Hook({ - "path": "stacker.tests.test_util.context_hook" - }) - ] - handle_hooks("result", hooks, "us-east-1", self.context) - - self.assertEqual( - self.context.hook_data["my_hook_results"]["foo"], - "bar" - ) - # Verify only the first hook resulted in stored data - self.assertEqual( - list(self.context.hook_data.keys()), ["my_hook_results"] - ) - - def test_return_data_hook_duplicate_key(self): - hooks = [ - Hook({ - "path": "stacker.tests.test_util.result_hook", - "data_key": "my_hook_results" - }), - Hook({ - "path": "stacker.tests.test_util.result_hook", - "data_key": "my_hook_results" - }) - ] - - with self.assertRaises(KeyError): - handle_hooks("result", hooks, "us-east-1", self.context) - - class TestException1(Exception): pass diff --git a/stacker/tests/test_variables.py b/stacker/tests/test_variables.py index 2b1acbc55..74bc1d6ec 100644 --- a/stacker/tests/test_variables.py +++ b/stacker/tests/test_variables.py @@ -4,85 +4,49 @@ import unittest -from mock import MagicMock - from troposphere import s3 + from stacker.blueprints.variables.types import TroposphereType +from stacker.lookups.handlers import LookupHandler from stacker.variables import Variable from stacker.lookups import register_lookup_handler -from stacker.stack import Stack +from .factories import mock_context, mock_provider -from .factories import generate_definition + +class MockLookup(LookupHandler): + @classmethod + def handle(cls, value, context, provider): + return str(value) class TestVariables(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = MagicMock() + self.provider = mock_provider() + self.context = mock_context() + + register_lookup_handler("test", MockLookup) def test_variable_replace_no_lookups(self): var = Variable("Param1", "2") self.assertEqual(var.value, "2") - def test_variable_replace_simple_lookup(self): - var = Variable("Param1", "${output fakeStack::FakeOutput}") - var._value._resolve("resolved") - self.assertEqual(var.value, "resolved") - def test_variable_resolve_simple_lookup(self): - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "FakeOutput": "resolved", - "FakeOutput2": "resolved2", - }) - - self.context.get_stack.return_value = stack - - var = Variable("Param1", "${output fakeStack::FakeOutput}") - var.resolve(self.context, self.provider) - self.assertTrue(var.resolved) - self.assertEqual(var.value, "resolved") - - def test_variable_resolve_default_lookup_empty(self): - var = Variable("Param1", "${default fakeStack::}") + var = Variable("Param1", "${noop test}") var.resolve(self.context, self.provider) self.assertTrue(var.resolved) - self.assertEqual(var.value, "") + self.assertEqual(var.value, "test") def test_variable_replace_multiple_lookups_string(self): var = Variable( "Param1", "url://" # 0 - "${output fakeStack::FakeOutput}" # 1 + "${test resolved}" # 1 "@" # 2 - "${output fakeStack::FakeOutput2}", # 3 - ) - var._value[1]._resolve("resolved") - var._value[3]._resolve("resolved2") - self.assertEqual(var.value, "url://resolved@resolved2") - - def test_variable_resolve_multiple_lookups_string(self): - var = Variable( - "Param1", - "url://${output fakeStack::FakeOutput}@" - "${output fakeStack::FakeOutput2}", + "${test resolved2}", # 3 ) - - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "FakeOutput": "resolved", - "FakeOutput2": "resolved2", - }) - - self.context.get_stack.return_value = stack var.resolve(self.context, self.provider) - self.assertTrue(var.resolved) self.assertEqual(var.value, "url://resolved@resolved2") def test_variable_replace_no_lookups_list(self): @@ -90,77 +54,52 @@ def test_variable_replace_no_lookups_list(self): self.assertEqual(var.value, ["something", "here"]) def test_variable_replace_lookups_list(self): - value = ["something", # 0 - "${output fakeStack::FakeOutput}", # 1 - "${output fakeStack::FakeOutput2}" # 2 - ] + value = ["something", "${test resolved}", "${test resolved2}"] var = Variable("Param1", value) - - var._value[1]._resolve("resolved") - var._value[2]._resolve("resolved2") + var.resolve(self.context, self.provider) self.assertEqual(var.value, ["something", "resolved", "resolved2"]) def test_variable_replace_lookups_dict(self): value = { - "something": "${output fakeStack::FakeOutput}", - "other": "${output fakeStack::FakeOutput2}", + "something": "${test resolved}", + "other": "${test resolved2}", } var = Variable("Param1", value) - var._value["something"]._resolve("resolved") - var._value["other"]._resolve("resolved2") - self.assertEqual(var.value, {"something": "resolved", "other": - "resolved2"}) + var.resolve(self.context, self.provider) + self.assertEqual(var.value, {"something": "resolved", + "other": "resolved2"}) def test_variable_replace_lookups_mixed(self): value = { - "something": [ - "${output fakeStack::FakeOutput}", - "other", + "list": [ + "${test 1}", + "2", ], - "here": { - "other": "${output fakeStack::FakeOutput2}", - "same": "${output fakeStack::FakeOutput}", - "mixed": "something:${output fakeStack::FakeOutput3}", + "dict": { + "1": "${test a}", + "2": "${test b}", + "3": "c:${test d}", }, } var = Variable("Param1", value) - var._value["something"][0]._resolve("resolved") - var._value["here"]["other"]._resolve("resolved2") - var._value["here"]["same"]._resolve("resolved") - var._value["here"]["mixed"][1]._resolve("resolved3") + var.resolve(self.context, self.provider) self.assertEqual(var.value, { - "something": [ - "resolved", - "other", - ], - "here": { - "other": "resolved2", - "same": "resolved", - "mixed": "something:resolved3", + "list": ["1", "2"], + "dict": { + "1": "a", + "2": "b", + "3": "c:d", }, }) def test_variable_resolve_nested_lookup(self): - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "FakeOutput": "resolved", - "FakeOutput2": "resolved2", - }) - - def mock_handler(value, context, provider, **kwargs): - return "looked up: {}".format(value) - - register_lookup_handler("lookup", mock_handler) - self.context.get_stack.return_value = stack var = Variable( "Param1", - "${lookup ${lookup ${output fakeStack::FakeOutput}}}", + "${test a:${test b:${test c}}}", ) var.resolve(self.context, self.provider) self.assertTrue(var.resolved) - self.assertEqual(var.value, "looked up: looked up: resolved") + self.assertEqual(var.value, "a:b:c") def test_troposphere_type_no_from_dict(self): with self.assertRaises(ValueError): diff --git a/stacker/util.py b/stacker/util.py index 4f95a52f6..41595dbec 100644 --- a/stacker/util.py +++ b/stacker/util.py @@ -16,7 +16,6 @@ import tempfile import zipfile -import collections from collections import OrderedDict import botocore.client @@ -26,7 +25,7 @@ from yaml.constructor import ConstructorError from yaml.nodes import MappingNode -from .awscli_yamlhelper import yaml_parse +from stacker.awscli_yamlhelper import yaml_parse from stacker.session_cache import get_session logger = logging.getLogger(__name__) @@ -337,74 +336,6 @@ def cf_safe_name(name): return "".join([uppercase_first_letter(part) for part in parts]) -def handle_hooks(stage, hooks, provider, context): - """ Used to handle pre/post_build hooks. - - These are pieces of code that we want to run before/after the builder - builds the stacks. - - Args: - stage (string): The current stage (pre_run, post_run, etc). - hooks (list): A list of :class:`stacker.config.Hook` containing the - hooks to execute. - provider (:class:`stacker.provider.base.BaseProvider`): The provider - the current stack is using. - context (:class:`stacker.context.Context`): The current stacker - context. - """ - if not hooks: - logger.debug("No %s hooks defined.", stage) - return - - hook_paths = [] - for i, h in enumerate(hooks): - try: - hook_paths.append(h.path) - except KeyError: - raise ValueError("%s hook #%d missing path." % (stage, i)) - - logger.info("Executing %s hooks: %s", stage, ", ".join(hook_paths)) - for hook in hooks: - data_key = hook.data_key - required = hook.required - kwargs = hook.args or {} - enabled = hook.enabled - if not enabled: - logger.debug("hook with method %s is disabled, skipping", - hook.path) - continue - try: - method = load_object_from_string(hook.path) - except (AttributeError, ImportError): - logger.exception("Unable to load method at %s:", hook.path) - if required: - raise - continue - try: - result = method(context=context, provider=provider, **kwargs) - except Exception: - logger.exception("Method %s threw an exception:", hook.path) - if required: - raise - continue - if not result: - if required: - logger.error("Required hook %s failed. Return value: %s", - hook.path, result) - sys.exit(1) - logger.warning("Non-required hook %s failed. Return value: %s", - hook.path, result) - else: - if isinstance(result, collections.Mapping): - if data_key: - logger.debug("Adding result for hook %s to context in " - "data_key %s.", hook.path, data_key) - context.set_hook_data(data_key, result) - else: - logger.debug("Hook %s returned result data, but no data " - "key set, so ignoring.", hook.path) - - def get_config_directory(): """Return the directory the config file is located in. diff --git a/tests/fixtures/blueprints/bucket.yaml.j2 b/tests/fixtures/blueprints/bucket.yaml.j2 new file mode 100644 index 000000000..2687f4473 --- /dev/null +++ b/tests/fixtures/blueprints/bucket.yaml.j2 @@ -0,0 +1,10 @@ +AWSTemplateFormatVersion: 2010-09-09 +Resources: + Bucket: + Type: AWS::S3::Bucket + Properties: + BucketName: {{ variables.BucketName }} + AccessControl: Private +Outputs: + BucketName: + Value: !Ref Bucket diff --git a/tests/test_helper.bash b/tests/test_helper.bash index 1d0d52194..1392df715 100644 --- a/tests/test_helper.bash +++ b/tests/test_helper.bash @@ -36,7 +36,23 @@ assert() { # Checks that the given line is in $output. assert_has_line() { - echo "$output" | grep "$@" 1>/dev/null + echo "$output" | grep -q "$@" +} + +assert_has_lines_in_order() { + local search_line + read -r search_line || return $? + + for line in "${lines[@]}"; do + if grep -q "$@" "$search_line" <<< "$line"; then + if ! read -r search_line && [ -z "$search_line" ]; then + return 0 + fi + fi + done + + echo "Error: did not match line in correct order: '$search_line'" >&2 + return 1 } # This helper wraps "stacker" with bats' "run" and also outputs debug diff --git a/tests/test_suite/34_stacker_build-integrated-hooks.bats b/tests/test_suite/34_stacker_build-integrated-hooks.bats new file mode 100644 index 000000000..b1da384f5 --- /dev/null +++ b/tests/test_suite/34_stacker_build-integrated-hooks.bats @@ -0,0 +1,130 @@ +#!/usr/bin/env bats + +# This test will exercise the integration of hooks among the execution of stacks +# making use of the fact that S3 buckets cannot be deleted when not empty. +# The test will create the bucket and populate it during build, and erase the +# objects before destruction. If the hooks are not executed in the proper order, +# the destruction will fail, and so will the tst. + +load ../test_helper + +@test "stacker build - integrated hooks" { + needs_aws + + config() { + echo "namespace: ${STACKER_NAMESPACE}-integrated-hooks" + cat <<'EOF' +stacks: + - name: bucket + profile: stacker + template_path: fixtures/blueprints/bucket.yaml.j2 + variables: + BucketName: "stacker-${envvar STACKER_NAMESPACE}-integrated-hooks-${awsparam AccountId}" + +build_hooks: + - name: write-hello + path: stacker.hooks.command.run_command + args: + command: 'echo "Hello from Stacker!" > /tmp/hello.txt' + shell: true + + - name: send-hello + path: stacker.hooks.command.run_command + requires: + - write-hello + args: + command: 'aws s3 cp /tmp/hello.txt "s3://$BUCKET/hello.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + + - name: send-world + path: stacker.hooks.command.run_command + requires: + - send-hello + args: + command: 'aws s3 cp "s3://$BUCKET/hello.txt" "s3://$BUCKET/world.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + +destroy_hooks: + - name: remove-world + path: stacker.hooks.command.run_command + args: + command: 'aws s3 rm "s3://$BUCKET/world.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + + - name: remove-hello + path: stacker.hooks.command.run_command + required_by: + - remove-world + args: + command: 'aws s3 rm "s3://$BUCKET/hello.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + + - name: clean-hello + path: stacker.hooks.command.run_command + required_by: + - bucket + args: + command: [rm, -f, /tmp/hello.txt] +EOF + } + + teardown() { + stacker destroy --force <(config) + } + + stacker build -t --recreate-failed <(config) + assert "$status" -eq 0 + assert_has_line "Using default AWS provider mode" + assert_has_lines_in_order -E <<'EOF' +pre_build_hooks: complete +write-hello: complete +bucket: submitted \(creating new stack\) +bucket: complete \(creating new stack\) +upload: [^ ]*/hello.txt to s3://[^ ]*/hello.txt +send-hello: complete +copy: s3://[^ ]*/hello.txt to s3://[^ ]*/world.txt +send-world: complete +post_build_hooks: complete +EOF + + stacker destroy --force <(config) + assert "$status" -eq 0 + assert_has_line "Using default AWS provider mode" + assert_has_lines_in_order -E <<'EOF' +pre_destroy_hooks: complete +delete: s3://[^ ]*/world.txt +remove-world: complete +delete: s3://[^ ]*/hello.txt +remove-hello: complete +bucket: submitted \(submitted for destruction\) +bucket: complete \(stack destroyed\) +clean-hello: complete +post_destroy_hooks: complete +EOF + assert ! -e /tmp/hello.txt + + # Check that hooks that use lookups from stacks that do not exist anymore are + # not run + stacker destroy --force <(config) + assert "$status" -eq 0 + assert_has_lines_in_order <<'EOF' +pre_destroy_hooks: complete +remove-world: skipped +remove-hello: skipped +bucket: skipped +clean-hello: complete +post_destroy_hooks: complete +EOF +}