diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 5b02dc1a22..e23749d5f4 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -24,5 +24,8 @@ from openfisca_core.errors import CycleError, NaNCreationError, SpiralError # noqa: F401 from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 +from .axis import Axis # noqa: F401 +from .axis_array import AxisArray # noqa: F401 +from .axis_expander import AxisExpander # noqa: F401 from .simulation import Simulation # noqa: F401 from .simulation_builder import SimulationBuilder # noqa: F401 diff --git a/openfisca_core/simulations/axis.py b/openfisca_core/simulations/axis.py new file mode 100644 index 0000000000..a1fec9bcba --- /dev/null +++ b/openfisca_core/simulations/axis.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import dataclasses +from typing import Optional, Union + + +@dataclasses.dataclass(frozen = True) +class Axis: + """ + Base data class for axes (no domain logic). + + Attributes: + + name: The name of the :class:`openfisca_core.variables.Variable` + whose values are to be expanded. + count: The Number of "steps" to take when expanding a + :class:`openfisca_core.variables.Variable` (between + :attr:`min_` and :attr:`max_`, we create a line and split it in + :attr:`count` number of parts). + min: The starting numerical value for the :class:`Axis` expansion. + max: The up-to numerical value for the :class:`Axis` expansion. + period: The period at which the expansion will take place over. + index: The :class:`Axis` position relative to other equidistant axes. + + Usage: + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis + Axis(name='salary', count=3, min=0, max=3000, period=None, index=0) + + >>> axis.name + 'salary' + + Testing: + + pytest tests/core/test_axes.py openfisca_core/simulations/axis.py + + .. versionadded:: 35.4.0 + """ + + name: str + count: int + min: Union[int, float] + max: Union[int, float] + period: Optional[Union[int, str]] = dataclasses.field(default = None) + index: int = dataclasses.field(default = 0) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py new file mode 100644 index 0000000000..c6734c09da --- /dev/null +++ b/openfisca_core/simulations/axis_array.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import dataclasses +from typing import Any, Callable, Iterator, List, NoReturn, Union + +from . import Axis + + +@dataclasses.dataclass(frozen = True) +class AxisArray: + """ + A collection of :obj:`Axis` (some domain logic related to data integrity). + + As axes have a relative position relative to each other, they can be either + equidistant, that is parallel, or perpendicular. We assume the first + dimension to be a collection of parallel axes relative to themselves. + + Henceforward, we will consider each parallel axis as belonging to this + first dimension, and each perpendicular axis as belonging to a new + dimension, perpendicular to the previous one: that is, we won't be adding + more than one axis for each perpendicular dimension beyond the first one. + + As you might've already guess, it is not possible to add any parallel or + perpendicular axis relative to anything, so we assume the following when + our collection is yet empty: whenever you add a parallel axis it will by + default be added to the first dimension, and whenever you add a + perpendicular axis it will be added in isolation to second dimension and + beyond. + + Adding a perpendicular axis to an empty collection is a conceptual error + so instead of trying to mitigate this, we will rather error out and let + the user know why she can't do that and how she can correct they way she's + building her collection of axes (simply put to add first a parallel axis). + + Attributes: + + axes: A :type:`list` containing our collection of :obj:`Axis`. + + Usage: + + >>> axis_array = AxisArray() + >>> axis_array + AxisArray[[]] + + >>> salary = Axis(name = "salary", count = 3, min = 0, max = 3) + >>> axis_array = axis_array.add_parallel(salary) + >>> axis_array + AxisArray[AxisArray[Axis(name='salary', ..., index=0)]] + + >>> pension = Axis(name = "pension", count = 2, min = 0, max = 1) + >>> axis_array = axis_array.add_perpendicular(pension) + >>> axis_array + AxisArray[AxisArray[Axis(...)], AxisArray[Axis(...)]] + + >>> rent = Axis(name = "rent", count = 3, min = 0, max = 2) + >>> axis_array.add_parallel(rent) + AxisArray[AxisArray[Axis(...), Axis(...)], AxisArray[Axis(...)]] + + Testing: + + pytest tests/core/test_axes.py openfisca_core/simulations/axis_array.py + + .. versionadded:: 35.4.0 + """ + + axes: List[Union[AxisArray, Axis, list]] = \ + dataclasses \ + .field(default_factory = lambda: [[]]) + + def __post_init__(self) -> None: + axes = self.validate(isinstance, self.axes, list) + + for item in self.__flatten(axes): + self.validate(isinstance, item, (AxisArray, Axis)) + + def __contains__(self, item: Union[AxisArray, Axis]) -> bool: + return item in self.axes + + def __iter__(self) -> Iterator: + return (item for item in self.axes) + + def __len__(self) -> int: + return len(self.axes) + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}{repr(self.axes)}" + + def first(self) -> Union[AxisArray, Axis, List]: + """ + Retrieves the first axis from our axes collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis_array.first() + [] + + >>> axis_array = AxisArray([]) + >>> axis_array.first() + Traceback (most recent call last): + TypeError: Expecting a non empty list, but [] given. + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> node_array = AxisArray([axis]) + >>> node_array.first() + Axis(name='salary', count=3, min=0, max=3000, period=None, index=0) + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array = AxisArray() + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array.first() + AxisArray[Axis(name='salary', ..., index=0)] + + .. versionadded:: 35.4.0 + """ + self.validate(lambda item, _: item, self.axes, "a non empty list") + return self.axes[0] + + def last(self) -> Union[AxisArray, Axis, List]: + """ + Retrieves the last axis from our axes collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis_array.last() + [] + + >>> axis_array = AxisArray([]) + >>> axis_array.last() + Traceback (most recent call last): + TypeError: Expecting a non empty list, but [] given. + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> node_array = AxisArray([axis]) + >>> node_array.last() + Axis(name='salary', count=3, min=0, max=3000, period=None, index=0) + + >>> axis1 = Axis(name = "salary", count = 3, min = 0, max = 3) + >>> axis2 = Axis(name = "pension", count = 2, min = 0, max = 2) + >>> axis3 = Axis(name = "rent", count = 3, min = 0, max = 1) + >>> axis_array = AxisArray() + >>> axis_array = axis_array.add_parallel(axis1) + >>> axis_array = axis_array.add_perpendicular(axis2) + >>> axis_array = axis_array.add_parallel(axis3) + + >>> axis_array.first() + AxisArray[Axis(name='salary', ...), Axis(name='rent', ...)] + + >>> axis_array.first().last() + Axis(name='rent', ..., index=0) + + >>> axis_array.last() + AxisArray[Axis(name='pension', ...)] + + >>> axis_array.last().last() + Axis(name='pension', ..., index=0) + + .. versionadded:: 35.4.0 + """ + self.validate(lambda item, _: item, self.axes, "a non empty list") + return self.axes[-1] + + def add_parallel(self, tail: Axis) -> Union[AxisArray, NoReturn]: + """ + Add an :obj:`Axis` to the first dimension of our collection. + + Args: + + tail: An :obj:`Axis` to add to the first dimension of our + collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array + AxisArray[AxisArray[Axis(name='salary', ...)]] + + >>> axis = Axis(name = "pension", count = 2, min = 0, max = 3000) + >>> axis_array.add_parallel(axis) + Traceback (most recent call last): + TypeError: Expecting counts to be equal... + + .. versionadded:: 35.4.0 + """ + node = self.validate(isinstance, self.first(), (AxisArray, list)) + tail = self.validate(isinstance, tail, Axis) + parallel = self.__class__([*node, tail]) + self.validate(self.__has_same_counts, parallel, "counts to be equal") + return self.__class__([parallel, *self.axes[1:]]) + + def add_perpendicular(self, tail: Axis) -> Union[AxisArray, NoReturn]: + """ + Add an :obj:`Axis` to the subsequent dimensions of our collection. + + Args: + + tail: An :obj:`Axis` to add to the subsequent dimensions of + our collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array.add_perpendicular(axis) + Traceback (most recent call last): + TypeError: Expecting parallel axes set, but [] given. + + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array.add_perpendicular(axis) + AxisArray[AxisArray[Axis(name='salary', ...)]] + + .. versionadded:: 35.4.0 + """ + self.validate(lambda item, _: item, self.first(), "parallel axes set") + tail = self.validate(isinstance, tail, Axis) + perpendicular = self.__class__([tail]) + return self.__class__([*self.axes, perpendicular]) + + def validate( + self, + condition: Callable, + real: Any, + expected: Any, + ) -> Union[Any, NoReturn]: + """ + Validate that a condition holds true. + + Args: + + condition: A function reprensenting the condition to validate. + real: The value given as input, passed to :args:`condition`. + expected: The value we expect, passed to :args:`condition`. + + Usage: + + >>> axis_array = AxisArray() + >>> condition = isinstance + >>> real = tuple() + >>> expected = list + >>> axis_array.validate(condition, real, expected) + Traceback (most recent call last): + TypeError: Expecting , but () given. + + .. versionadded:: 35.4.0 + """ + if condition(real, expected): + return real + + raise TypeError(f"Expecting {expected}, but {real} given.") + + def __has_same_counts(self, axes: AxisArray, _) -> bool: + """ + Validate all counts on a collection are the same. + + They have to be the same for all parallel axes of a collection, + otherwise they become non expandable. + """ + counts = list(map(lambda axis: axis.count, axes)) + + if counts.count(counts[0]) == len(counts): + return True + + return False + + def __flatten(self, axes: list) -> List[Union[AxisArray, Axis]]: + """ + Flatten out our entire collection. + + Args: + + axes: Our collection. + + .. versionadded:: 35.4.0 + """ + if not axes: + return axes + + if isinstance(axes[0], list): + return self.__flatten(axes[0]) + self.__flatten(axes[1:]) + + return axes[:1] + self.__flatten(axes[1:]) diff --git a/openfisca_core/simulations/axis_expander.py b/openfisca_core/simulations/axis_expander.py new file mode 100644 index 0000000000..3c1f1aaff9 --- /dev/null +++ b/openfisca_core/simulations/axis_expander.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import functools +import typing + +if typing.TYPE_CHECKING: + from . import AxisArray + + +class AxisExpander: + """ + Expander of all axes for a given axes collection (lots of domain logic). + + Axis expansion is a feature in :module:`openfisca_core` that allows us to + parametrise some dimensions in order to create and to evaluate a range of + values for others. + + The most typical use of axis expansion is evaluate different numerical + values, starting from a :attr:`min_` and up to a :attr:`max_`, that could + take any given :class:`openfisca_core.variables.Variable` for any given + :class:`openfisca_core.periods.Period` for any given population (or a + collection of :module:`openfisca_core.entities`). + + Args: + + axis_array: An array of axes to expand. + + .. versionadded:: 35.4.0 + """ + + def __init__(self, axis_array: AxisArray) -> None: + self.__axis_array = axis_array + + @property + def axis_array(self) -> AxisArray: + """ + An array of axes to expand. + + .. versionadded:: 35.4.0 + """ + return self.__axis_array + + def count_cells(self): + """ + Count the total number of cells on an axes collection. + + As a collection of axes is comprised of several perpendicular + collections of axes, relative to each other, we're going to consider + the total number of cells as being the multiplication of the value + :attr:`Axis.count` for the first axis of each dimension. + + We assume however that all parallel axes should have the same count. + This method searches for a compatible axis (the first one). If none + exists, it should error out (we do not check here at it is the + responsability of :class:`AxisArray` to do so: data integrity is a + domain invariant related to the data model). + + Usage: + + >>> from . import Axis, AxisArray + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array = AxisArray() + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array = axis_array.add_perpendicular(axis) + >>> axis_expander = AxisExpander(axis_array) + >>> axis_expander.count_cells() + 9 + + .. versionadded:: 35.4.0 + """ + axis_count = map(lambda dim: dim.first().count, self.axis_array) + return functools.reduce(lambda acc, count: acc * count, axis_count, 1) diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 88553488db..bdea64fcfb 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,6 +1,7 @@ import copy import dpath -import typing +import warnings +from typing import Dict, Iterable, List import numpy @@ -8,9 +9,10 @@ from openfisca_core.entities import Entity from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation from openfisca_core.variables import Variable +from . import helpers, Axis, AxisArray, AxisExpander, Simulation + class SimulationBuilder: @@ -19,24 +21,24 @@ def __init__(self): self.persons_plural = None # Plural name for person entity in current tax and benefits system # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {} - self.populations: typing.Dict[Entity.key, Population] = {} + self.input_buffer: Dict[Variable.name, Dict[str(periods.period), numpy.array]] = {} + self.populations: Dict[Entity.key, Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: typing.Dict[Entity.plural, int] = {} - # JSON input - typing.List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} + self.entity_counts: Dict[Entity.plural, int] = {} + # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. + self.entity_ids: Dict[Entity.plural, List[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.memberships: Dict[Entity.plural, List[int]] = {} + self.roles: Dict[Entity.plural, List[int]] = {} - self.variable_entities: typing.Dict[Variable.name, Entity] = {} + self.variable_entities: Dict[Variable.name, Entity] = {} - self.axes = [[]] - self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} - self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.axes = AxisArray() + self.axes_entity_counts: Dict[Entity.plural, int] = {} + self.axes_entity_ids: Dict[Entity.plural, List[int]] = {} + self.axes_memberships: Dict[Entity.plural, List[int]] = {} + self.axes_roles: Dict[Entity.plural, List[int]] = {} def build_from_dict(self, tax_benefit_system, input_dict): """ @@ -106,7 +108,13 @@ def build_from_entities(self, tax_benefit_system, input_dict): self.add_default_group_entity(persons_ids, entity_class) if axes: - self.axes = axes + for axis in axes[0]: + self.add_parallel_axis(axis) + + if len(axes) >= 1: + for axis in axes[1:]: + self.add_perpendicular_axis(axis[0]) + self.expand_axes() try: @@ -167,14 +175,14 @@ def build_default_simulation(self, tax_benefit_system, count = 1): def create_entities(self, tax_benefit_system): self.populations = tax_benefit_system.instantiate_entities() - def declare_person_entity(self, person_singular, persons_ids: typing.Iterable): + def declare_person_entity(self, person_singular, persons_ids: Iterable): person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) self.persons_plural = person_instance.entity.plural - def declare_entity(self, entity_singular, entity_ids: typing.Iterable): + def declare_entity(self, entity_singular, entity_ids: Iterable): entity_instance = self.populations[entity_singular] entity_instance.ids = numpy.array(list(entity_ids)) entity_instance.count = len(entity_instance.ids) @@ -183,7 +191,7 @@ def declare_entity(self, entity_singular, entity_ids: typing.Iterable): def nb_persons(self, entity_singular, role = None): return self.populations[entity_singular].nb_persons(role = role) - def join_with_persons(self, group_population, persons_group_assignment, roles: typing.Iterable[str]): + def join_with_persons(self, group_population, persons_group_assignment, roles: Iterable[str]): # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) group_sorted_indices = numpy.unique(persons_group_assignment, return_inverse = True)[1] group_population.members_entity_id = numpy.argsort(group_population.ids)[group_sorted_indices] @@ -456,24 +464,66 @@ def get_roles(self, entity_name): # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) - def add_parallel_axis(self, axis): - # All parallel axes have the same count and entity. - # Search for a compatible axis, if none exists, error out - self.axes[0].append(axis) + def add_parallel_axis(self, axis: dict) -> None: + """ + Add a parallel axis to our collection of axes. + + Args: + + axis: An axis to add to our collection. + + .. deprecated:: 35.4.0 + + Use :meth:`AxisArray.add_parallel` instead. + """ + message = [ + "The 'add_parallel_axis' method has been deprecated since", + "version 35.4.0, and will be removed in the future. Please use", + "'AxisArray.add_parallel' instead", + ] + + warnings.warn(" ".join(message), DeprecationWarning) + self.axes = self.axes.add_parallel(Axis(**axis)) + + def add_perpendicular_axis(self, axis: dict) -> None: + """ + Add a perpendicular axis to all previous dimensions. + + Args: + + axis: An axis to add to our collection. + + .. deprecated:: 35.4.0 + + Use :meth:`AxisArray.add_parallel` instead. + """ + message = [ + "The 'add_perpendicular_axis' method has been deprecated since", + "version 35.4.0, and will be removed in the future. Please use", + "'AxisArray.add_perpendicular' instead", + ] - def add_perpendicular_axis(self, axis): - # This adds an axis perpendicular to all previous dimensions - self.axes.append([axis]) + warnings.warn(" ".join(message), DeprecationWarning) + self.axes = self.axes.add_perpendicular(Axis(**axis)) def expand_axes(self): - # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + """ + Expand all axes for the current simulation. + + .. deprecated:: 35.4.0 + + Use :class:`AxisExpander` instead. + """ + message = [ + "The 'expand_axes' method has been deprecated since", + "version 35.4.0, and will be removed in the future. Please use", + "'AxisExpander' instead", + ] + + warnings.warn(" ".join(message), DeprecationWarning) - cell_count = 1 - for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] - axis_count = first_axis['count'] - cell_count *= axis_count + expander = AxisExpander(self.axes) + cell_count = expander.count_cells() # Scale the "prototype" situation, repeating it cell_count times for entity_name in self.entity_counts.keys(): @@ -498,17 +548,17 @@ def expand_axes(self): # Now generate input values along the specified axes # TODO - factor out the common logic here - if len(self.axes) == 1 and len(self.axes[0]): - parallel_axes = self.axes[0] - first_axis = parallel_axes[0] - axis_count: int = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + if len(self.axes) == 1 and len(self.axes.first()): + parallel_axes = self.axes.first() + first_axis = parallel_axes.first() + axis_count: int = first_axis.count + axis_entity = self.get_variable_entity(first_axis.name) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis.get('period', self.default_period) - axis_name = axis['name'] + axis_index = axis.index + axis_period = axis.period or self.default_period + axis_name = axis.name variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: @@ -516,15 +566,15 @@ def expand_axes(self): elif array.size == axis_entity_step_size: array = numpy.tile(array, axis_count) array[axis_index:: axis_entity_step_size] = numpy.linspace( - axis['min'], - axis['max'], + axis.min, + axis.max, num = axis_count, ) # Set input self.input_buffer[axis_name][str(axis_period)] = array else: - first_axes_count: typing.List[int] = ( - parallel_axes[0]["count"] + first_axes_count: List[int] = ( + parallel_axes.first().count for parallel_axes in self.axes ) @@ -535,23 +585,23 @@ def expand_axes(self): ] axes_meshes = numpy.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): - first_axis = parallel_axes[0] - axis_count = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + first_axis = parallel_axes.first() + axis_count = first_axis.count + axis_entity = self.get_variable_entity(first_axis.name) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along the grid for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis['period'] or self.default_period - axis_name = axis['name'] + axis_index = axis.index + axis_period = axis.period or self.default_period + axis_name = axis.name variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array(cell_count * axis_entity_step_size) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) - array[axis_index:: axis_entity_step_size] = axis['min'] \ - + mesh.reshape(cell_count) * (axis['max'] - axis['min']) / (axis_count - 1) + array[axis_index:: axis_entity_step_size] = axis.min \ + + mesh.reshape(cell_count) * (axis.max - axis.min) / (axis_count - 1) self.input_buffer[axis_name][str(axis_period)] = array def get_variable_entity(self, variable_name): diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 686c9b27e7..248abd993c 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,15 +1,184 @@ import pytest -from pytest import fixture, approx -from openfisca_core.simulation_builder import SimulationBuilder +from openfisca_core.simulations import ( + Axis, + AxisArray, + AxisExpander, + SimulationBuilder, + ) + from .test_simulation_builder import * # noqa: F401 -@fixture +@pytest.fixture def simulation_builder(): return SimulationBuilder() +@pytest.fixture +def salary(): + return { + "name": "salary", + "count": 3, + "min": 0, + "max": 3000, + } + + +@pytest.fixture +def pension(): + return { + "name": "pension", + "count": 2, + "min": 0, + "max": 2000, + } + + +@pytest.fixture +def salary_axis(salary): + return Axis(**salary) + + +@pytest.fixture +def pension_axis(pension): + return Axis(**pension) + + +@pytest.fixture +def axis_array(): + return AxisArray() + + +@pytest.fixture +def axis_expander(axis_array, salary_axis, pension_axis): + axis_array = axis_array.add_parallel(salary_axis) + axis_array = axis_array.add_perpendicular(pension_axis) + return AxisExpander(axis_array) + + +# Axis + + +def test_create_axis(salary): + """ + Works! Missing fields are optional, so they default to None. + """ + result = Axis(**salary) + assert result.name == "salary" + assert not result.period + + +def test_create_empty_axis(): + """ + Fails because we're not providing the required fields. + """ + with pytest.raises(TypeError): + Axis() + + +# AxisArray + + +def test_empty_create_axis_array(): + """ + Nothing fancy, just an empty container. + """ + result = AxisArray() + assert isinstance(result, AxisArray) + + +def test_create_axis_array_with_axes(salary_axis): + """ + We can pass along some axes at initialisation time as well. + """ + result = AxisArray([salary_axis]) + assert result.first() == salary_axis + + +def test_create_axis_array_with_anything(salary_axis): + """ + If you don't pass a collection, it will fail! + """ + with pytest.raises(TypeError): + AxisArray(salary_axis) + + +def test_create_axis_array_with_a_collection_of_anything(): + """ + If you pass a collection of anything, it will fail! + """ + with pytest.raises(TypeError): + AxisArray(["axis"]) + + +def test_add_parallel_axis(axis_array, salary_axis): + """ + As there are no previously added axes in our collection, it adds the first + one to the first dimension (parallel). + """ + result = axis_array.add_parallel(salary_axis) + assert salary_axis in result.first() + assert result.first().first() == salary_axis + + +def test_add_parallel_axes_with_different_counts(axis_array, salary_axis, pension_axis): + """ + We can't, it should fail! + + Otherwise they become non expandable. + """ + with pytest.raises(TypeError): + axis_array.add_parallel(salary_axis).add_parallel(pension_axis) + + +def test_add_perpendicular_axis_before_parallel_axis(axis_array, pension_axis): + """ + As there are no previously added axes in our collection, it fails! + """ + with pytest.raises(TypeError): + axis_array.add_perpendicular(pension_axis) + + +def test_add_perpendicular_axis(axis_array, salary_axis, pension_axis): + """ + As there is at least one parallel axis already, it works! + """ + result = \ + axis_array \ + .add_parallel(salary_axis) \ + .add_perpendicular(pension_axis) + + # Parallel + assert salary_axis in result.first() + assert result.first().first() == salary_axis + + # Perpendicular + assert pension_axis in result.last() + assert result.last().first() == pension_axis + + +def test_add_perpendicular_axes_with_different_counts(axis_array, salary_axis, pension_axis): + """ + We can, because each perpendicular axis is added to the next dimension. + """ + assert \ + axis_array \ + .add_parallel(salary_axis) \ + .add_perpendicular(salary_axis) \ + .add_perpendicular(pension_axis) + + +# AxisExpander + + +def test_count_cells(axis_expander): + assert axis_expander.count_cells() == 6 + + +# SimulationBuilder + + # With periods @@ -19,7 +188,7 @@ def test_add_axis_without_period(simulation_builder, persons): simulation_builder.register_variable('salary', persons) simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) # With variables @@ -38,7 +207,7 @@ def test_add_axis_on_an_existing_variable_with_input(simulation_builder, persons simulation_builder.register_variable('salary', persons) simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) assert simulation_builder.get_count('persons') == 3 assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] @@ -51,7 +220,7 @@ def test_add_axis_on_persons(simulation_builder, persons): simulation_builder.register_variable('salary', persons) simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) assert simulation_builder.get_count('persons') == 3 assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] @@ -62,8 +231,8 @@ def test_add_two_axes(simulation_builder, persons): simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.add_parallel_axis({'count': 3, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == approx([0, 1000, 2000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) + assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 1000, 2000]) def test_add_axis_with_group(simulation_builder, persons): @@ -74,7 +243,7 @@ def test_add_axis_with_group(simulation_builder, persons): simulation_builder.expand_axes() assert simulation_builder.get_count('persons') == 4 assert simulation_builder.get_ids('persons') == ['Alicia0', 'Javier1', 'Alicia2', 'Javier3'] - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 0, 3000, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 0, 3000, 3000]) def test_add_axis_with_group_int_period(simulation_builder, persons): @@ -83,7 +252,7 @@ def test_add_axis_with_group_int_period(simulation_builder, persons): simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018}) simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018, 'index': 1}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018') == approx([0, 0, 3000, 3000]) + assert simulation_builder.get_input('salary', '2018') == pytest.approx([0, 0, 3000, 3000]) def test_add_axis_on_group_entity(simulation_builder, persons, group_entity): @@ -155,8 +324,8 @@ def test_add_perpendicular_axes(simulation_builder, persons): simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) + assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) def test_add_perpendicular_axis_on_an_existing_variable_with_input(simulation_builder, persons): @@ -171,8 +340,9 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(simulation_bu simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) + assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) + # Integration test @@ -199,5 +369,5 @@ def test_simulation_with_axes(simulation_builder): """ data = yaml.safe_load(input_yaml) simulation = simulation_builder.build_from_dict(tax_benefit_system, data) - assert simulation.get_array('salary', '2018-11') == approx([0, 0, 0, 0, 0, 0]) - assert simulation.get_array('rent', '2018-11') == approx([0, 0, 3000, 0]) + assert simulation.get_array('salary', '2018-11') == pytest.approx([0, 0, 0, 0, 0, 0]) + assert simulation.get_array('rent', '2018-11') == pytest.approx([0, 0, 3000, 0])