From dfc0ea18afccc234beb529c88beedea89af1d541 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 10 Apr 2021 15:41:14 +0200 Subject: [PATCH 01/11] Add Axis data class --- openfisca_core/simulations/__init__.py | 1 + openfisca_core/simulations/axis.py | 43 ++++++++++++ .../simulations/simulation_builder.py | 2 +- tests/core/test_axes.py | 66 ++++++++++++++----- 4 files changed, 95 insertions(+), 17 deletions(-) create mode 100644 openfisca_core/simulations/axis.py diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 5b02dc1a22..e3caacca91 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -24,5 +24,6 @@ from openfisca_core.errors import CycleError, NaNCreationError, SpiralError # noqa: F401 from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 +from .axis import Axis # noqa: F401 from .simulation import Simulation # noqa: F401 from .simulation_builder import SimulationBuilder # noqa: F401 diff --git a/openfisca_core/simulations/axis.py b/openfisca_core/simulations/axis.py new file mode 100644 index 0000000000..3d20b118b8 --- /dev/null +++ b/openfisca_core/simulations/axis.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import typing +import dataclasses + + +@dataclasses.dataclass(frozen = True) +class Axis: + """ + Base data class for axes (no business logic). + + Axis expansion is a feature in :module:`openfisca_core` that allows us to + parametrise some dimensions in order to create and to evaluate a range of + values for others. + + The most typical use of axis expansion is evaluate different numerical + values, starting from a :attr:`min_` and up to a :attr:`max_`, that could + take any given :class:`openfisca_core.variables.Variable` for any given + :class:`openfisca_core.periods.Period` for any given population (or a + collection of :module:`openfisca_core.entities`). + + Attributes: + + name: The name of the :class:`openfisca_core.variables.Variable` + whose values are to be expanded. + count: The Number of "steps" to take when expanding a + :class:`openfisca_core.variables.Variable` (between + :attr:`min_` and :attr:`max_`, we create a line and split it in + :attr:`count` number of parts). + min: The starting numerical value for the :class:`Axis` expansion. + max: The up-to numerical value for the :class:`Axis` expansion. + period: The period at which the expansion will take place over. + index: The :class:`Axis` position relative to other equidistant axes. + + .. versionadded:: 3.4.0 + """ + + name: str + count: int + min: typing.Union[int, float] + max: typing.Union[int, float] + period: typing.Optional[typing.Union[int, str]] = None + index: typing.Optional[int] = None diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 88553488db..53f4b2e4ab 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -8,7 +8,7 @@ from openfisca_core.entities import Entity from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation +from openfisca_core.simulations import helpers, Axis, Simulation from openfisca_core.variables import Variable diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 686c9b27e7..77daeba82f 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,15 +1,49 @@ import pytest -from pytest import fixture, approx -from openfisca_core.simulation_builder import SimulationBuilder +from openfisca_core.simulations import Axis, SimulationBuilder from .test_simulation_builder import * # noqa: F401 -@fixture +@pytest.fixture def simulation_builder(): return SimulationBuilder() +@pytest.fixture +def axis_args(): + return { + "name": "salary", + "count": 3, + "min": 0, + "max": 3000, + } + + +@pytest.fixture +def axis(axis_params): + return Axis(axis_args) + + +# Unit tests + +def test_create_axis(axis_args): + """ + Works! Missing fields are optional, so they default to None. + """ + result = Axis(**axis_args) + assert result.name == "salary" + assert not result.period + assert not result.index + + +def test_create_empty_axis(): + """ + Fails because we're not providing the required fields. + """ + with pytest.raises(TypeError): + Axis() + + # With periods @@ -19,7 +53,7 @@ def test_add_axis_without_period(simulation_builder, persons): simulation_builder.register_variable('salary', persons) simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) # With variables @@ -38,7 +72,7 @@ def test_add_axis_on_an_existing_variable_with_input(simulation_builder, persons simulation_builder.register_variable('salary', persons) simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) assert simulation_builder.get_count('persons') == 3 assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] @@ -51,7 +85,7 @@ def test_add_axis_on_persons(simulation_builder, persons): simulation_builder.register_variable('salary', persons) simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) assert simulation_builder.get_count('persons') == 3 assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] @@ -62,8 +96,8 @@ def test_add_two_axes(simulation_builder, persons): simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.add_parallel_axis({'count': 3, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == approx([0, 1000, 2000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) + assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 1000, 2000]) def test_add_axis_with_group(simulation_builder, persons): @@ -74,7 +108,7 @@ def test_add_axis_with_group(simulation_builder, persons): simulation_builder.expand_axes() assert simulation_builder.get_count('persons') == 4 assert simulation_builder.get_ids('persons') == ['Alicia0', 'Javier1', 'Alicia2', 'Javier3'] - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 0, 3000, 3000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 0, 3000, 3000]) def test_add_axis_with_group_int_period(simulation_builder, persons): @@ -83,7 +117,7 @@ def test_add_axis_with_group_int_period(simulation_builder, persons): simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018}) simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018, 'index': 1}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018') == approx([0, 0, 3000, 3000]) + assert simulation_builder.get_input('salary', '2018') == pytest.approx([0, 0, 3000, 3000]) def test_add_axis_on_group_entity(simulation_builder, persons, group_entity): @@ -155,8 +189,8 @@ def test_add_perpendicular_axes(simulation_builder, persons): simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) + assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) def test_add_perpendicular_axis_on_an_existing_variable_with_input(simulation_builder, persons): @@ -171,8 +205,8 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(simulation_bu simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) + assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) # Integration test @@ -199,5 +233,5 @@ def test_simulation_with_axes(simulation_builder): """ data = yaml.safe_load(input_yaml) simulation = simulation_builder.build_from_dict(tax_benefit_system, data) - assert simulation.get_array('salary', '2018-11') == approx([0, 0, 0, 0, 0, 0]) - assert simulation.get_array('rent', '2018-11') == approx([0, 0, 3000, 0]) + assert simulation.get_array('salary', '2018-11') == pytest.approx([0, 0, 0, 0, 0, 0]) + assert simulation.get_array('rent', '2018-11') == pytest.approx([0, 0, 3000, 0]) From 57143e3a37bd2fe0635b8470c46da6ff76130fc7 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 10 Apr 2021 21:52:22 +0200 Subject: [PATCH 02/11] Add AxisArray collection class --- openfisca_core/simulations/__init__.py | 1 + openfisca_core/simulations/axis_array.py | 24 ++++++++++++++++++++++++ tests/core/test_axes.py | 11 +++++++++-- 3 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 openfisca_core/simulations/axis_array.py diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index e3caacca91..b0a54d533e 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -25,5 +25,6 @@ from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 from .axis import Axis # noqa: F401 +from .axis_array import AxisArray # noqa: F401 from .simulation import Simulation # noqa: F401 from .simulation_builder import SimulationBuilder # noqa: F401 diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py new file mode 100644 index 0000000000..15f30d2334 --- /dev/null +++ b/openfisca_core/simulations/axis_array.py @@ -0,0 +1,24 @@ +import collections.abc + +from .axis import Axis + + +class AxisArray(collections.abc.Container): + """ + Simply a collection of :class:`Axis`. + + Axis expansion is a feature in :module:`openfisca_core` that allows us to + parametrise some dimensions in order to create and to evaluate a range of + values for others. + + The most typical use of axis expansion is evaluate different numerical + values, starting from a :attr:`min_` and up to a :attr:`max_`, that could + take any given :class:`openfisca_core.variables.Variable` for any given + :class:`openfisca_core.periods.Period` for any given population (or a + collection of :module:`openfisca_core.entities`). + + .. versionadded:: 3.4.0 + """ + + def __contains__(self, axis: Axis) -> bool: + pass diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 77daeba82f..341646f2b5 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,6 +1,6 @@ import pytest -from openfisca_core.simulations import Axis, SimulationBuilder +from openfisca_core.simulations import Axis, AxisArray, SimulationBuilder from .test_simulation_builder import * # noqa: F401 @@ -20,7 +20,7 @@ def axis_args(): @pytest.fixture -def axis(axis_params): +def axis(axis_args): return Axis(axis_args) @@ -44,6 +44,13 @@ def test_create_empty_axis(): Axis() +def test_create_axis_array(): + """ + Nothing fancy, just an empty container. + """ + assert AxisArray() + + # With periods From 5ed84c80fbaa855cc8341f744b1800bada22e8ff Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 01:05:57 +0200 Subject: [PATCH 03/11] Allow to append axes to axis array --- openfisca_core/simulations/axis.py | 13 +++++++ openfisca_core/simulations/axis_array.py | 44 +++++++++++++++++++++--- tests/core/test_axes.py | 28 +++++++++++---- 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/openfisca_core/simulations/axis.py b/openfisca_core/simulations/axis.py index 3d20b118b8..bdb10a65ae 100644 --- a/openfisca_core/simulations/axis.py +++ b/openfisca_core/simulations/axis.py @@ -32,6 +32,19 @@ class Axis: period: The period at which the expansion will take place over. index: The :class:`Axis` position relative to other equidistant axes. + Usage: + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis + Axis(name='salary', count=3, min=0, max=3000, period=None, index=None) + + >>> axis.name + 'salary' + + Testing: + + pytest tests/core/test_axes.py openfisca_core/simulations/axis.py + .. versionadded:: 3.4.0 """ diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index 15f30d2334..027dc2472f 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -1,9 +1,13 @@ -import collections.abc +from __future__ import annotations + +import dataclasses +import typing from .axis import Axis -class AxisArray(collections.abc.Container): +@dataclasses.dataclass(frozen = True) +class AxisArray: """ Simply a collection of :class:`Axis`. @@ -17,8 +21,40 @@ class AxisArray(collections.abc.Container): :class:`openfisca_core.periods.Period` for any given population (or a collection of :module:`openfisca_core.entities`). + Attributes: + + axes: A :type:`tuple` containing our collection of :class:`Axis`. + + Usage: + + >>> axis_array = AxisArray() + >>> axis_array + AxisArray() + + Testing: + + pytest tests/core/test_axes.py openfisca_core/simulations/axis_array.py + .. versionadded:: 3.4.0 """ - def __contains__(self, axis: Axis) -> bool: - pass + axes: typing.Tuple[Axis, ...] = () + + def append(self, tail: Axis) -> AxisArray: + """ + Append an :class:`Axis` to our axes collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array.append(axis) # doctest: +ELLIPSIS + AxisArray(Axis(name='salary', ...),) + """ + return self.__class__(axes = (*self.axes, tail)) + + def __contains__(self, item: Axis) -> bool: + return item in self.axes + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}{repr(self.axes)}" diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 341646f2b5..0a020204c8 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -10,7 +10,7 @@ def simulation_builder(): @pytest.fixture -def axis_args(): +def kwargs(): return { "name": "salary", "count": 3, @@ -20,20 +20,25 @@ def axis_args(): @pytest.fixture -def axis(axis_args): - return Axis(axis_args) +def axis(kwargs): + return Axis(**kwargs) + + +@pytest.fixture +def axis_array(): + return AxisArray() # Unit tests -def test_create_axis(axis_args): + +def test_create_axis(kwargs): """ Works! Missing fields are optional, so they default to None. """ - result = Axis(**axis_args) + result = Axis(**kwargs) assert result.name == "salary" assert not result.period - assert not result.index def test_create_empty_axis(): @@ -48,7 +53,16 @@ def test_create_axis_array(): """ Nothing fancy, just an empty container. """ - assert AxisArray() + result = AxisArray() + assert isinstance(result, AxisArray) + + +def test_add_axis_to_array(axis_array, axis): + """ + If you add an :class:`Axis` to the array, it works! + """ + result = axis_array.append(axis) + assert axis in result # With periods From 1bb80e7f9eb1445374d9277bb75e07a82f1e4074 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 02:26:14 +0200 Subject: [PATCH 04/11] Add init/append validation --- openfisca_core/simulations/axis_array.py | 55 +++++++++++++++++++++--- tests/core/test_axes.py | 35 ++++++++++++++- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index 027dc2472f..0e957057d6 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections.abc import dataclasses import typing @@ -9,7 +10,7 @@ @dataclasses.dataclass(frozen = True) class AxisArray: """ - Simply a collection of :class:`Axis`. + A collection of :obj:`Axis` and a bunch of business logic. Axis expansion is a feature in :module:`openfisca_core` that allows us to parametrise some dimensions in order to create and to evaluate a range of @@ -23,7 +24,7 @@ class AxisArray: Attributes: - axes: A :type:`tuple` containing our collection of :class:`Axis`. + axes: A :type:`tuple` containing our collection of :obj:`Axis`. Usage: @@ -40,19 +41,59 @@ class AxisArray: axes: typing.Tuple[Axis, ...] = () + def first(self) -> typing.Optional[Axis]: + """ + Retrieves the first :obj:`Axis` in our axes collection. + + Usage: + + >>> axis_array = AxisArray() + >>> not axis_array.first() + True + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array = axis_array.append(axis) + >>> axis_array.first() + Axis(name='salary', ..., index=None) + """ + if len(self.axes) == 0: + return None + + return self.axes[0] + def append(self, tail: Axis) -> AxisArray: """ - Append an :class:`Axis` to our axes collection. + Append an :obj:`Axis` to our axes collection. + + Args: + + axis: An :obj:`Axis` to append to our collection. Usage: - >>> axis_array = AxisArray() - >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) - >>> axis_array.append(axis) # doctest: +ELLIPSIS - AxisArray(Axis(name='salary', ...),) + >>> axis_array = AxisArray() + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array.append(axis) + AxisArray(Axis(name='salary', ...),) """ + if not isinstance(self.axes, collections.abc.Iterable): + raise TypeError("Not an Iterable, but improve this message stub!") + + for axis in self.axes: + if not isinstance(axis, Axis): + raise TypeError("Not an Axis, but improve this message stub!") + return self.__class__(axes = (*self.axes, tail)) + def __post_init__(self) -> None: + if not isinstance(self.axes, collections.abc.Iterable): + raise TypeError("Not an Iterable, but improve this message stub!") + + for axis in self.axes: + if not isinstance(axis, Axis): + raise TypeError("Not an Axis, but improve this message stub!") + + def __contains__(self, item: Axis) -> bool: return item in self.axes diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 0a020204c8..295ff4f846 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -49,7 +49,7 @@ def test_create_empty_axis(): Axis() -def test_create_axis_array(): +def test_empty_create_axis_array(): """ Nothing fancy, just an empty container. """ @@ -57,14 +57,45 @@ def test_create_axis_array(): assert isinstance(result, AxisArray) +def test_create_axis_array_with_axes(axis): + """ + We can pass along some axes at initialisation time as well. + """ + result = AxisArray(axes = [axis]) + assert result.first() == axis + + +def test_create_axis_array_with_anything(axis): + """ + If you don't pass a collection, it will fail! + """ + with pytest.raises(TypeError): + AxisArray(axes = axis) + + +def test_create_axis_array_with_a_collection_of_anything(): + """ + If you pass anything, it will fail! + """ + with pytest.raises(TypeError): + AxisArray(axes = ["axis"]) + + def test_add_axis_to_array(axis_array, axis): """ - If you add an :class:`Axis` to the array, it works! + If you add an :obj:`Axis` to the array, it works! """ result = axis_array.append(axis) assert axis in result +def test_add_anything_to_array(axis_array, axis): + """ + If you add anything else to the array, it fails! + """ + with pytest.raises(TypeError): + axis_array.append("cuack") + # With periods From 63cdbc254ee370035095ab569604686dbf79202c Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 02:48:24 +0200 Subject: [PATCH 05/11] Add init/append validation --- openfisca_core/simulations/axis_array.py | 52 +++++++++---------- .../simulations/simulation_builder.py | 2 +- tests/core/test_axes.py | 2 +- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index 0e957057d6..e5bb343f22 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections.abc import dataclasses import typing @@ -24,13 +23,13 @@ class AxisArray: Attributes: - axes: A :type:`tuple` containing our collection of :obj:`Axis`. + axes: A :type:`list` containing our collection of :obj:`Axis`. Usage: >>> axis_array = AxisArray() >>> axis_array - AxisArray() + AxisArray[] Testing: @@ -39,11 +38,21 @@ class AxisArray: .. versionadded:: 3.4.0 """ - axes: typing.Tuple[Axis, ...] = () + axes: typing.List[Axis] = dataclasses.field(default_factory = list) + + def __post_init__(self) -> None: + self.__is_list(self.axes) + list(map(self.__is_axis, self.axes)) + + def __contains__(self, item: Axis) -> bool: + return item in self.axes + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}{repr(self.axes)}" def first(self) -> typing.Optional[Axis]: """ - Retrieves the first :obj:`Axis` in our axes collection. + Retrieves the first :obj:`Axis` from our axes collection. Usage: @@ -61,7 +70,7 @@ def first(self) -> typing.Optional[Axis]: return self.axes[0] - def append(self, tail: Axis) -> AxisArray: + def append(self, tail: Axis) -> typing.Union[AxisArray, typing.NoReturn]: """ Append an :obj:`Axis` to our axes collection. @@ -74,28 +83,19 @@ def append(self, tail: Axis) -> AxisArray: >>> axis_array = AxisArray() >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) >>> axis_array.append(axis) - AxisArray(Axis(name='salary', ...),) + AxisArray[Axis(name='salary', ...)] """ - if not isinstance(self.axes, collections.abc.Iterable): - raise TypeError("Not an Iterable, but improve this message stub!") + self.__is_axis(tail) + return self.__class__(axes = [*self.axes, tail]) - for axis in self.axes: - if not isinstance(axis, Axis): - raise TypeError("Not an Axis, but improve this message stub!") + def __is_list(self, axes: list) -> typing.Union[bool, typing.NoReturn]: + if isinstance(axes, list): + return True - return self.__class__(axes = (*self.axes, tail)) + raise TypeError(f"Expecting a list, but {type(self.axes)} given") - def __post_init__(self) -> None: - if not isinstance(self.axes, collections.abc.Iterable): - raise TypeError("Not an Iterable, but improve this message stub!") - - for axis in self.axes: - if not isinstance(axis, Axis): - raise TypeError("Not an Axis, but improve this message stub!") - - - def __contains__(self, item: Axis) -> bool: - return item in self.axes + def __is_axis(self, item: Axis) -> typing.Union[bool, typing.NoReturn]: + if isinstance(item, Axis): + return True - def __repr__(self) -> str: - return f"{self.__class__.__qualname__}{repr(self.axes)}" + raise TypeError(f"Expecting an {Axis}, but {type(item)} given") diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 53f4b2e4ab..88553488db 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -8,7 +8,7 @@ from openfisca_core.entities import Entity from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Axis, Simulation +from openfisca_core.simulations import helpers, Simulation from openfisca_core.variables import Variable diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 295ff4f846..4d8a0fbe7e 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -75,7 +75,7 @@ def test_create_axis_array_with_anything(axis): def test_create_axis_array_with_a_collection_of_anything(): """ - If you pass anything, it will fail! + If you pass a collection of anything, it will fail! """ with pytest.raises(TypeError): AxisArray(axes = ["axis"]) From 1c31ebb061ad1acf4170f2a31bdd7eaa47132bf6 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 13:46:36 +0200 Subject: [PATCH 06/11] Use Axis in SimulationBuilder 1/2 --- openfisca_core/simulations/axis.py | 14 +- openfisca_core/simulations/axis_array.py | 84 ++++++++---- .../simulations/simulation_builder.py | 121 ++++++++++++------ tests/core/test_axes.py | 1 + 4 files changed, 150 insertions(+), 70 deletions(-) diff --git a/openfisca_core/simulations/axis.py b/openfisca_core/simulations/axis.py index bdb10a65ae..f4e1b34a69 100644 --- a/openfisca_core/simulations/axis.py +++ b/openfisca_core/simulations/axis.py @@ -1,7 +1,7 @@ from __future__ import annotations -import typing import dataclasses +from typing import Optional, Union @dataclasses.dataclass(frozen = True) @@ -36,7 +36,7 @@ class Axis: >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) >>> axis - Axis(name='salary', count=3, min=0, max=3000, period=None, index=None) + Axis(name='salary', count=3, min=0, max=3000, period=None, index=0) >>> axis.name 'salary' @@ -45,12 +45,12 @@ class Axis: pytest tests/core/test_axes.py openfisca_core/simulations/axis.py - .. versionadded:: 3.4.0 + .. versionadded:: 35.4.0 """ name: str count: int - min: typing.Union[int, float] - max: typing.Union[int, float] - period: typing.Optional[typing.Union[int, str]] = None - index: typing.Optional[int] = None + min: Union[int, float] + max: Union[int, float] + period: Optional[Union[int, str]] = dataclasses.field(default = None) + index: int = dataclasses.field(default = 0) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index e5bb343f22..64697265f5 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -import typing +from typing import Any, Callable, List, NoReturn, Optional, Type, Union from .axis import Axis @@ -9,7 +9,7 @@ @dataclasses.dataclass(frozen = True) class AxisArray: """ - A collection of :obj:`Axis` and a bunch of business logic. + A collection of :obj:`Axis` (some business logic). Axis expansion is a feature in :module:`openfisca_core` that allows us to parametrise some dimensions in order to create and to evaluate a range of @@ -21,6 +21,27 @@ class AxisArray: :class:`openfisca_core.periods.Period` for any given population (or a collection of :module:`openfisca_core.entities`). + As axes have a relative position relative to each other, they can be either + equidistant, that is parallel, or perpendicular. We assume the first + dimension to be a collection of parallel axes relative to themselves. + + Henceforward, we will consider each parallel axis as belonging to this + first dimension, and each perpendicular axis as belonging to a new + dimension, perpendicular to the previous one: that is, we won't be adding + more than one axis for each perpendicular dimension beyond the first one. + + As you might've already guess, it is not possible to add any parallel or + perpendicular axis relative to anything, so we assume the following when + our collection is yet empty: whenever you add a parallel axis it will by + default be added to the first dimension, and whenever you add a + perpendicular axis it will be added in isolation to second dimension and + beyond. + + Adding a perpendicular axis to an empty collection is a conceptual error + so instead of trying to mitigate this, we will rather error out and let + the user know why she can't do that and how she can correct they way she's + building her collection of axes (simply put to add first a parallel axis). + Attributes: axes: A :type:`list` containing our collection of :obj:`Axis`. @@ -35,14 +56,16 @@ class AxisArray: pytest tests/core/test_axes.py openfisca_core/simulations/axis_array.py - .. versionadded:: 3.4.0 + .. versionadded:: 35.4.0 """ - axes: typing.List[Axis] = dataclasses.field(default_factory = list) + axes: List[Axis] = dataclasses.field(default_factory = list) def __post_init__(self) -> None: - self.__is_list(self.axes) - list(map(self.__is_axis, self.axes)) + self.validate(isinstance, self.axes, list) + + for axis in self.axes: + self.validate(isinstance, axis, Axis) def __contains__(self, item: Axis) -> bool: return item in self.axes @@ -50,27 +73,26 @@ def __contains__(self, item: Axis) -> bool: def __repr__(self) -> str: return f"{self.__class__.__qualname__}{repr(self.axes)}" - def first(self) -> typing.Optional[Axis]: + def first(self) -> Optional[Axis]: """ Retrieves the first :obj:`Axis` from our axes collection. Usage: >>> axis_array = AxisArray() - >>> not axis_array.first() - True + >>> axis_array.first() + Traceback (most recent call last): + TypeError: Expecting a non empty list, but [] given >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) >>> axis_array = axis_array.append(axis) >>> axis_array.first() - Axis(name='salary', ..., index=None) + Axis(name='salary', ..., index=0) """ - if len(self.axes) == 0: - return None - + self.validate(lambda axes, _: axes, self.axes, "a non empty list") return self.axes[0] - def append(self, tail: Axis) -> typing.Union[AxisArray, typing.NoReturn]: + def append(self, tail: Axis) -> Union[AxisArray, NoReturn]: """ Append an :obj:`Axis` to our axes collection. @@ -85,17 +107,35 @@ def append(self, tail: Axis) -> typing.Union[AxisArray, typing.NoReturn]: >>> axis_array.append(axis) AxisArray[Axis(name='salary', ...)] """ - self.__is_axis(tail) + self.validate(isinstance, tail, Axis) return self.__class__(axes = [*self.axes, tail]) - def __is_list(self, axes: list) -> typing.Union[bool, typing.NoReturn]: - if isinstance(axes, list): - return True + def validate( + self, + condition: Callable, + real: Any, + expected: Any, + ) -> Union[bool, NoReturn]: + """ + Validate that a condition holds true. + + Args: - raise TypeError(f"Expecting a list, but {type(self.axes)} given") + condition: A function reprensenting the condition to validate. + real: The value given as input, passed to :args:`condition`. + expected: The value we expect, passed to :args:`condition`. - def __is_axis(self, item: Axis) -> typing.Union[bool, typing.NoReturn]: - if isinstance(item, Axis): + Usage: + + >>> axis_array = AxisArray() + >>> condition = isinstance + >>> real = tuple() + >>> expected = list + >>> axis_array.validate(condition, real, expected) + Traceback (most recent call last): + TypeError: Expecting , but () given + """ + if condition(real, expected): return True - raise TypeError(f"Expecting an {Axis}, but {type(item)} given") + raise TypeError(f"Expecting {expected}, but {real} given") diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 88553488db..714203d9e6 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,6 +1,7 @@ import copy import dpath -import typing +import warnings +from typing import Dict, Iterable, List import numpy @@ -8,7 +9,7 @@ from openfisca_core.entities import Entity from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation +from openfisca_core.simulations import helpers, Axis, Simulation from openfisca_core.variables import Variable @@ -19,24 +20,24 @@ def __init__(self): self.persons_plural = None # Plural name for person entity in current tax and benefits system # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {} - self.populations: typing.Dict[Entity.key, Population] = {} + self.input_buffer: Dict[Variable.name, Dict[str(periods.period), numpy.array]] = {} + self.populations: Dict[Entity.key, Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: typing.Dict[Entity.plural, int] = {} - # JSON input - typing.List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} + self.entity_counts: Dict[Entity.plural, int] = {} + # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. + self.entity_ids: Dict[Entity.plural, List[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.memberships: Dict[Entity.plural, List[int]] = {} + self.roles: Dict[Entity.plural, List[int]] = {} - self.variable_entities: typing.Dict[Variable.name, Entity] = {} + self.variable_entities: Dict[Variable.name, Entity] = {} self.axes = [[]] - self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} - self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.axes_entity_counts: Dict[Entity.plural, int] = {} + self.axes_entity_ids: Dict[Entity.plural, List[int]] = {} + self.axes_memberships: Dict[Entity.plural, List[int]] = {} + self.axes_roles: Dict[Entity.plural, List[int]] = {} def build_from_dict(self, tax_benefit_system, input_dict): """ @@ -167,14 +168,14 @@ def build_default_simulation(self, tax_benefit_system, count = 1): def create_entities(self, tax_benefit_system): self.populations = tax_benefit_system.instantiate_entities() - def declare_person_entity(self, person_singular, persons_ids: typing.Iterable): + def declare_person_entity(self, person_singular, persons_ids: Iterable): person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) self.persons_plural = person_instance.entity.plural - def declare_entity(self, entity_singular, entity_ids: typing.Iterable): + def declare_entity(self, entity_singular, entity_ids: Iterable): entity_instance = self.populations[entity_singular] entity_instance.ids = numpy.array(list(entity_ids)) entity_instance.count = len(entity_instance.ids) @@ -183,7 +184,7 @@ def declare_entity(self, entity_singular, entity_ids: typing.Iterable): def nb_persons(self, entity_singular, role = None): return self.populations[entity_singular].nb_persons(role = role) - def join_with_persons(self, group_population, persons_group_assignment, roles: typing.Iterable[str]): + def join_with_persons(self, group_population, persons_group_assignment, roles: Iterable[str]): # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) group_sorted_indices = numpy.unique(persons_group_assignment, return_inverse = True)[1] group_population.members_entity_id = numpy.argsort(group_population.ids)[group_sorted_indices] @@ -456,23 +457,61 @@ def get_roles(self, entity_name): # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) - def add_parallel_axis(self, axis): - # All parallel axes have the same count and entity. - # Search for a compatible axis, if none exists, error out - self.axes[0].append(axis) + def add_parallel_axis(self, axis: dict) -> None: + """ + Add a parallel axis to our collection of axes. + + Args: + + axis: An axis to add to our collection. + + .. deprecated:: 35.4.0 - def add_perpendicular_axis(self, axis): - # This adds an axis perpendicular to all previous dimensions - self.axes.append([axis]) + Use :meth:`AxisArray.add_parallel` instead. + + """ + message = [ + "The 'add_parallel_axis' class has been deprecated since version", + "35.4.0, and will be removed in the future. Please use", + "'AxisArray.add_parallel' instead", + ] + + warnings.warn(" ".join(message), DeprecationWarning) + self.axes[0].append(Axis(**axis)) + + def add_perpendicular_axis(self, axis: dict) -> None: + """ + Add a perpendicular axis to all previous dimensions. + + Args: + + axis: An axis to add to our collection. + + .. deprecated:: 35.4.0 + + Use :meth:`AxisArray.add_parallel` instead. + + """ + message = [ + "The 'add_perpendicular_axis' class has been deprecated since", + "version 35.4.0, and will be removed in the future. Please use", + "'AxisArray.add_perpendicular' instead", + ] + + warnings.warn(" ".join(message), DeprecationWarning) + self.axes.append([Axis(**axis)]) def expand_axes(self): # This method should be idempotent & allow change in axes perpendicular_dimensions = self.axes cell_count = 1 + + # All parallel axes have the same count and entity. + # Search for a compatible axis, if none exists, error out. for parallel_axes in perpendicular_dimensions: first_axis = parallel_axes[0] - axis_count = first_axis['count'] + axis_count = first_axis.count cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times @@ -501,14 +540,14 @@ def expand_axes(self): if len(self.axes) == 1 and len(self.axes[0]): parallel_axes = self.axes[0] first_axis = parallel_axes[0] - axis_count: int = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + axis_count: int = first_axis.count + axis_entity = self.get_variable_entity(first_axis.name) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis.get('period', self.default_period) - axis_name = axis['name'] + axis_index = axis.index + axis_period = axis.period or self.default_period + axis_name = axis.name variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: @@ -516,15 +555,15 @@ def expand_axes(self): elif array.size == axis_entity_step_size: array = numpy.tile(array, axis_count) array[axis_index:: axis_entity_step_size] = numpy.linspace( - axis['min'], - axis['max'], + axis.min, + axis.max, num = axis_count, ) # Set input self.input_buffer[axis_name][str(axis_period)] = array else: - first_axes_count: typing.List[int] = ( - parallel_axes[0]["count"] + first_axes_count: List[int] = ( + parallel_axes[0].count for parallel_axes in self.axes ) @@ -536,22 +575,22 @@ def expand_axes(self): axes_meshes = numpy.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): first_axis = parallel_axes[0] - axis_count = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + axis_count = first_axis.count + axis_entity = self.get_variable_entity(first_axis.name) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along the grid for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis['period'] or self.default_period - axis_name = axis['name'] + axis_index = axis.index + axis_period = axis.period or self.default_period + axis_name = axis.name variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array(cell_count * axis_entity_step_size) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) - array[axis_index:: axis_entity_step_size] = axis['min'] \ - + mesh.reshape(cell_count) * (axis['max'] - axis['min']) / (axis_count - 1) + array[axis_index:: axis_entity_step_size] = axis.min \ + + mesh.reshape(cell_count) * (axis.max - axis.min) / (axis_count - 1) self.input_buffer[axis_name][str(axis_period)] = array def get_variable_entity(self, variable_name): diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 4d8a0fbe7e..a7e36cf8dd 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -263,6 +263,7 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(simulation_bu # Integration test +@pytest.mark.skip def test_simulation_with_axes(simulation_builder): from .test_countries import tax_benefit_system input_yaml = """ From 3e6ee084ea57a41adbb9c7de6f1be9fe035dba00 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 13:54:37 +0200 Subject: [PATCH 07/11] Use Axis in SimulationBuilder 2/2 --- openfisca_core/simulations/axis_array.py | 4 ++-- openfisca_core/simulations/simulation_builder.py | 8 +++++++- tests/core/test_axes.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index 64697265f5..5097f8eec0 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Callable, List, NoReturn, Optional, Type, Union +from typing import Any, Callable, List, NoReturn, Optional, Union from .axis import Axis @@ -98,7 +98,7 @@ def append(self, tail: Axis) -> Union[AxisArray, NoReturn]: Args: - axis: An :obj:`Axis` to append to our collection. + tail: An :obj:`Axis` to append to our collection. Usage: diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 714203d9e6..967cefa758 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -107,7 +107,13 @@ def build_from_entities(self, tax_benefit_system, input_dict): self.add_default_group_entity(persons_ids, entity_class) if axes: - self.axes = axes + for axis in axes[0]: + self.add_parallel_axis(axis) + + if len(axes) >= 1: + for axis in axes[1:]: + self.add_perpendicular_axis(axis[0]) + self.expand_axes() try: diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index a7e36cf8dd..81aa4736c3 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -260,10 +260,10 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(simulation_bu assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) + # Integration test -@pytest.mark.skip def test_simulation_with_axes(simulation_builder): from .test_countries import tax_benefit_system input_yaml = """ From 1e39369489c8f1850531ffa940abd43ea8cc8fff Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 16:59:37 +0200 Subject: [PATCH 08/11] Use AxisArray.add_parallel in SimulationBuilder --- openfisca_core/simulations/axis_array.py | 111 ++++++++++++++---- .../simulations/simulation_builder.py | 33 +++--- tests/core/test_axes.py | 19 ++- 3 files changed, 113 insertions(+), 50 deletions(-) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index 5097f8eec0..e550800ac7 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -1,9 +1,9 @@ from __future__ import annotations import dataclasses -from typing import Any, Callable, List, NoReturn, Optional, Union +from typing import Any, Callable, Iterator, List, NoReturn, Union -from .axis import Axis +from . import Axis @dataclasses.dataclass(frozen = True) @@ -50,7 +50,7 @@ class AxisArray: >>> axis_array = AxisArray() >>> axis_array - AxisArray[] + AxisArray[[]] Testing: @@ -59,63 +59,91 @@ class AxisArray: .. versionadded:: 35.4.0 """ - axes: List[Axis] = dataclasses.field(default_factory = list) + axes: List[Union[AxisArray, Axis, list]] = \ + dataclasses \ + .field(default_factory = lambda: [[]]) def __post_init__(self) -> None: - self.validate(isinstance, self.axes, list) + axes = self.validate(isinstance, self.axes, list) - for axis in self.axes: - self.validate(isinstance, axis, Axis) + for item in self.__flatten(axes): + self.validate(isinstance, item, (AxisArray, Axis)) - def __contains__(self, item: Axis) -> bool: + def __contains__(self, item: Union[AxisArray, Axis]) -> bool: return item in self.axes + def __iter__(self) -> Iterator: + return (item for item in self.axes) + + def __len__(self) -> int: + return len(self.axes) + def __repr__(self) -> str: return f"{self.__class__.__qualname__}{repr(self.axes)}" - def first(self) -> Optional[Axis]: + def first(self) -> Union[AxisArray, Axis, List]: """ - Retrieves the first :obj:`Axis` from our axes collection. + Retrieves the first axis from our axes collection. Usage: >>> axis_array = AxisArray() >>> axis_array.first() + [] + + >>> axis_array = AxisArray([]) + >>> axis_array.first() Traceback (most recent call last): TypeError: Expecting a non empty list, but [] given >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) - >>> axis_array = axis_array.append(axis) + >>> node_array = AxisArray([axis]) + >>> node_array.first() + Axis(name='salary', count=3, min=0, max=3000, period=None, index=0) + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array = AxisArray() + >>> axis_array = axis_array.append_parallel(axis) >>> axis_array.first() - Axis(name='salary', ..., index=0) + AxisArray[Axis(name='salary', ..., index=0)] + + .. versionadded:: 35.4.0 """ - self.validate(lambda axes, _: axes, self.axes, "a non empty list") + self.validate(lambda item, _: item, self.axes, "a non empty list") return self.axes[0] - def append(self, tail: Axis) -> Union[AxisArray, NoReturn]: + def append_parallel(self, tail: Axis) -> Union[AxisArray, NoReturn]: """ - Append an :obj:`Axis` to our axes collection. + Append an :obj:`Axis` to the first dimension of our collection. + + We choose the language "append" instead of "add" as, in a pythonic + context, "add" means "concatenate". Here, instead, we're "appending" + and :obj:`Axis` to a dimension (the first one) of the collection. Args: - tail: An :obj:`Axis` to append to our collection. + tail: An :obj:`Axis` to append to the first dimension of our + collection. Usage: >>> axis_array = AxisArray() >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) - >>> axis_array.append(axis) - AxisArray[Axis(name='salary', ...)] + >>> axis_array.append_parallel(axis) + AxisArray[AxisArray[Axis(name='salary', ...)]] + + .. versionadded:: 35.4.0 """ - self.validate(isinstance, tail, Axis) - return self.__class__(axes = [*self.axes, tail]) + parallel = self.validate(isinstance, self.first(), (AxisArray, list)) + appended = self.__append(parallel, tail) + return self.__append(self.axes, appended) def validate( self, condition: Callable, real: Any, expected: Any, - ) -> Union[bool, NoReturn]: + ) -> Union[Any, NoReturn]: """ Validate that a condition holds true. @@ -134,8 +162,47 @@ def validate( >>> axis_array.validate(condition, real, expected) Traceback (most recent call last): TypeError: Expecting , but () given + + .. versionadded:: 35.4.0 """ if condition(real, expected): - return True + return real raise TypeError(f"Expecting {expected}, but {real} given") + + def __append( + self, + axes: List[Union[AxisArray, Axis, list]], + tail: Union[AxisArray, Axis], + ) -> Union[AxisArray, NoReturn]: + """ + Append an element to an array. + + Args: + + axes: An :obj:`AxisArray` to be appended to. + tail: An :obj:`Axis` or an :obj:`AxisArray` to append. + + .. versionadded:: 35.4.0 + """ + self.validate(isinstance, axes, list) + self.validate(isinstance, tail, (AxisArray, Axis)) + return self.__class__([*axes, tail][-1:]) + + def __flatten(self, axes: list) -> List[Union[AxisArray, Axis]]: + """ + Flatten out our entire collection. + + Args: + + axes: Our collection. + + .. versionadded:: 35.4.0 + """ + if not axes: + return axes + + if isinstance(axes[0], list): + return self.__flatten(axes[0]) + self.__flatten(axes[1:]) + + return axes[:1] + self.__flatten(axes[1:]) diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 967cefa758..304ab0dc4e 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -9,9 +9,10 @@ from openfisca_core.entities import Entity from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Axis, Simulation from openfisca_core.variables import Variable +from . import helpers, Axis, AxisArray, Simulation + class SimulationBuilder: @@ -33,7 +34,7 @@ def __init__(self): self.variable_entities: Dict[Variable.name, Entity] = {} - self.axes = [[]] + self.axes = AxisArray() self.axes_entity_counts: Dict[Entity.plural, int] = {} self.axes_entity_ids: Dict[Entity.plural, List[int]] = {} self.axes_memberships: Dict[Entity.plural, List[int]] = {} @@ -473,17 +474,17 @@ def add_parallel_axis(self, axis: dict) -> None: .. deprecated:: 35.4.0 - Use :meth:`AxisArray.add_parallel` instead. + Use :meth:`AxisArray.append_parallel` instead. """ message = [ - "The 'add_parallel_axis' class has been deprecated since version", - "35.4.0, and will be removed in the future. Please use", - "'AxisArray.add_parallel' instead", + "The 'add_parallel_axis' function has been deprecated since", + "version 35.4.0, and will be removed in the future. Please use", + "'AxisArray.append_parallel' instead", ] warnings.warn(" ".join(message), DeprecationWarning) - self.axes[0].append(Axis(**axis)) + self.axes = self.axes.append_parallel(Axis(**axis)) def add_perpendicular_axis(self, axis: dict) -> None: """ @@ -495,17 +496,17 @@ def add_perpendicular_axis(self, axis: dict) -> None: .. deprecated:: 35.4.0 - Use :meth:`AxisArray.add_parallel` instead. + Use :meth:`AxisArray.append_parallel` instead. """ message = [ - "The 'add_perpendicular_axis' class has been deprecated since", + "The 'add_perpendicular_axis' function has been deprecated since", "version 35.4.0, and will be removed in the future. Please use", - "'AxisArray.add_perpendicular' instead", + "'AxisArray.append_perpendicular' instead", ] warnings.warn(" ".join(message), DeprecationWarning) - self.axes.append([Axis(**axis)]) + self.axes.append_perpendicular(Axis(**axis)) def expand_axes(self): # This method should be idempotent & allow change in axes @@ -516,7 +517,7 @@ def expand_axes(self): # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out. for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] + first_axis = parallel_axes.first() axis_count = first_axis.count cell_count *= axis_count @@ -543,9 +544,9 @@ def expand_axes(self): # Now generate input values along the specified axes # TODO - factor out the common logic here - if len(self.axes) == 1 and len(self.axes[0]): - parallel_axes = self.axes[0] - first_axis = parallel_axes[0] + if len(self.axes) == 1 and len(self.axes.first()): + parallel_axes = self.axes.first() + first_axis = parallel_axes.first() axis_count: int = first_axis.count axis_entity = self.get_variable_entity(first_axis.name) axis_entity_step_size = self.entity_counts[axis_entity.plural] @@ -580,7 +581,7 @@ def expand_axes(self): ] axes_meshes = numpy.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): - first_axis = parallel_axes[0] + first_axis = parallel_axes.first() axis_count = first_axis.count axis_entity = self.get_variable_entity(first_axis.name) axis_entity_step_size = self.entity_counts[axis_entity.plural] diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 81aa4736c3..3c305b0bbb 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -61,7 +61,7 @@ def test_create_axis_array_with_axes(axis): """ We can pass along some axes at initialisation time as well. """ - result = AxisArray(axes = [axis]) + result = AxisArray([axis]) assert result.first() == axis @@ -81,21 +81,16 @@ def test_create_axis_array_with_a_collection_of_anything(): AxisArray(axes = ["axis"]) -def test_add_axis_to_array(axis_array, axis): +def test_append_parallel_axis(axis_array, axis): """ - If you add an :obj:`Axis` to the array, it works! + As there are no previously added axes in our collection, it adds the first + one to the first dimension (parallel). """ - result = axis_array.append(axis) - assert axis in result + result = axis_array.append_parallel(axis) + assert axis in result.first() + assert result.first().first() == axis -def test_add_anything_to_array(axis_array, axis): - """ - If you add anything else to the array, it fails! - """ - with pytest.raises(TypeError): - axis_array.append("cuack") - # With periods From 263e7fe094ce26181d8da30f18984072ae5228d2 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 17:52:26 +0200 Subject: [PATCH 09/11] Use AxisArray.add_perpendicular in SimulationBuilder --- openfisca_core/simulations/axis_array.py | 131 +++++++++++++----- .../simulations/simulation_builder.py | 14 +- tests/core/test_axes.py | 71 ++++++++-- 3 files changed, 161 insertions(+), 55 deletions(-) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index e550800ac7..b72111ce6c 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -52,6 +52,20 @@ class AxisArray: >>> axis_array AxisArray[[]] + >>> salary = Axis(name = "salary", count = 3, min = 0, max = 3) + >>> axis_array = axis_array.add_parallel(salary) + >>> axis_array + AxisArray[AxisArray[Axis(name='salary', ..., index=0)]] + + >>> pension = Axis(name = "pension", count = 2, min = 0, max = 1) + >>> axis_array = axis_array.add_perpendicular(pension) + >>> axis_array + AxisArray[AxisArray[Axis(...)], AxisArray[Axis(...)]] + + >>> rent = Axis(name = "rent", count = 3, min = 0, max = 2) + >>> axis_array.add_parallel(rent) + AxisArray[AxisArray[Axis(...), Axis(...)], AxisArray[Axis(...)]] + Testing: pytest tests/core/test_axes.py openfisca_core/simulations/axis_array.py @@ -94,7 +108,7 @@ def first(self) -> Union[AxisArray, Axis, List]: >>> axis_array = AxisArray([]) >>> axis_array.first() Traceback (most recent call last): - TypeError: Expecting a non empty list, but [] given + TypeError: Expecting a non empty list, but [] given. >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) >>> node_array = AxisArray([axis]) @@ -103,7 +117,7 @@ def first(self) -> Union[AxisArray, Axis, List]: >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) >>> axis_array = AxisArray() - >>> axis_array = axis_array.append_parallel(axis) + >>> axis_array = axis_array.add_parallel(axis) >>> axis_array.first() AxisArray[Axis(name='salary', ..., index=0)] @@ -112,31 +126,101 @@ def first(self) -> Union[AxisArray, Axis, List]: self.validate(lambda item, _: item, self.axes, "a non empty list") return self.axes[0] - def append_parallel(self, tail: Axis) -> Union[AxisArray, NoReturn]: + def last(self) -> Union[AxisArray, Axis, List]: + """ + Retrieves the last axis from our axes collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis_array.last() + [] + + >>> axis_array = AxisArray([]) + >>> axis_array.last() + Traceback (most recent call last): + TypeError: Expecting a non empty list, but [] given. + + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> node_array = AxisArray([axis]) + >>> node_array.last() + Axis(name='salary', count=3, min=0, max=3000, period=None, index=0) + + >>> axis1 = Axis(name = "salary", count = 3, min = 0, max = 3) + >>> axis2 = Axis(name = "pension", count = 2, min = 0, max = 2) + >>> axis3 = Axis(name = "rent", count = 3, min = 0, max = 1) + >>> axis_array = AxisArray() + >>> axis_array = axis_array.add_parallel(axis1) + >>> axis_array = axis_array.add_perpendicular(axis2) + >>> axis_array = axis_array.add_parallel(axis3) + + >>> axis_array.first() + AxisArray[Axis(name='salary', ...), Axis(name='rent', ...)] + + >>> axis_array.first().last() + Axis(name='rent', ..., index=0) + + >>> axis_array.last() + AxisArray[Axis(name='pension', ...)] + + >>> axis_array.last().last() + Axis(name='pension', ..., index=0) + + .. versionadded:: 35.4.0 """ - Append an :obj:`Axis` to the first dimension of our collection. + self.validate(lambda item, _: item, self.axes, "a non empty list") + return self.axes[-1] - We choose the language "append" instead of "add" as, in a pythonic - context, "add" means "concatenate". Here, instead, we're "appending" - and :obj:`Axis` to a dimension (the first one) of the collection. + def add_parallel(self, tail: Axis) -> Union[AxisArray, NoReturn]: + """ + Add an :obj:`Axis` to the first dimension of our collection. Args: - tail: An :obj:`Axis` to append to the first dimension of our + tail: An :obj:`Axis` to add to the first dimension of our collection. Usage: >>> axis_array = AxisArray() >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) - >>> axis_array.append_parallel(axis) + >>> axis_array.add_parallel(axis) AxisArray[AxisArray[Axis(name='salary', ...)]] .. versionadded:: 35.4.0 """ - parallel = self.validate(isinstance, self.first(), (AxisArray, list)) - appended = self.__append(parallel, tail) - return self.__append(self.axes, appended) + node = self.validate(isinstance, self.first(), (AxisArray, list)) + tail = self.validate(isinstance, tail, Axis) + parallel = self.__class__([*node, tail]) + return self.__class__([parallel, *self.axes[1:]]) + + def add_perpendicular(self, tail: Axis) -> Union[AxisArray, NoReturn]: + """ + Add an :obj:`Axis` to the subsequent dimensions of our collection. + + Args: + + tail: An :obj:`Axis` to add to the subsequent dimensions of + our collection. + + Usage: + + >>> axis_array = AxisArray() + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array.add_perpendicular(axis) + Traceback (most recent call last): + TypeError: Expecting parallel axes set, but [] given. + + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array.add_perpendicular(axis) + AxisArray[AxisArray[Axis(name='salary', ...)]] + + .. versionadded:: 35.4.0 + """ + self.validate(lambda item, _: item, self.first(), "parallel axes set") + tail = self.validate(isinstance, tail, Axis) + perpendicular = self.__class__([tail]) + return self.__class__([*self.axes, perpendicular]) def validate( self, @@ -161,33 +245,14 @@ def validate( >>> expected = list >>> axis_array.validate(condition, real, expected) Traceback (most recent call last): - TypeError: Expecting , but () given + TypeError: Expecting , but () given. .. versionadded:: 35.4.0 """ if condition(real, expected): return real - raise TypeError(f"Expecting {expected}, but {real} given") - - def __append( - self, - axes: List[Union[AxisArray, Axis, list]], - tail: Union[AxisArray, Axis], - ) -> Union[AxisArray, NoReturn]: - """ - Append an element to an array. - - Args: - - axes: An :obj:`AxisArray` to be appended to. - tail: An :obj:`Axis` or an :obj:`AxisArray` to append. - - .. versionadded:: 35.4.0 - """ - self.validate(isinstance, axes, list) - self.validate(isinstance, tail, (AxisArray, Axis)) - return self.__class__([*axes, tail][-1:]) + raise TypeError(f"Expecting {expected}, but {real} given.") def __flatten(self, axes: list) -> List[Union[AxisArray, Axis]]: """ diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 304ab0dc4e..dfa1ad10b9 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -474,17 +474,17 @@ def add_parallel_axis(self, axis: dict) -> None: .. deprecated:: 35.4.0 - Use :meth:`AxisArray.append_parallel` instead. + Use :meth:`AxisArray.add_parallel` instead. """ message = [ "The 'add_parallel_axis' function has been deprecated since", "version 35.4.0, and will be removed in the future. Please use", - "'AxisArray.append_parallel' instead", + "'AxisArray.add_parallel' instead", ] warnings.warn(" ".join(message), DeprecationWarning) - self.axes = self.axes.append_parallel(Axis(**axis)) + self.axes = self.axes.add_parallel(Axis(**axis)) def add_perpendicular_axis(self, axis: dict) -> None: """ @@ -496,17 +496,17 @@ def add_perpendicular_axis(self, axis: dict) -> None: .. deprecated:: 35.4.0 - Use :meth:`AxisArray.append_parallel` instead. + Use :meth:`AxisArray.add_parallel` instead. """ message = [ "The 'add_perpendicular_axis' function has been deprecated since", "version 35.4.0, and will be removed in the future. Please use", - "'AxisArray.append_perpendicular' instead", + "'AxisArray.add_perpendicular' instead", ] warnings.warn(" ".join(message), DeprecationWarning) - self.axes.append_perpendicular(Axis(**axis)) + self.axes = self.axes.add_perpendicular(Axis(**axis)) def expand_axes(self): # This method should be idempotent & allow change in axes @@ -570,7 +570,7 @@ def expand_axes(self): self.input_buffer[axis_name][str(axis_period)] = array else: first_axes_count: List[int] = ( - parallel_axes[0].count + parallel_axes.first().count for parallel_axes in self.axes ) diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 3c305b0bbb..dc20c63ba8 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -10,7 +10,7 @@ def simulation_builder(): @pytest.fixture -def kwargs(): +def salary(): return { "name": "salary", "count": 3, @@ -20,8 +20,23 @@ def kwargs(): @pytest.fixture -def axis(kwargs): - return Axis(**kwargs) +def pension(): + return { + "name": "pension", + "count": 2, + "min": 0, + "max": 2000, + } + + +@pytest.fixture +def salary_axis(salary): + return Axis(**salary) + + +@pytest.fixture +def pension_axis(pension): + return Axis(**pension) @pytest.fixture @@ -32,11 +47,11 @@ def axis_array(): # Unit tests -def test_create_axis(kwargs): +def test_create_axis(salary): """ Works! Missing fields are optional, so they default to None. """ - result = Axis(**kwargs) + result = Axis(**salary) assert result.name == "salary" assert not result.period @@ -57,20 +72,20 @@ def test_empty_create_axis_array(): assert isinstance(result, AxisArray) -def test_create_axis_array_with_axes(axis): +def test_create_axis_array_with_axes(salary_axis): """ We can pass along some axes at initialisation time as well. """ - result = AxisArray([axis]) - assert result.first() == axis + result = AxisArray([salary_axis]) + assert result.first() == salary_axis -def test_create_axis_array_with_anything(axis): +def test_create_axis_array_with_anything(salary_axis): """ If you don't pass a collection, it will fail! """ with pytest.raises(TypeError): - AxisArray(axes = axis) + AxisArray(salary_axis) def test_create_axis_array_with_a_collection_of_anything(): @@ -78,17 +93,43 @@ def test_create_axis_array_with_a_collection_of_anything(): If you pass a collection of anything, it will fail! """ with pytest.raises(TypeError): - AxisArray(axes = ["axis"]) + AxisArray(["axis"]) -def test_append_parallel_axis(axis_array, axis): +def test_add_parallel_axis(axis_array, salary_axis): """ As there are no previously added axes in our collection, it adds the first one to the first dimension (parallel). """ - result = axis_array.append_parallel(axis) - assert axis in result.first() - assert result.first().first() == axis + result = axis_array.add_parallel(salary_axis) + assert salary_axis in result.first() + assert result.first().first() == salary_axis + + +def test_add_perpendicular_axis_before_parallel_axis(axis_array, pension_axis): + """ + As there are no previously added axes in our collection, it fails! + """ + with pytest.raises(TypeError): + axis_array.add_perpendicular(pension_axis) + + +def test_add_perpendicular_axis_after_parallel_axis(axis_array, salary_axis, pension_axis): + """ + As there is at least one parallel axis already, it works! + """ + result = \ + axis_array \ + .add_parallel(salary_axis) \ + .add_perpendicular(pension_axis) + + # Parallel + assert salary_axis in result.first() + assert result.first().first() == salary_axis + + # Perpendicular + assert pension_axis in result.last() + assert result.last().first() == pension_axis # With periods From 86eeee097558856d4fd3953d7756bdf2fd4abdeb Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 19:17:20 +0200 Subject: [PATCH 10/11] Move cell_count to AxisExpander --- openfisca_core/simulations/__init__.py | 1 + openfisca_core/simulations/axis.py | 12 +--- openfisca_core/simulations/axis_array.py | 12 +--- openfisca_core/simulations/axis_expander.py | 72 +++++++++++++++++++ .../simulations/simulation_builder.py | 32 +++++---- tests/core/test_axes.py | 32 ++++++++- 6 files changed, 122 insertions(+), 39 deletions(-) create mode 100644 openfisca_core/simulations/axis_expander.py diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index b0a54d533e..e23749d5f4 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -26,5 +26,6 @@ from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 from .axis import Axis # noqa: F401 from .axis_array import AxisArray # noqa: F401 +from .axis_expander import AxisExpander # noqa: F401 from .simulation import Simulation # noqa: F401 from .simulation_builder import SimulationBuilder # noqa: F401 diff --git a/openfisca_core/simulations/axis.py b/openfisca_core/simulations/axis.py index f4e1b34a69..a1fec9bcba 100644 --- a/openfisca_core/simulations/axis.py +++ b/openfisca_core/simulations/axis.py @@ -7,17 +7,7 @@ @dataclasses.dataclass(frozen = True) class Axis: """ - Base data class for axes (no business logic). - - Axis expansion is a feature in :module:`openfisca_core` that allows us to - parametrise some dimensions in order to create and to evaluate a range of - values for others. - - The most typical use of axis expansion is evaluate different numerical - values, starting from a :attr:`min_` and up to a :attr:`max_`, that could - take any given :class:`openfisca_core.variables.Variable` for any given - :class:`openfisca_core.periods.Period` for any given population (or a - collection of :module:`openfisca_core.entities`). + Base data class for axes (no domain logic). Attributes: diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index b72111ce6c..60648f5c8c 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -9,17 +9,7 @@ @dataclasses.dataclass(frozen = True) class AxisArray: """ - A collection of :obj:`Axis` (some business logic). - - Axis expansion is a feature in :module:`openfisca_core` that allows us to - parametrise some dimensions in order to create and to evaluate a range of - values for others. - - The most typical use of axis expansion is evaluate different numerical - values, starting from a :attr:`min_` and up to a :attr:`max_`, that could - take any given :class:`openfisca_core.variables.Variable` for any given - :class:`openfisca_core.periods.Period` for any given population (or a - collection of :module:`openfisca_core.entities`). + A collection of :obj:`Axis` (some domain logic related to data integrity). As axes have a relative position relative to each other, they can be either equidistant, that is parallel, or perpendicular. We assume the first diff --git a/openfisca_core/simulations/axis_expander.py b/openfisca_core/simulations/axis_expander.py new file mode 100644 index 0000000000..3c1f1aaff9 --- /dev/null +++ b/openfisca_core/simulations/axis_expander.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import functools +import typing + +if typing.TYPE_CHECKING: + from . import AxisArray + + +class AxisExpander: + """ + Expander of all axes for a given axes collection (lots of domain logic). + + Axis expansion is a feature in :module:`openfisca_core` that allows us to + parametrise some dimensions in order to create and to evaluate a range of + values for others. + + The most typical use of axis expansion is evaluate different numerical + values, starting from a :attr:`min_` and up to a :attr:`max_`, that could + take any given :class:`openfisca_core.variables.Variable` for any given + :class:`openfisca_core.periods.Period` for any given population (or a + collection of :module:`openfisca_core.entities`). + + Args: + + axis_array: An array of axes to expand. + + .. versionadded:: 35.4.0 + """ + + def __init__(self, axis_array: AxisArray) -> None: + self.__axis_array = axis_array + + @property + def axis_array(self) -> AxisArray: + """ + An array of axes to expand. + + .. versionadded:: 35.4.0 + """ + return self.__axis_array + + def count_cells(self): + """ + Count the total number of cells on an axes collection. + + As a collection of axes is comprised of several perpendicular + collections of axes, relative to each other, we're going to consider + the total number of cells as being the multiplication of the value + :attr:`Axis.count` for the first axis of each dimension. + + We assume however that all parallel axes should have the same count. + This method searches for a compatible axis (the first one). If none + exists, it should error out (we do not check here at it is the + responsability of :class:`AxisArray` to do so: data integrity is a + domain invariant related to the data model). + + Usage: + + >>> from . import Axis, AxisArray + >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) + >>> axis_array = AxisArray() + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array = axis_array.add_perpendicular(axis) + >>> axis_expander = AxisExpander(axis_array) + >>> axis_expander.count_cells() + 9 + + .. versionadded:: 35.4.0 + """ + axis_count = map(lambda dim: dim.first().count, self.axis_array) + return functools.reduce(lambda acc, count: acc * count, axis_count, 1) diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index dfa1ad10b9..bdea64fcfb 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -11,7 +11,7 @@ from openfisca_core.populations import Population from openfisca_core.variables import Variable -from . import helpers, Axis, AxisArray, Simulation +from . import helpers, Axis, AxisArray, AxisExpander, Simulation class SimulationBuilder: @@ -475,10 +475,9 @@ def add_parallel_axis(self, axis: dict) -> None: .. deprecated:: 35.4.0 Use :meth:`AxisArray.add_parallel` instead. - """ message = [ - "The 'add_parallel_axis' function has been deprecated since", + "The 'add_parallel_axis' method has been deprecated since", "version 35.4.0, and will be removed in the future. Please use", "'AxisArray.add_parallel' instead", ] @@ -497,10 +496,9 @@ def add_perpendicular_axis(self, axis: dict) -> None: .. deprecated:: 35.4.0 Use :meth:`AxisArray.add_parallel` instead. - """ message = [ - "The 'add_perpendicular_axis' function has been deprecated since", + "The 'add_perpendicular_axis' method has been deprecated since", "version 35.4.0, and will be removed in the future. Please use", "'AxisArray.add_perpendicular' instead", ] @@ -509,17 +507,23 @@ def add_perpendicular_axis(self, axis: dict) -> None: self.axes = self.axes.add_perpendicular(Axis(**axis)) def expand_axes(self): - # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + """ + Expand all axes for the current simulation. - cell_count = 1 + .. deprecated:: 35.4.0 - # All parallel axes have the same count and entity. - # Search for a compatible axis, if none exists, error out. - for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes.first() - axis_count = first_axis.count - cell_count *= axis_count + Use :class:`AxisExpander` instead. + """ + message = [ + "The 'expand_axes' method has been deprecated since", + "version 35.4.0, and will be removed in the future. Please use", + "'AxisExpander' instead", + ] + + warnings.warn(" ".join(message), DeprecationWarning) + + expander = AxisExpander(self.axes) + cell_count = expander.count_cells() # Scale the "prototype" situation, repeating it cell_count times for entity_name in self.entity_counts.keys(): diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index dc20c63ba8..a13188479e 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,6 +1,12 @@ import pytest -from openfisca_core.simulations import Axis, AxisArray, SimulationBuilder +from openfisca_core.simulations import ( + Axis, + AxisArray, + AxisExpander, + SimulationBuilder, + ) + from .test_simulation_builder import * # noqa: F401 @@ -44,7 +50,14 @@ def axis_array(): return AxisArray() -# Unit tests +@pytest.fixture +def axis_expander(axis_array, salary_axis, pension_axis): + axis_array = axis_array.add_parallel(salary_axis) + axis_array = axis_array.add_perpendicular(pension_axis) + return AxisExpander(axis_array) + + +# Axis def test_create_axis(salary): @@ -64,6 +77,9 @@ def test_create_empty_axis(): Axis() +# AxisArray + + def test_empty_create_axis_array(): """ Nothing fancy, just an empty container. @@ -114,7 +130,7 @@ def test_add_perpendicular_axis_before_parallel_axis(axis_array, pension_axis): axis_array.add_perpendicular(pension_axis) -def test_add_perpendicular_axis_after_parallel_axis(axis_array, salary_axis, pension_axis): +def test_add_perpendicular_axis(axis_array, salary_axis, pension_axis): """ As there is at least one parallel axis already, it works! """ @@ -132,6 +148,16 @@ def test_add_perpendicular_axis_after_parallel_axis(axis_array, salary_axis, pen assert result.last().first() == pension_axis +# AxisExpander + + +def test_count_cells(axis_expander): + assert axis_expander.count_cells() == 6 + + +# SimulationBuilder + + # With periods From 0c8ad0b73846fbad4137d8b133b2878e7fc1db2b Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sun, 11 Apr 2021 20:36:49 +0200 Subject: [PATCH 11/11] Validate that counts must be equal --- openfisca_core/simulations/axis_array.py | 23 ++++++++++++++++++++++- tests/core/test_axes.py | 21 +++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/openfisca_core/simulations/axis_array.py b/openfisca_core/simulations/axis_array.py index 60648f5c8c..c6734c09da 100644 --- a/openfisca_core/simulations/axis_array.py +++ b/openfisca_core/simulations/axis_array.py @@ -174,14 +174,21 @@ def add_parallel(self, tail: Axis) -> Union[AxisArray, NoReturn]: >>> axis_array = AxisArray() >>> axis = Axis(name = "salary", count = 3, min = 0, max = 3000) - >>> axis_array.add_parallel(axis) + >>> axis_array = axis_array.add_parallel(axis) + >>> axis_array AxisArray[AxisArray[Axis(name='salary', ...)]] + >>> axis = Axis(name = "pension", count = 2, min = 0, max = 3000) + >>> axis_array.add_parallel(axis) + Traceback (most recent call last): + TypeError: Expecting counts to be equal... + .. versionadded:: 35.4.0 """ node = self.validate(isinstance, self.first(), (AxisArray, list)) tail = self.validate(isinstance, tail, Axis) parallel = self.__class__([*node, tail]) + self.validate(self.__has_same_counts, parallel, "counts to be equal") return self.__class__([parallel, *self.axes[1:]]) def add_perpendicular(self, tail: Axis) -> Union[AxisArray, NoReturn]: @@ -244,6 +251,20 @@ def validate( raise TypeError(f"Expecting {expected}, but {real} given.") + def __has_same_counts(self, axes: AxisArray, _) -> bool: + """ + Validate all counts on a collection are the same. + + They have to be the same for all parallel axes of a collection, + otherwise they become non expandable. + """ + counts = list(map(lambda axis: axis.count, axes)) + + if counts.count(counts[0]) == len(counts): + return True + + return False + def __flatten(self, axes: list) -> List[Union[AxisArray, Axis]]: """ Flatten out our entire collection. diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index a13188479e..248abd993c 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -122,6 +122,16 @@ def test_add_parallel_axis(axis_array, salary_axis): assert result.first().first() == salary_axis +def test_add_parallel_axes_with_different_counts(axis_array, salary_axis, pension_axis): + """ + We can't, it should fail! + + Otherwise they become non expandable. + """ + with pytest.raises(TypeError): + axis_array.add_parallel(salary_axis).add_parallel(pension_axis) + + def test_add_perpendicular_axis_before_parallel_axis(axis_array, pension_axis): """ As there are no previously added axes in our collection, it fails! @@ -148,6 +158,17 @@ def test_add_perpendicular_axis(axis_array, salary_axis, pension_axis): assert result.last().first() == pension_axis +def test_add_perpendicular_axes_with_different_counts(axis_array, salary_axis, pension_axis): + """ + We can, because each perpendicular axis is added to the next dimension. + """ + assert \ + axis_array \ + .add_parallel(salary_axis) \ + .add_perpendicular(salary_axis) \ + .add_perpendicular(pension_axis) + + # AxisExpander