diff --git a/docs/source/workflows/code_examples.rst b/docs/source/workflows/code_examples.rst index bbe85d8da..9b1ec4926 100644 --- a/docs/source/workflows/code_examples.rst +++ b/docs/source/workflows/code_examples.rst @@ -66,7 +66,7 @@ This pattern is also extremely useful for performing optimization over complex o } ) - physical_properties_predictor = AutoMLModel( + physical_properties_predictor = AutoMLPredictor( name = 'physical properties model', inputs = [ wheat_flour_quantity, diff --git a/docs/source/workflows/predictors.rst b/docs/source/workflows/predictors.rst index f6185860b..999964c3b 100644 --- a/docs/source/workflows/predictors.rst +++ b/docs/source/workflows/predictors.rst @@ -251,16 +251,16 @@ The following example demonstrates how to use a :class:`~citrine.informatics.pre ml_predictor = AutoMLPredictor( name='ML Model for Density', description='Predict the density, given molecular features of the solvent', - inputs = features, - output = [output_desc] + inputs=features, + outputs=[output_desc] ) # use a graph predictor to wrap together the featurizer and the machine learning model graph_predictor = GraphPredictor( name='Density from solvent molecular structure', description='Predict the density from the solvent molecular structure using molecular structure features.', - predictors = [featurizer, ml_predictor], - training_data = [GemTableDataSource(table_id=training_data_table_uid, table_version=training_data_table_version)] # training data shared by all sub-predictors + predictors=[featurizer, ml_predictor], + training_data=[GemTableDataSource(table_id=training_data_table_uid, table_version=training_data_table_version)] # training data shared by all sub-predictors ) # register or update predictor by name diff --git a/src/citrine/informatics/predictor_evaluator.py b/src/citrine/informatics/predictor_evaluator.py index 2283e69ad..8c71e884b 100644 --- a/src/citrine/informatics/predictor_evaluator.py +++ b/src/citrine/informatics/predictor_evaluator.py @@ -1,13 +1,10 @@ from citrine._serialization import properties from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization.serializable import Serializable -from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric from citrine.informatics.data_sources import DataSource +from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric -__all__ = ['PredictorEvaluator', - 'CrossValidationEvaluator', - 'HoldoutSetEvaluator' - ] +__all__ = ["PredictorEvaluator", "CrossValidationEvaluator", "HoldoutSetEvaluator"] class PredictorEvaluator(PolymorphicSerializable["PredictorEvaluator"]): @@ -18,7 +15,7 @@ def get_type(cls, data) -> type[Serializable]: """Return the subtype.""" return { "CrossValidationEvaluator": CrossValidationEvaluator, - "HoldoutSetEvaluator": HoldoutSetEvaluator + "HoldoutSetEvaluator": HoldoutSetEvaluator, }[data["type"]] def __eq__(self, other): @@ -26,13 +23,15 @@ def __eq__(self, other): self_dict = self.dump() other_dict = other.dump() - self_dict['responses'] = set(self_dict.get('responses', [])) - self_dict['metrics'] = frozenset( - frozenset((k, v) for k, v in dct.items()) for dct in self_dict.get('metrics', []) + self_dict["responses"] = set(self_dict.get("responses", [])) + self_dict["metrics"] = frozenset( + frozenset((k, v) for k, v in dct.items()) + for dct in self_dict.get("metrics", []) ) - other_dict['responses'] = set(other_dict.get('responses', [])) - other_dict['metrics'] = frozenset( - frozenset((k, v) for k, v in dct.items()) for dct in other_dict.get('metrics', []) + other_dict["responses"] = set(other_dict.get("responses", [])) + other_dict["metrics"] = frozenset( + frozenset((k, v) for k, v in dct.items()) + for dct in other_dict.get("metrics", []) ) return self_dict == other_dict @@ -55,13 +54,15 @@ def name(self) -> str: A name is required by all evaluators because it is used as the top-level key in the results returned by a - :class:`citrine.informatics.workflows.PredictorEvaluationWorkflow`. + :class:`citrine.informatics.executions.predictor_evaluation.PredictorEvaluation`. As such, the names of all evaluators within a single workflow must be unique. """ raise NotImplementedError # pragma: no cover -class CrossValidationEvaluator(Serializable["CrossValidationEvaluator"], PredictorEvaluator): +class CrossValidationEvaluator( + Serializable["CrossValidationEvaluator"], PredictorEvaluator +): """Evaluate a predictor via cross validation. Performs cross-validation on requested predictor responses and computes the requested metrics @@ -103,21 +104,27 @@ class CrossValidationEvaluator(Serializable["CrossValidationEvaluator"], Predict _responses = properties.Set(properties.String, "responses") n_folds = properties.Integer("n_folds") n_trials = properties.Integer("n_trials") - _metrics = properties.Optional(properties.Set(properties.Object(PredictorEvaluationMetric)), - "metrics") - ignore_when_grouping = properties.Optional(properties.Set(properties.String), - "ignore_when_grouping") - typ = properties.String("type", default="CrossValidationEvaluator", deserializable=False) - - def __init__(self, - name: str, - *, - description: str = "", - responses: set[str], - n_folds: int = 5, - n_trials: int = 3, - metrics: set[PredictorEvaluationMetric] | None = None, - ignore_when_grouping: set[str] | None = None): + _metrics = properties.Optional( + properties.Set(properties.Object(PredictorEvaluationMetric)), "metrics" + ) + ignore_when_grouping = properties.Optional( + properties.Set(properties.String), "ignore_when_grouping" + ) + typ = properties.String( + "type", default="CrossValidationEvaluator", deserializable=False + ) + + def __init__( + self, + name: str, + *, + description: str = "", + responses: set[str], + n_folds: int = 5, + n_trials: int = 3, + metrics: set[PredictorEvaluationMetric] | None = None, + ignore_when_grouping: set[str] | None = None, + ): self.name: str = name self.description: str = description self._responses: set[str] = responses @@ -161,16 +168,20 @@ class HoldoutSetEvaluator(Serializable["HoldoutSetEvaluator"], PredictorEvaluato description = properties.String("description") _responses = properties.Set(properties.String, "responses") data_source = properties.Object(DataSource, "data_source") - _metrics = properties.Optional(properties.Set(properties.Object(PredictorEvaluationMetric)), - "metrics") + _metrics = properties.Optional( + properties.Set(properties.Object(PredictorEvaluationMetric)), "metrics" + ) typ = properties.String("type", default="HoldoutSetEvaluator", deserializable=False) - def __init__(self, - name: str, *, - description: str = "", - responses: set[str], - data_source: DataSource, - metrics: set[PredictorEvaluationMetric] | None = None): + def __init__( + self, + name: str, + *, + description: str = "", + responses: set[str], + data_source: DataSource, + metrics: set[PredictorEvaluationMetric] | None = None, + ): self.name: str = name self.description: str = description self._responses: set[str] = responses diff --git a/tests/conftest.py b/tests/conftest.py index bbcba1583..147403487 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,34 +6,29 @@ from citrine.informatics.predictors import AutoMLEstimator from citrine.resources.status_detail import StatusDetail, StatusLevelEnum -from tests.utils.factories import (PredictorEntityDataFactory, PredictorDataDataFactory, - PredictorMetadataDataFactory, StatusDataFactory) +from tests.utils.factories import ( + PredictorDataDataFactory, + PredictorEntityDataFactory, + PredictorMetadataDataFactory, + StatusDataFactory, +) def build_predictor_entity(instance, status_name="READY", status_detail=[]): user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( name=instance.get("name"), description=instance.get("description"), - instance=instance + instance=instance, ), metadata=dict( - status=dict( - name=status_name, - detail=status_detail - ), - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ) - ) + status=dict(name=status_name, detail=status_detail), + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + ), ) @@ -41,78 +36,70 @@ def build_predictor_entity(instance, status_name="READY", status_detail=[]): def valid_product_design_space_data(): """Produce valid product design space data.""" from citrine.informatics.descriptors import FormulationDescriptor + user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( - name='my design space', - description='does some things', + name="my design space", + description="does some things", instance=dict( - type='ProductDesignSpace', - name='my design space', - description='does some things', + type="ProductDesignSpace", + name="my design space", + description="does some things", subspaces=[ dict( - type='FormulationDesignSpace', - name='first subspace', - description='', + type="FormulationDesignSpace", + name="first subspace", + description="", formulation_descriptor=FormulationDescriptor.hierarchical().dump(), - ingredients=['foo'], - labels={'bar': ['foo']}, + ingredients=["foo"], + labels={"bar": ["foo"]}, constraints=[], - resolution=0.1 + resolution=0.1, ), dict( - type='FormulationDesignSpace', - name='second subspace', - description='formulates some things', + type="FormulationDesignSpace", + name="second subspace", + description="formulates some things", formulation_descriptor=FormulationDescriptor.hierarchical().dump(), - ingredients=['baz'], + ingredients=["baz"], labels={}, constraints=[], - resolution=0.1 - ) + resolution=0.1, + ), ], dimensions=[ dict( - type='ContinuousDimension', + type="ContinuousDimension", descriptor=dict( - type='Real', - descriptor_key='alpha', - units='', + type="Real", + descriptor_key="alpha", + units="", lower_bound=5.0, upper_bound=10.0, ), lower_bound=6.0, - upper_bound=7.0 + upper_bound=7.0, ), dict( - type='EnumeratedDimension', + type="EnumeratedDimension", descriptor=dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], + type="Categorical", + descriptor_key="color", + descriptor_values=["blue", "green", "red"], ), - list=['red'] - ) - ] - ) + list=["red"], + ), + ], + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @@ -121,102 +108,90 @@ def valid_formulation_design_space_data(): """Produce valid formulation design space data.""" from citrine.informatics.constraints import IngredientCountConstraint from citrine.informatics.descriptors import FormulationDescriptor + descriptor = FormulationDescriptor.hierarchical() - constraint = IngredientCountConstraint(formulation_descriptor=descriptor, min=0, max=1) + constraint = IngredientCountConstraint( + formulation_descriptor=descriptor, min=0, max=1 + ) return dict( - type='FormulationDesignSpace', - name='formulation design space', - description='formulates some things', + type="FormulationDesignSpace", + name="formulation design space", + description="formulates some things", formulation_descriptor=descriptor.dump(), - ingredients=['foo'], - labels={'bar': ['foo']}, + ingredients=["foo"], + labels={"bar": ["foo"]}, constraints=[constraint.dump()], - resolution=0.1 + resolution=0.1, ) @pytest.fixture def valid_hierarchical_design_space_data( - valid_material_node_definition_data, - valid_gem_data_source_dict + valid_material_node_definition_data, valid_gem_data_source_dict ): """Produce valid hierarchical design space data.""" import copy - name = 'hierarchical design space' - description = 'does things but in levels' + + name = "hierarchical design space" + description = "does things but in levels" user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( name=name, description=description, instance=dict( - type='HierarchicalDesignSpace', + type="HierarchicalDesignSpace", name=name, description=description, root=copy.deepcopy(valid_material_node_definition_data), subspaces=[copy.deepcopy(valid_material_node_definition_data)], - data_sources=[valid_gem_data_source_dict] - ) + data_sources=[valid_gem_data_source_dict], + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - archived=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + archived=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @pytest.fixture def valid_material_node_definition_data(valid_formulation_design_space_data): return dict( - identifier=dict( - id=f"Material Node-{uuid.uuid4()}", - scope="Custom Scope" - ), + identifier=dict(id=f"Material Node-{uuid.uuid4()}", scope="Custom Scope"), attributes=[ dict( - type='ContinuousDimension', + type="ContinuousDimension", descriptor=dict( - type='Real', - descriptor_key='alpha', - units='', + type="Real", + descriptor_key="alpha", + units="", lower_bound=5.0, upper_bound=10.0, ), lower_bound=6.0, - upper_bound=7.0 + upper_bound=7.0, ), dict( - type='EnumeratedDimension', + type="EnumeratedDimension", descriptor=dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], + type="Categorical", + descriptor_key="color", + descriptor_values=["blue", "green", "red"], ), - list=['red'] - ) + list=["red"], + ), ], formulation=valid_formulation_design_space_data, template=dict( material_template=str(uuid.uuid4()), process_template=str(uuid.uuid4()), ), - display_name="Material Node" + display_name="Material Node", ) @@ -224,8 +199,8 @@ def valid_material_node_definition_data(valid_formulation_design_space_data): def valid_gem_data_source_dict(): return { "type": "hosted_table_data_source", - "table_id": 'e5c51369-8e71-4ec6-b027-1f92bdc14762', - "table_version": 2 + "table_id": "e5c51369-8e71-4ec6-b027-1f92bdc14762", + "table_version": 2, } @@ -233,39 +208,43 @@ def valid_gem_data_source_dict(): def valid_auto_ml_predictor_data(valid_gem_data_source_dict): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") return dict( - type='AutoML', - name='AutoML predictor', - description='Predicts z from input x', + type="AutoML", + name="AutoML predictor", + description="Predicts z from input x", inputs=[x.dump()], outputs=[z.dump()], - estimators=[AutoMLEstimator.RANDOM_FOREST.value] + estimators=[AutoMLEstimator.RANDOM_FOREST.value], ) @pytest.fixture def valid_graph_predictor_data( - valid_simple_mixture_predictor_data, - valid_label_fractions_predictor_data, - valid_expression_predictor_data, - valid_mean_property_predictor_data, - valid_auto_ml_predictor_data + valid_simple_mixture_predictor_data, + valid_label_fractions_predictor_data, + valid_expression_predictor_data, + valid_mean_property_predictor_data, + valid_auto_ml_predictor_data, ): """Produce valid data used for tests.""" from citrine.informatics.data_sources import GemTableDataSource + instance = dict( - name='Graph predictor', - description='description', + name="Graph predictor", + description="description", predictors=[ valid_simple_mixture_predictor_data, valid_label_fractions_predictor_data, valid_expression_predictor_data, valid_mean_property_predictor_data, - valid_auto_ml_predictor_data + valid_auto_ml_predictor_data, + ], + training_data=[ + GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump() ], - training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump()] ) return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance)) @@ -274,11 +253,11 @@ def valid_graph_predictor_data( def valid_graph_predictor_data_empty(): """Another predictor valid data used for tests.""" instance = dict( - type='Graph', - name='Empty Graph predictor', - description='description', + type="Graph", + name="Empty Graph predictor", + description="description", predictors=[], - training_data=[] + training_data=[], ) return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance)) @@ -287,17 +266,20 @@ def valid_graph_predictor_data_empty(): def valid_deprecated_expression_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor - shear_modulus = RealDescriptor('Property~Shear modulus', lower_bound=0, upper_bound=100, units='GPa') + + shear_modulus = RealDescriptor( + "Property~Shear modulus", lower_bound=0, upper_bound=100, units="GPa" + ) return dict( - type='Expression', - name='Expression predictor', - description='Computes shear modulus from Youngs modulus and Poissons ratio', - expression='Y / (2 * (1 + v))', + type="Expression", + name="Expression predictor", + description="Computes shear modulus from Youngs modulus and Poissons ratio", + expression="Y / (2 * (1 + v))", output=shear_modulus.dump(), aliases={ - 'Y': "Property~Young's modulus", - 'v': "Property~Poisson's ratio", - } + "Y": "Property~Young's modulus", + "v": "Property~Poisson's ratio", + }, ) @@ -305,19 +287,26 @@ def valid_deprecated_expression_predictor_data(): def valid_expression_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor - shear_modulus = RealDescriptor('Property~Shear modulus', lower_bound=0, upper_bound=100, units='GPa') - youngs_modulus = RealDescriptor('Property~Young\'s modulus', lower_bound=0, upper_bound=100, units='GPa') - poissons_ratio = RealDescriptor('Property~Poisson\'s ratio', lower_bound=-1, upper_bound=0.5, units='') + + shear_modulus = RealDescriptor( + "Property~Shear modulus", lower_bound=0, upper_bound=100, units="GPa" + ) + youngs_modulus = RealDescriptor( + "Property~Young's modulus", lower_bound=0, upper_bound=100, units="GPa" + ) + poissons_ratio = RealDescriptor( + "Property~Poisson's ratio", lower_bound=-1, upper_bound=0.5, units="" + ) return dict( - type='AnalyticExpression', - name='Expression predictor', - description='Computes shear modulus from Youngs modulus and Poissons ratio', - expression='Y / (2 * (1 + v))', + type="AnalyticExpression", + name="Expression predictor", + description="Computes shear modulus from Youngs modulus and Poissons ratio", + expression="Y / (2 * (1 + v))", output=shear_modulus.dump(), aliases={ - 'Y': youngs_modulus.dump(), - 'v': poissons_ratio.dump(), - } + "Y": youngs_modulus.dump(), + "v": poissons_ratio.dump(), + }, ) @@ -325,40 +314,39 @@ def valid_expression_predictor_data(): def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metrics): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=1, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=101, units="") return dict( - id='7c2dda5d-675a-41b6-829c-e485163f0e43', - module_id='31c7f311-6f3d-4a93-9387-94cc877f170c', - status='OK', - create_time='2020-04-23T15:46:26Z', - update_time='2020-04-23T15:46:26Z', + id="7c2dda5d-675a-41b6-829c-e485163f0e43", + module_id="31c7f311-6f3d-4a93-9387-94cc877f170c", + status="OK", + create_time="2020-04-23T15:46:26Z", + update_time="2020-04-23T15:46:26Z", report=dict( models=[ dict( - name='GeneralLoloModel_1', - type='ML Model', + name="GeneralLoloModel_1", + type="ML Model", inputs=[x.key], outputs=[y.key], - display_name='ML Model', + display_name="ML Model", model_settings=[ dict( - name='Algorithm', - value='Ensemble of non-linear estimators', + name="Algorithm", + value="Ensemble of non-linear estimators", children=[ - dict(name='Number of estimators', value=64, children=[]), - dict(name='Leaf model', value='Mean', children=[]), - dict(name='Use jackknife', value=True, children=[]) - ] + dict( + name="Number of estimators", value=64, children=[] + ), + dict(name="Leaf model", value="Mean", children=[]), + dict(name="Use jackknife", value=True, children=[]), + ], ) ], feature_importances=[ - dict( - response_key='y', - importances=dict(x=1.00), - top_features=5 - ) + dict(response_key="y", importances=dict(x=1.00), top_features=5) ], selection_summary=dict( n_folds=4, @@ -366,48 +354,56 @@ def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metr dict( model_settings=[ dict( - name='Algorithm', - value='Ensemble of non-linear estimators', + name="Algorithm", + value="Ensemble of non-linear estimators", children=[ - dict(name='Number of estimators', value=64, children=[]), - dict(name='Leaf model', value='Mean', children=[]), - dict(name='Use jackknife', value=True, children=[]) - ] + dict( + name="Number of estimators", + value=64, + children=[], + ), + dict( + name="Leaf model", + value="Mean", + children=[], + ), + dict( + name="Use jackknife", + value=True, + children=[], + ), + ], ) ], response_results=dict( response_name=dict( metrics=dict( predicted_vs_actual=example_categorical_pva_metrics, - f1=example_f1_metrics + f1=example_f1_metrics, ) ) - ) + ), ) - ] + ], ), - predictor_configuration_name="Predict y from x with ML" + predictor_configuration_name="Predict y from x with ML", ), dict( - name='GeneralLosslessModel_2', - type='Analytic Model', + name="GeneralLosslessModel_2", + type="Analytic Model", inputs=[x.key, y.key], outputs=[z.key], - display_name='GeneralLosslessModel_2', + display_name="GeneralLosslessModel_2", model_settings=[ - dict( - name="Expression", - value="(z) <- (x + y)", - children=[] - ) + dict(name="Expression", value="(z) <- (x + y)", children=[]) ], feature_importances=[], predictor_configuration_name="Expression for z", - predictor_configuration_uid="249bf32c-6f3d-4a93-9387-94cc877f170c" - ) + predictor_configuration_uid="249bf32c-6f3d-4a93-9387-94cc877f170c", + ), ], - descriptors=[x.dump(), y.dump(), z.dump()] - ) + descriptors=[x.dump(), y.dump(), z.dump()], + ), ) @@ -415,18 +411,23 @@ def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metr def valid_ing_formulation_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + return dict( - type='IngredientsToSimpleMixture', - name='Ingredients to formulation predictor', - description='Constructs mixtures from ingredients', + type="IngredientsToSimpleMixture", + name="Ingredients to formulation predictor", + description="Constructs mixtures from ingredients", id_to_quantity={ - 'water': RealDescriptor('water quantity', lower_bound=0, upper_bound=1, units="").dump(), - 'salt': RealDescriptor('salt quantity', lower_bound=0, upper_bound=1, units="").dump() + "water": RealDescriptor( + "water quantity", lower_bound=0, upper_bound=1, units="" + ).dump(), + "salt": RealDescriptor( + "salt quantity", lower_bound=0, upper_bound=1, units="" + ).dump(), }, labels={ - 'solvent': ['water'], - 'solute': ['salt'], - } + "solvent": ["water"], + "solute": ["salt"], + }, ) @@ -434,17 +435,18 @@ def valid_ing_formulation_predictor_data(): def valid_generalized_mean_property_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor + formulation_descriptor = FormulationDescriptor.hierarchical() return dict( - type='GeneralizedMeanProperty', - name='Mean property predictor', - description='Computes mean ingredient properties', + type="GeneralizedMeanProperty", + name="Mean property predictor", + description="Computes mean ingredient properties", input=formulation_descriptor.dump(), - properties=['density'], + properties=["density"], p=2, impute_properties=True, - default_properties={'density': 1.0}, - label='solvent' + default_properties={"density": 1.0}, + label="solvent", ) @@ -452,18 +454,21 @@ def valid_generalized_mean_property_predictor_data(): def valid_mean_property_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor + formulation_descriptor = FormulationDescriptor.flat() - density = RealDescriptor(key='density', lower_bound=0, upper_bound=100, units='g/cm^3') + density = RealDescriptor( + key="density", lower_bound=0, upper_bound=100, units="g/cm^3" + ) return dict( - type='MeanProperty', - name='Mean property predictor', - description='Computes mean ingredient properties', + type="MeanProperty", + name="Mean property predictor", + description="Computes mean ingredient properties", input=formulation_descriptor.dump(), properties=[density.dump()], p=2.0, impute_properties=True, - default_properties={'density': 1.0}, - label='solvent' + default_properties={"density": 1.0}, + label="solvent", ) @@ -471,12 +476,13 @@ def valid_mean_property_predictor_data(): def valid_label_fractions_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor + return dict( - type='LabelFractions', - name='Label fractions predictor', - description='Computes relative proportions of labeled ingredients', + type="LabelFractions", + name="Label fractions predictor", + description="Computes relative proportions of labeled ingredients", input=FormulationDescriptor.hierarchical().dump(), - labels=['solvent'] + labels=["solvent"], ) @@ -484,12 +490,13 @@ def valid_label_fractions_predictor_data(): def valid_ingredient_fractions_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor + return dict( - type='IngredientFractions', - name='Ingredient fractions predictor', - description='Computes ingredient fractions', + type="IngredientFractions", + name="Ingredient fractions predictor", + description="Computes ingredient fractions", input=FormulationDescriptor.hierarchical().dump(), - ingredients=['Blue dye', 'Red dye'] + ingredients=["Blue dye", "Red dye"], ) @@ -499,7 +506,7 @@ def valid_data_source_design_space_dict(valid_gem_data_source_dict): type="DataSourceDesignSpace", name="Example valid data source design space", description="Example valid data source design space based on a GEM Table Data Source.", - data_source=valid_gem_data_source_dict + data_source=valid_gem_data_source_dict, ) @@ -507,15 +514,16 @@ def valid_data_source_design_space_dict(valid_gem_data_source_dict): def invalid_predictor_node_data(): """Produce invalid valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") return dict( - type='invalid', - name='my predictor', - description='does some things', + type="invalid", + name="my predictor", + description="does some things", inputs=[x.dump(), y.dump()], - output=z.dump() + output=z.dump(), ) @@ -523,23 +531,24 @@ def invalid_predictor_node_data(): def invalid_graph_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") instance = dict( - type='invalid', - name='my predictor', - description='does some things badly', + type="invalid", + name="my predictor", + description="does some things badly", predictors=[x.dump(), y.dump()], ) detail = [ - StatusDetail(level=StatusLevelEnum.WARNING, msg='Something is wrong'), - StatusDetail(level="Error", msg='Very wrong') + StatusDetail(level=StatusLevelEnum.WARNING, msg="Something is wrong"), + StatusDetail(level="Error", msg="Very wrong"), ] - status = StatusDataFactory(name='INVALID', detail=detail) + status = StatusDataFactory(name="INVALID", detail=detail) return PredictorEntityDataFactory( data=PredictorDataDataFactory(instance=instance), - meatadata=PredictorMetadataDataFactory(status=status) + meatadata=PredictorMetadataDataFactory(status=status), ) @@ -547,11 +556,11 @@ def invalid_graph_predictor_data(): def invalid_design_subspace_data(): """Produce invalid valid data used for tests.""" return dict( - type='invalid', - name='my design space', - description='does some things', + type="invalid", + name="my design space", + description="does some things", subspaces=[], - dimensions=[] + dimensions=[], ) @@ -559,9 +568,9 @@ def invalid_design_subspace_data(): def valid_simple_mixture_predictor_data(): """Produce valid data used for tests.""" return dict( - type='SimpleMixture', - name='Simple mixture predictor', - description='simple mixture description' + type="SimpleMixture", + name="Simple mixture predictor", + description="simple mixture description", ) @@ -574,10 +583,8 @@ def example_cv_evaluator_dict(): "responses": ["salt?", "saltiness"], "n_folds": 6, "n_trials": 8, - "metrics": [ - {"type": "PVA"}, {"type": "RMSE"}, {"type": "F1"} - ], - "ignore_when_grouping": ["temperature"] + "metrics": [{"type": "PVA"}, {"type": "RMSE"}, {"type": "F1"}], + "ignore_when_grouping": ["temperature"], } @@ -589,24 +596,18 @@ def example_holdout_evaluator_dict(valid_gem_data_source_dict): "description": "", "responses": ["sweetness"], "data_source": valid_gem_data_source_dict, - "metrics": [{"type": "RMSE"}] + "metrics": [{"type": "RMSE"}], } + @pytest.fixture() def example_rmse_metrics(): - return { - "type": "RealMetricValue", - "mean": 0.4, - "standard_error": 0.12 - } + return {"type": "RealMetricValue", "mean": 0.4, "standard_error": 0.12} @pytest.fixture def example_f1_metrics(): - return { - "type": "RealMetricValue", - "mean": 0.3 - } + return {"type": "RealMetricValue", "mean": 0.3} @pytest.fixture @@ -622,15 +623,15 @@ def example_real_pva_metrics(): "predicted": { "type": "RealMetricValue", "mean": 1.0, - "standard_error": 0.12 + "standard_error": 0.12, }, "actual": { "type": "RealMetricValue", "mean": 1.2, - "standard_error": 0.0 - } + "standard_error": 0.0, + }, } - ] + ], } @@ -644,20 +645,21 @@ def example_categorical_pva_metrics(): "identifiers": ["Foo", "Bar"], "trial": 1, "fold": 3, - "predicted": { - "salt": 0.3, - "not salt": 0.7 - }, - "actual": { - "not salt": 1.0 - } + "predicted": {"salt": 0.3, "not salt": 0.7}, + "actual": {"not salt": 1.0}, } - ] + ], } @pytest.fixture() -def example_cv_result_dict(example_cv_evaluator_dict, example_rmse_metrics, example_categorical_pva_metrics, example_f1_metrics, example_real_pva_metrics): +def example_cv_result_dict( + example_cv_evaluator_dict, + example_rmse_metrics, + example_categorical_pva_metrics, + example_f1_metrics, + example_real_pva_metrics, +): return { "type": "CrossValidationResult", "evaluator": example_cv_evaluator_dict, @@ -665,16 +667,16 @@ def example_cv_result_dict(example_cv_evaluator_dict, example_rmse_metrics, exam "salt?": { "metrics": { "predicted_vs_actual": example_categorical_pva_metrics, - "f1": example_f1_metrics + "f1": example_f1_metrics, } }, "saltiness": { "metrics": { "predicted_vs_actual": example_real_pva_metrics, - "rmse": example_rmse_metrics + "rmse": example_rmse_metrics, } - } - } + }, + }, } @@ -683,13 +685,7 @@ def example_holdout_result_dict(example_holdout_evaluator_dict, example_rmse_met return { "type": "HoldoutSetResult", "evaluator": example_holdout_evaluator_dict, - "response_results": { - "sweetness": { - "metrics": { - "rmse": example_rmse_metrics - } - } - } + "response_results": {"sweetness": {"metrics": {"rmse": example_rmse_metrics}}}, } @@ -702,8 +698,8 @@ def sample_design_space_execution_dict(generic_entity): "status": { "major": ret.get("status"), "minor": ret.get("status_description"), - "detail": ret.get("status_detail") - } + "detail": ret.get("status_detail"), + }, } ) return ret @@ -712,30 +708,31 @@ def sample_design_space_execution_dict(generic_entity): @pytest.fixture() def example_design_material(): return { - 'vars': { - 'Temperature': {'type': 'R', 'm': 475.8, 's': 0}, - 'Flour': {'type': 'C', 'cp': {'flour': 100.0}}, - 'Water': {'type': 'M', 'q': {'water': 72.5}, 'l': {}}, - 'Salt': {'type': 'F', 'f': 'NaCl'}, - 'Yeast': {'type': 'S', 's': 'O1C=2C=C(C=3SC=C4C=CNC43)CC2C=5C=CC=6C=CNC6C15'} + "vars": { + "Temperature": {"type": "R", "m": 475.8, "s": 0}, + "Flour": {"type": "C", "cp": {"flour": 100.0}}, + "Water": {"type": "M", "q": {"water": 72.5}, "l": {}}, + "Salt": {"type": "F", "f": "NaCl"}, + "Yeast": { + "type": "S", + "s": "O1C=2C=C(C=3SC=C4C=CNC43)CC2C=5C=CC=6C=CNC6C15", + }, + }, + "identifiers": { + "id": str(uuid.uuid4()), + "identifiers": [], + "material_template": str(uuid.uuid4()), + "process_template": str(uuid.uuid4()), }, - 'identifiers': { - 'id': str(uuid.uuid4()), - 'identifiers': [], - 'material_template': str(uuid.uuid4()), - 'process_template': str(uuid.uuid4()) - } } @pytest.fixture() def example_hierarchical_design_material(example_design_material): return { - 'terminal': example_design_material, - 'sub_materials': [example_design_material], - 'mixtures': { - str(uuid.uuid4()): {'q': {'A': 0.5, 'B': 0.5}, 'l': {}} - } + "terminal": example_design_material, + "sub_materials": [example_design_material], + "mixtures": {str(uuid.uuid4()): {"q": {"A": 0.5, "B": 0.5}, "l": {}}}, } @@ -744,48 +741,53 @@ def example_hierarchical_candidates(example_hierarchical_design_material): return { "page": 2, "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "primary_score": 0, - "rank": 1, - "material": example_hierarchical_design_material, - "name": "Example candidate", - "hidden": True, - "comments": [ - { - "message": "a message", - "created": { - "user": str(uuid.uuid4()), - "time": '2025-02-20T10:46:26Z' + "response": [ + { + "id": str(uuid.uuid4()), + "primary_score": 0, + "rank": 1, + "material": example_hierarchical_design_material, + "name": "Example candidate", + "hidden": True, + "comments": [ + { + "message": "a message", + "created": { + "user": str(uuid.uuid4()), + "time": "2025-02-20T10:46:26Z", + }, } - } - ] - }] + ], + } + ], } + @pytest.fixture() def example_candidates(example_design_material): return { "page": 2, "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "material_id": str(uuid.uuid4()), - "identifiers": [], - "primary_score": 0, - "material": example_design_material, - "name": "Example candidate", - "hidden": True, - "comments": [ - { - "message": "a message", - "created": { - "user": str(uuid.uuid4()), - "time": '2025-02-20T10:46:26Z' + "response": [ + { + "id": str(uuid.uuid4()), + "material_id": str(uuid.uuid4()), + "identifiers": [], + "primary_score": 0, + "material": example_design_material, + "name": "Example candidate", + "hidden": True, + "comments": [ + { + "message": "a message", + "created": { + "user": str(uuid.uuid4()), + "time": "2025-02-20T10:46:26Z", + }, } - } - ] - }] + ], + } + ], } @@ -793,15 +795,16 @@ def example_candidates(example_design_material): def example_sample_design_space_response(example_hierarchical_design_material): return { "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "execution_id": str(uuid.uuid4()), - "material": example_hierarchical_design_material - }] + "response": [ + { + "id": str(uuid.uuid4()), + "execution_id": str(uuid.uuid4()), + "material": example_hierarchical_design_material, + } + ], } - @pytest.fixture def generic_entity(): user = str(uuid.uuid4()) @@ -812,8 +815,8 @@ def generic_entity(): "status_detail": [{"level": "Info", "msg": "System processing"}], "experimental": False, "experimental_reasons": [], - "create_time": '2020-04-23T15:46:26Z', - "update_time": '2020-04-23T15:46:26Z', + "create_time": "2020-04-23T15:46:26Z", + "update_time": "2020-04-23T15:46:26Z", "created_by": user, "updated_by": user, } @@ -822,31 +825,35 @@ def generic_entity(): @pytest.fixture def predictor_evaluation_execution_dict(generic_entity): ret = deepcopy(generic_entity) - ret.update({ - "workflow_id": str(uuid.uuid4()), - "predictor_id": str(uuid.uuid4()), - "predictor_version": random.randint(1, 10), - "evaluator_names": ["Example evaluator"] - }) + ret.update( + { + "workflow_id": str(uuid.uuid4()), + "predictor_id": str(uuid.uuid4()), + "predictor_version": random.randint(1, 10), + "evaluator_names": ["Example evaluator"], + } + ) return ret @pytest.fixture def design_execution_dict(generic_entity): ret = generic_entity.copy() - ret.update({ - "workflow_id": str(uuid.uuid4()), - "version_number": 2, - "score": { - "type": "MLI", - "baselines": [], - "constraints": [], - "objectives": [], - "name": "score", - "description": "" - }, - "descriptors": [] - }) + ret.update( + { + "workflow_id": str(uuid.uuid4()), + "version_number": 2, + "score": { + "type": "MLI", + "baselines": [], + "constraints": [], + "objectives": [], + "name": "score", + "description": "", + }, + "descriptors": [], + } + ) return ret @@ -861,26 +868,16 @@ def example_generation_results(): return { "page": 1, "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "execution_id": str(uuid.uuid4()), - "result": { - "seed": "CCCCO", - "mutated": "CCCN", - "fingerprint_similarity": 0.41, - "fingerprint_type": "ECFP4", + "response": [ + { + "id": str(uuid.uuid4()), + "execution_id": str(uuid.uuid4()), + "result": { + "seed": "CCCCO", + "mutated": "CCCN", + "fingerprint_similarity": 0.41, + "fingerprint_type": "ECFP4", + }, } - }] + ], } - - - -@pytest.fixture -def predictor_evaluation_workflow_dict(generic_entity, example_cv_evaluator_dict, example_holdout_evaluator_dict): - ret = deepcopy(generic_entity) - ret.update({ - "name": "Example PEW", - "description": "Example PEW for testing", - "evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict] - }) - return ret diff --git a/tests/utils/factories.py b/tests/utils/factories.py index dc7f96399..489118cd9 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -4,15 +4,18 @@ # Naming convention here is to use "*DataFactory" for dictionaries used as API input/out, and # Factory for the domain objects themselves +from random import randint, random +from typing import Optional, Set + import arrow import factory from faker.providers.date_time import Provider -from random import random, randint -from typing import Set, Optional +from gemd import EmpiricalFormula, FileLink, LinkByUID +from gemd.enumeration import SampleType -from citrine.gemd_queries.gemd_query import * from citrine.gemd_queries.criteria import * from citrine.gemd_queries.filter import * +from citrine.gemd_queries.gemd_query import * from citrine.informatics.scores import LIScore from citrine.informatics.workflows import DesignWorkflow from citrine.jobs.job import JobStatus @@ -25,9 +28,6 @@ from citrine.resources.process_template import ProcessTemplate from citrine.resources.table_config import TableConfigInitiator -from gemd import LinkByUID, EmpiricalFormula, FileLink -from gemd.enumeration import SampleType - class AugmentedProvider(Provider): def random_formula(self, count: int = None, elements: Set[str] = None) -> str: @@ -38,7 +38,9 @@ def random_formula(self, count: int = None, elements: Set[str] = None) -> str: count = self.generator.random.randrange(1, 5) components = sorted(self.generator.random.sample(elements, count)) # Use weights to bias toward looking more real-ish - amounts = self.generator.random.choices([1, 2, 3, 4, 5], weights=[40, 40, 10, 10, 2], k=count) + amounts = self.generator.random.choices( + [1, 2, 3, 4, 5], weights=[40, 40, 10, 10, 2], k=count + ) return "".join(f"({c}){a}" for c, a in zip(components, amounts)) def random_smiles(self) -> str: @@ -53,7 +55,7 @@ def random_smiles(self) -> str: "F": 1, "Cl": 1, "Br": 1, - "I": 1 + "I": 1, } valence = { "B": 3, @@ -65,9 +67,9 @@ def random_smiles(self) -> str: "F": 1, "Cl": 1, "Br": 1, - "I": 1 + "I": 1, } - bonds = ['', '=', '#', '$'] + bonds = ["", "=", "#", "$"] elements = list(element_weights) weights = list(element_weights.values()) @@ -83,12 +85,17 @@ def random_smiles(self) -> str: else: atom = self.generator.random.choices(elements, weights=weights)[0] max_bond = max(valence[atom], remain[-1]) - bond = 1 + self.generator.random.choices( - range(max_bond), - weights=[0.1 ** i for i in range(max_bond)] - )[0] + bond = ( + 1 + + self.generator.random.choices( + range(max_bond), weights=[0.1**i for i in range(max_bond)] + )[0] + ) remain[-1] -= bond - if remain[-1] > 1 and self.generator.random.randrange(3 ** len(remain)) == 0: + if ( + remain[-1] > 1 + and self.generator.random.randrange(3 ** len(remain)) == 0 + ): # Branch remain.append(None) smiles += "(" @@ -98,9 +105,9 @@ def random_smiles(self) -> str: return smiles[:-1] # Always has a superfluous ) at the end def unix_milliseconds( - self, - end_milliseconds: Optional[int] = None, - start_milliseconds: Optional[int] = None, + self, + end_milliseconds: Optional[int] = None, + start_milliseconds: Optional[int] = None, ) -> float: """ Get a timestamp in milliseconds between January 1, 1970 and now, unless @@ -131,19 +138,19 @@ class UserTimestampDataFactory(factory.DictFactory): class TeamDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') - name = factory.Faker('company') - description = factory.Faker('catch_phrase') + id = factory.Faker("uuid4") + name = factory.Faker("company") + description = factory.Faker("catch_phrase") created_at = factory.Faker("unix_milliseconds") class ProjectDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') - name = factory.Faker('company') - description = factory.Faker('catch_phrase') - status = 'CREATED' + id = factory.Faker("uuid4") + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + status = "CREATED" created_at = factory.Faker("unix_milliseconds") - team_id = factory.Faker('uuid4') + team_id = factory.Faker("uuid4") class DataVersionUpdateFactory(factory.DictFactory): @@ -152,8 +159,8 @@ class DataVersionUpdateFactory(factory.DictFactory): class PredictorRefFactory(factory.DictFactory): - predictor_id = factory.Faker('uuid4') - predictor_version = factory.Faker('random_digit_not_null') + predictor_id = factory.Faker("uuid4") + predictor_version = factory.Faker("random_digit_not_null") class BranchDataUpdateFactory(factory.DictFactory): @@ -167,47 +174,47 @@ class NextBranchVersionFactory(factory.DictFactory): class BranchDataFieldFactory(factory.DictFactory): - name = factory.Faker('company') + name = factory.Faker("company") class BranchMetadataFieldFactory(factory.DictFactory): - root_id = factory.Faker('uuid4') - archived = factory.Faker('boolean') - version = factory.Faker('random_digit_not_null') + root_id = factory.Faker("uuid4") + archived = factory.Faker("boolean") + version = factory.Faker("random_digit_not_null") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) class BranchDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(BranchDataFieldFactory) metadata = factory.SubFactory(BranchMetadataFieldFactory) class BranchVersionRefFactory(factory.DictFactory): - id = factory.Faker('uuid4') - version = factory.Faker('random_digit_not_null') + id = factory.Faker("uuid4") + version = factory.Faker("random_digit_not_null") class BranchRootMetadataFieldFactory(factory.DictFactory): latest_branch_version = factory.SubFactory(BranchVersionRefFactory) - archived = factory.Faker('boolean') + archived = factory.Faker("boolean") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) class BranchRootDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(BranchDataFieldFactory) metadata = factory.SubFactory(BranchRootMetadataFieldFactory) class UserDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') - screen_name = factory.Faker('name') - position = factory.Faker('job') - email = factory.Faker('email') - is_admin = factory.Faker('boolean') + id = factory.Faker("uuid4") + screen_name = factory.Faker("name") + position = factory.Faker("job") + email = factory.Faker("email") + is_admin = factory.Faker("boolean") class GemTableDataFactory(factory.DictFactory): @@ -218,9 +225,9 @@ class GemTableDataFactory(factory.DictFactory): * table-configs/{table_config_uid_str}/gem-tables """ - id = factory.Faker('uuid4') - version = factory.Faker('random_digit_not_null') - signed_download_url = factory.Faker('uri') + id = factory.Faker("uuid4") + version = factory.Faker("random_digit_not_null") + signed_download_url = factory.Faker("uri") class ListGemTableVersionsDataFactory(factory.DictFactory): @@ -230,17 +237,20 @@ class ListGemTableVersionsDataFactory(factory.DictFactory): * gem-tables/ * gem-tables/{table_identity_id} """ + # Explicitly set version numbers so that they are distinct - tables = factory.List([ - factory.SubFactory(GemTableDataFactory, version=1), - factory.SubFactory(GemTableDataFactory, version=4), - factory.SubFactory(GemTableDataFactory, version=2), - ]) + tables = factory.List( + [ + factory.SubFactory(GemTableDataFactory, version=1), + factory.SubFactory(GemTableDataFactory, version=4), + factory.SubFactory(GemTableDataFactory, version=2), + ] + ) class RealFilterDataFactory(factory.DictFactory): type = AllRealFilter.typ - unit = 'dimensionless' + unit = "dimensionless" lower = factory.LazyAttribute(lambda o: min(0, 2 * o.upper) + random() * o.upper) upper = factory.Faker("pyfloat") @@ -253,12 +263,12 @@ class IntegerFilterDataFactory(factory.DictFactory): class CategoryFilterDataFactory(factory.DictFactory): type = NominalCategoricalFilter.typ - categories = factory.Faker('words', unique=True) + categories = factory.Faker("words", unique=True) class PropertiesCriteriaDataFactory(factory.DictFactory): type = PropertiesCriteria.typ - property_templates_filter = factory.List([factory.Faker('uuid4')]) + property_templates_filter = factory.List([factory.Faker("uuid4")]) value_type_filter = factory.SubFactory(RealFilterDataFactory) class Params: @@ -272,95 +282,101 @@ class Params: class NameCriteriaDataFactory(factory.DictFactory): type = NameCriteria.typ - name = factory.Faker('word') - search_type = factory.Faker('enum', enum_cls=TextSearchType) + name = factory.Faker("word") + search_type = factory.Faker("enum", enum_cls=TextSearchType) class MaterialRunClassificationCriteriaDataFactory(factory.DictFactory): type = MaterialRunClassificationCriteria.typ classifications = factory.Faker( - 'random_elements', + "random_elements", elements=[str(x) for x in MaterialClassification], - unique=True + unique=True, ) class MaterialTemplatesCriteriaDataFactory(factory.DictFactory): type = MaterialTemplatesCriteria.typ - material_templates_identifiers = factory.List([factory.Faker('uuid4')]) - tag_filters = factory.Faker('words', unique=True) + material_templates_identifiers = factory.List([factory.Faker("uuid4")]) + tag_filters = factory.Faker("words", unique=True) class ConnectivityClassCriteriaDataFactory(factory.DictFactory): type = ConnectivityClassCriteria.typ - is_consumed = factory.Faker('boolean') - is_produced = factory.Faker('boolean') + is_consumed = factory.Faker("boolean") + is_produced = factory.Faker("boolean") class TagsCriteriaDataFactory(factory.DictFactory): type = TagsCriteria.typ - tags = factory.Faker('words', unique=True) - filter_type = factory.Faker('enum', enum_cls=TagFilterType) + tags = factory.Faker("words", unique=True) + filter_type = factory.Faker("enum", enum_cls=TagFilterType) class AndOperatorCriteriaDataFactory(factory.DictFactory): type = AndOperator.typ - criteria = factory.List([ - factory.SubFactory(NameCriteriaDataFactory), - factory.SubFactory(MaterialRunClassificationCriteriaDataFactory), - factory.SubFactory(MaterialTemplatesCriteriaDataFactory) - ]) + criteria = factory.List( + [ + factory.SubFactory(NameCriteriaDataFactory), + factory.SubFactory(MaterialRunClassificationCriteriaDataFactory), + factory.SubFactory(MaterialTemplatesCriteriaDataFactory), + ] + ) class OrOperatorCriteriaDataFactory(factory.DictFactory): type = OrOperator.typ - criteria = factory.List([ - factory.SubFactory(PropertiesCriteriaDataFactory), - factory.SubFactory(PropertiesCriteriaDataFactory, integer=True), - factory.SubFactory(PropertiesCriteriaDataFactory, category=True), - factory.SubFactory(AndOperatorCriteriaDataFactory) - ]) + criteria = factory.List( + [ + factory.SubFactory(PropertiesCriteriaDataFactory), + factory.SubFactory(PropertiesCriteriaDataFactory, integer=True), + factory.SubFactory(PropertiesCriteriaDataFactory, category=True), + factory.SubFactory(AndOperatorCriteriaDataFactory), + ] + ) class GemdQueryDataFactory(factory.DictFactory): criteria = factory.List([factory.SubFactory(OrOperatorCriteriaDataFactory)]) - datasets = factory.List([factory.Faker('uuid4')]) + datasets = factory.List([factory.Faker("uuid4")]) object_types = factory.List([str(x) for x in GemdObjectType]) schema_version = 1 class TableConfigMainMetaDataDataFactory(factory.DictFactory): """This is the metadata for the primary definition ID of the TableConfig.""" - id = factory.Faker('uuid4') + + id = factory.Faker("uuid4") deleted = False create_time = factory.Faker("unix_milliseconds") - created_by = factory.Faker('uuid4') + created_by = factory.Faker("uuid4") update_time = factory.Faker("unix_milliseconds") - updated_by = factory.Faker('uuid4') + updated_by = factory.Faker("uuid4") class TableConfigDataFactory(factory.DictFactory): """This is simply the Blob stored in a Table Config Version.""" + name = factory.Faker("company") - description = factory.Faker('bs') + description = factory.Faker("bs") # TODO Create factories for definitions rows = [] columns = [] variables = [] - datasets = factory.List([factory.Faker('uuid4')]) + datasets = factory.List([factory.Faker("uuid4")]) gemd_query = factory.SubFactory(GemdQueryDataFactory) class TableConfigVersionMetaDataDataFactory(factory.DictFactory): ara_definition = factory.SubFactory(TableConfigDataFactory) - id = factory.Faker('uuid4') - definition_id = factory.Faker('uuid4') - version_number = factory.Faker('random_digit_not_null') + id = factory.Faker("uuid4") + definition_id = factory.Faker("uuid4") + version_number = factory.Faker("random_digit_not_null") deleted = False create_time = factory.Faker("unix_milliseconds") - created_by = factory.Faker('uuid4') + created_by = factory.Faker("uuid4") update_time = factory.Faker("unix_milliseconds") - updated_by = factory.Faker('uuid4') + updated_by = factory.Faker("uuid4") initiator = str(TableConfigInitiator.CITRINE_PYTHON) @@ -371,28 +387,34 @@ class TableConfigResponseDataFactory(factory.DictFactory): * projects/{project_id}/display-tables/{uid}/versions/{version}/definition """ + definition = factory.SubFactory(TableConfigMainMetaDataDataFactory) version = factory.SubFactory(TableConfigVersionMetaDataDataFactory) class ListTableConfigResponseDataFactory(factory.DictFactory): """This encapsulates all of the versions of a table config object.""" + definition = factory.SubFactory(TableConfigMainMetaDataDataFactory) # Explicitly set version numbers so that they are distinct - versions = factory.List([ - factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=1), - factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=4), - factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=2), - ]) + versions = factory.List( + [ + factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=1), + factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=4), + factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=2), + ] + ) class TableDataSourceDataFactory(factory.DictFactory): type = "hosted_table_data_source" table_id = factory.Faker("uuid4") - table_version = factory.Faker('random_digit_not_null') + table_version = factory.Faker("random_digit_not_null") + from citrine.informatics.data_sources import GemTableDataSource + class TableDataSourceFactory(factory.Factory): class Meta: model = GemTableDataSource @@ -436,7 +458,7 @@ class PredictorDataDataFactory(factory.DictFactory): class PredictorEntityDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(PredictorDataDataFactory) metadata = factory.SubFactory(PredictorMetadataDataFactory) @@ -458,7 +480,7 @@ class AsyncDefaultPredictorResponseDataFactory(factory.DictFactory): class AsyncDefaultPredictorResponseFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") metadata = factory.SubFactory(AsyncDefaultPredictorResponseMetadataFactory) data = factory.SubFactory(AsyncDefaultPredictorResponseDataFactory) @@ -493,9 +515,9 @@ class AreaUnderROCFactory(factory.DictFactory): class CoverageProbabilityFactory(factory.DictFactory): class Meta: - exclude = ("_level", ) + exclude = ("_level",) - _level = factory.Faker('pyfloat', max_value=1, min_value=0) + _level = factory.Faker("pyfloat", max_value=1, min_value=0) coverage_level = factory.LazyAttribute(lambda o: str(o._level)) type = "CoverageProbability" @@ -503,57 +525,50 @@ class Meta: class CrossValidationEvaluatorFactory(factory.DictFactory): name = factory.Faker("company") description = factory.Faker("catch_phrase") - responses = factory.List(3 * [factory.Faker('company')]) - n_folds = factory.Faker('random_digit_not_null') - n_trials = factory.Faker('random_digit_not_null') - metrics = factory.List([factory.SubFactory(RMSEFactory), - factory.SubFactory(NDMEFactory), - factory.SubFactory(RSquaredFactory), - factory.SubFactory(StandardRMSEFactory), - factory.SubFactory(PVALFactory), - factory.SubFactory(F1Factory), - factory.SubFactory(AreaUnderROCFactory), - factory.SubFactory(CoverageProbabilityFactory)]) + responses = factory.List(3 * [factory.Faker("company")]) + n_folds = factory.Faker("random_digit_not_null") + n_trials = factory.Faker("random_digit_not_null") + metrics = factory.List( + [ + factory.SubFactory(RMSEFactory), + factory.SubFactory(NDMEFactory), + factory.SubFactory(RSquaredFactory), + factory.SubFactory(StandardRMSEFactory), + factory.SubFactory(PVALFactory), + factory.SubFactory(F1Factory), + factory.SubFactory(AreaUnderROCFactory), + factory.SubFactory(CoverageProbabilityFactory), + ] + ) type = "CrossValidationEvaluator" -class PredictorEvaluationWorkflowFactory(factory.DictFactory): - id = factory.Faker('uuid4') - name = factory.Faker("company") - description = factory.Faker("catch_phrase") - archived = False - evaluators = factory.List([factory.SubFactory(CrossValidationEvaluatorFactory)]) - type = "PredictorEvaluationWorkflow" - # TODO Create Trait and status_detail content - status = "SUCCEEDED" - status_description = "READY" - status_detail = [] - - class PredictorEvaluationDataFactory(factory.DictFactory): evaluators = factory.List([factory.SubFactory(CrossValidationEvaluatorFactory)]) class PredictorEvaluationMetadataFactory(factory.DictFactory): class Meta: - exclude = ('is_archived', ) + exclude = ("is_archived",) created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) - archived = factory.Maybe('is_archived', factory.SubFactory(UserTimestampDataFactory), None) + archived = factory.Maybe( + "is_archived", factory.SubFactory(UserTimestampDataFactory), None + ) predictor_id = factory.Faker("uuid4") predictor_version = factory.Faker("random_digit_not_null") status = {"major": "SUCCEEDED", "minor": "READY", "detail": []} class PredictorEvaluationFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(PredictorEvaluationDataFactory) metadata = factory.SubFactory(PredictorEvaluationMetadataFactory) class DesignSpaceConfigDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") name = factory.Faker("company") descriptor = factory.Faker("catch_phrase") subspaces = [] # TODO Create SubspaceDataFactory @@ -563,7 +578,7 @@ class DesignSpaceConfigDataFactory(factory.DictFactory): class DesignSpaceDataFactory(factory.DictFactory): config = factory.SubFactory(DesignSpaceConfigDataFactory) - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") display_name = factory.Faker("company") archived = False module_type = "DESIGN_SPACE" @@ -578,28 +593,28 @@ class Params: branch = factory.SubFactory(BranchDataFactory) times = factory.List([factory.Faker("unix_milliseconds") for i in range(3)]) register = factory.Trait( - id = factory.Faker('uuid4'), - branch_id = factory.LazyAttribute(lambda o: o.branch["id"]), - created_by = factory.Faker('uuid4'), - updated_by = factory.LazyAttribute(lambda o: o.created_by), - create_time = factory.LazyAttribute(lambda o: sorted(o.times)[0]), - update_time = factory.LazyAttribute(lambda o: sorted(o.times)[0]), + id=factory.Faker("uuid4"), + branch_id=factory.LazyAttribute(lambda o: o.branch["id"]), + created_by=factory.Faker("uuid4"), + updated_by=factory.LazyAttribute(lambda o: o.created_by), + create_time=factory.LazyAttribute(lambda o: sorted(o.times)[0]), + update_time=factory.LazyAttribute(lambda o: sorted(o.times)[0]), # TODO: Create a Trait for statuses - status = "SUCCEEDED", - status_description = "READY", - status_info = [], - status_detail = [] + status="SUCCEEDED", + status_description="READY", + status_info=[], + status_detail=[], ) update = factory.Trait( - register = True, - updated_by = factory.Faker('uuid4'), - update_time = factory.LazyAttribute(lambda o: sorted(o.times)[1]) + register=True, + updated_by=factory.Faker("uuid4"), + update_time=factory.LazyAttribute(lambda o: sorted(o.times)[1]), ) archive = factory.Trait( - update = True, - archived = True, - archived_by = factory.Faker('uuid4'), - archive_time = factory.LazyAttribute(lambda o: sorted(o.times)[2]), + update=True, + archived=True, + archived_by=factory.Faker("uuid4"), + archive_time=factory.LazyAttribute(lambda o: sorted(o.times)[2]), ) type = DesignWorkflow.typ @@ -612,41 +627,39 @@ class Params: branch_root_id = factory.LazyAttribute(lambda o: o.branch["metadata"]["root_id"]) branch_version = factory.LazyAttribute(lambda o: o.branch["metadata"]["version"]) archived = False - status_description = "" # TODO: Should be None, but property not defined as Optional + status_description = ( + "" # TODO: Should be None, but property not defined as Optional + ) class IngestFilesResponseDataFactory(factory.DictFactory): - team_id = factory.Faker('uuid4') - dataset_id = factory.Faker('uuid4') - ingestion_id = factory.Faker('uuid4') + team_id = factory.Faker("uuid4") + dataset_id = factory.Faker("uuid4") + ingestion_id = factory.Faker("uuid4") class IngestionStatusResponseDataFactory(factory.DictFactory): - ingestion_id = factory.Faker('uuid4') + ingestion_id = factory.Faker("uuid4") status = IngestionStatusType.INGESTION_CREATED errors = factory.List([]) class JobSubmissionResponseDataFactory(factory.DictFactory): - job_id = factory.Faker('uuid4') + job_id = factory.Faker("uuid4") class TaskNodeDataFactory(factory.DictFactory): class Params: failure = False - id = factory.Faker('uuid4') - task_type = factory.Faker('word') + id = factory.Faker("uuid4") + task_type = factory.Faker("word") status = factory.Maybe( - "failure", - yes_declaration=JobStatus.FAILURE, - no_declaration=JobStatus.SUCCESS + "failure", yes_declaration=JobStatus.FAILURE, no_declaration=JobStatus.SUCCESS ) dependencies = factory.List([]) failure_reason = factory.Maybe( - "failure", - yes_declaration=factory.Faker('sentence'), - no_declaration=None + "failure", yes_declaration=factory.Faker("sentence"), no_declaration=None ) @@ -654,15 +667,17 @@ class JobStatusResponseDataFactory(factory.DictFactory): class Params: failure = False - job_type = factory.Faker('word') + job_type = factory.Faker("word") status = factory.Maybe( - "failure", - yes_declaration=JobStatus.FAILURE, - no_declaration=JobStatus.SUCCESS + "failure", yes_declaration=JobStatus.FAILURE, no_declaration=JobStatus.SUCCESS + ) + tasks = factory.List( + [ + factory.RelatedFactory( + TaskNodeDataFactory, failure=factory.SelfAttribute("...failure") + ) + ] ) - tasks = factory.List([ - factory.RelatedFactory(TaskNodeDataFactory, failure=factory.SelfAttribute('...failure')) - ]) output = factory.Dict({}) @@ -670,14 +685,14 @@ class DatasetDataFactory(factory.DictFactory): class Params: times = factory.List([factory.Faker("unix_milliseconds") for i in range(3)]) - id = factory.Faker('uuid4') - name = factory.Faker('company') - summary = factory.Faker('catch_phrase') - description = factory.Faker('bs') + id = factory.Faker("uuid4") + name = factory.Faker("company") + summary = factory.Faker("catch_phrase") + description = factory.Faker("bs") deleted = False - created_by = factory.Faker('uuid4') - updated_by = factory.Faker('uuid4') - deleted_by = factory.Faker('uuid4') + created_by = factory.Faker("uuid4") + updated_by = factory.Faker("uuid4") + deleted_by = factory.Faker("uuid4") unique_name = None # TODO Update tests to include unique_name create_time = factory.LazyAttribute(lambda o: sorted(o.times)[0]) update_time = factory.LazyAttribute(lambda o: sorted(o.times)[1]) @@ -686,23 +701,23 @@ class Params: class IDDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") class LinkByUIDFactory(factory.Factory): class Meta: model = LinkByUID - scope = 'id' - id = factory.Faker('uuid4') + scope = "id" + id = factory.Faker("uuid4") class FileLinkFactory(factory.Factory): class Meta: model = FileLink - url = factory.Faker('uri') - filename = factory.Faker('file_name') + url = factory.Faker("uri") + filename = factory.Faker("file_name") class ProcessTemplateFactory(factory.Factory): @@ -710,9 +725,9 @@ class Meta: model = ProcessTemplate uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) - description = factory.Faker('catch_phrase') + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) + description = factory.Faker("catch_phrase") conditions = [] # TODO make a ConditionsTemplateFactory parameters = [] # TODO make a ParametersTemplateFactory @@ -722,10 +737,10 @@ class Meta: model = MaterialTemplate uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) properties = [] # TODO make a PropertiesTemplateFactory - description = factory.Faker('catch_phrase') + description = factory.Faker("catch_phrase") class MaterialSpecFactory(factory.Factory): @@ -733,9 +748,9 @@ class Meta: model = MaterialSpec uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) - notes = factory.Faker('catch_phrase') + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) + notes = factory.Faker("catch_phrase") process = factory.SubFactory(LinkByUIDFactory) file_links = factory.List([factory.SubFactory(FileLinkFactory)]) template = factory.SubFactory(LinkByUIDFactory) @@ -747,9 +762,9 @@ class Meta: model = MaterialRun uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) - notes = factory.Faker('catch_phrase') + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) + notes = factory.Faker("catch_phrase") process = factory.SubFactory(LinkByUIDFactory) sample_type = factory.Faker("enum", enum_cls=SampleType) spec = factory.SubFactory(LinkByUIDFactory) @@ -759,13 +774,13 @@ class Meta: class LinkByUIDDataFactory(factory.DictFactory): id = LinkByUIDFactory.id scope = LinkByUIDFactory.scope - type = 'link_by_uid' + type = "link_by_uid" class FileLinkDataFactory(factory.DictFactory): url = FileLinkFactory.url filename = FileLinkFactory.filename - type = 'file_link' + type = "file_link" class MaterialSpecDataFactory(factory.DictFactory): @@ -777,7 +792,7 @@ class MaterialSpecDataFactory(factory.DictFactory): file_links = factory.List([factory.SubFactory(FileLinkDataFactory)]) template = factory.SubFactory(LinkByUIDDataFactory) properties = [] # TODO make a PropertiesDataFactory - type = 'material_spec' + type = "material_spec" class MaterialRunDataFactory(factory.DictFactory): @@ -789,16 +804,16 @@ class MaterialRunDataFactory(factory.DictFactory): sample_type = MaterialRunFactory.sample_type spec = factory.SubFactory(LinkByUIDDataFactory) file_links = factory.List([factory.SubFactory(FileLinkDataFactory)]) - type = 'material_run' + type = "material_run" class DatasetFactory(factory.Factory): class Meta: model = Dataset - name = factory.Faker('company') - summary = factory.Faker('catch_phrase') - description = factory.Faker('bs') + name = factory.Faker("company") + summary = factory.Faker("catch_phrase") + description = factory.Faker("bs") unique_name = None # TODO Update tests to include unique_name @@ -809,14 +824,14 @@ class Meta: # TODO Bring _Uploader in line with other library concepts @factory.post_generation def assign_values(obj, create, extracted): - obj.bucket = 'citrine-datasvc' - obj.object_key = '334455' - obj.upload_id = 'dea3a-555' - obj.region_name = 'us-west' - obj.aws_access_key_id = 'dkfjiejkcm' - obj.aws_secret_access_key = 'ifeemkdsfjeijie8759235u2wjr388' - obj.aws_session_token = 'fafjeijfi87834j87woa' - obj.s3_version = '2' + obj.bucket = "citrine-datasvc" + obj.object_key = "334455" + obj.upload_id = "dea3a-555" + obj.region_name = "us-west" + obj.aws_access_key_id = "dkfjiejkcm" + obj.aws_secret_access_key = "ifeemkdsfjeijie8759235u2wjr388" + obj.aws_session_token = "fafjeijfi87834j87woa" + obj.s3_version = "2" class MLIScoreFactory(factory.Factory): @@ -830,17 +845,17 @@ class Meta: class CategoricalExperimentValueDataFactory(factory.DictFactory): type = "CategoricalValue" - value = factory.Faker('company') + value = factory.Faker("company") class ChemicalFormulaExperimentValueDataFactory(factory.DictFactory): type = "InorganicValue" - value = factory.Faker('random_formula') + value = factory.Faker("random_formula") class IntegerExperimentValueDataFactory(factory.DictFactory): type = "IntegerValue" - value = factory.Faker('random_int', min=1, max=99) + value = factory.Faker("random_int", min=1, max=99) class MixtureExperimentValueDataFactory(factory.DictFactory): @@ -850,127 +865,170 @@ class MixtureExperimentValueDataFactory(factory.DictFactory): class MolecularStructureExperimentValueDataFactory(factory.DictFactory): type = "OrganicValue" - value = factory.Faker('random_smiles') + value = factory.Faker("random_smiles") class RealExperimentValueDataFactory(factory.DictFactory): type = "RealValue" - value = factory.Faker('pyfloat', min_value=0, max_value=100) + value = factory.Faker("pyfloat", min_value=0, max_value=100) class CandidateExperimentSnapshotDataFactory(factory.DictFactory): - experiment_id = factory.Faker('uuid4') - candidate_id = factory.Faker('uuid4') - workflow_id = factory.Faker('uuid4') - name = factory.Faker('company') - description = factory.Faker('company') + experiment_id = factory.Faker("uuid4") + candidate_id = factory.Faker("uuid4") + workflow_id = factory.Faker("uuid4") + name = factory.Faker("company") + description = factory.Faker("company") updated_time = factory.Faker("unix_milliseconds") # TODO Generate Experiment keys randomly but uniquely - overrides = factory.Dict({ - "ingredient1": factory.SubFactory(CategoricalExperimentValueDataFactory), - "ingredient2": factory.SubFactory(ChemicalFormulaExperimentValueDataFactory), - "ingredient3": factory.SubFactory(IntegerExperimentValueDataFactory), - "Formulation": factory.SubFactory(MixtureExperimentValueDataFactory), - "ingredient4": factory.SubFactory(MolecularStructureExperimentValueDataFactory), - "ingredient5": factory.SubFactory(RealExperimentValueDataFactory) - }) + overrides = factory.Dict( + { + "ingredient1": factory.SubFactory(CategoricalExperimentValueDataFactory), + "ingredient2": factory.SubFactory( + ChemicalFormulaExperimentValueDataFactory + ), + "ingredient3": factory.SubFactory(IntegerExperimentValueDataFactory), + "Formulation": factory.SubFactory(MixtureExperimentValueDataFactory), + "ingredient4": factory.SubFactory( + MolecularStructureExperimentValueDataFactory + ), + "ingredient5": factory.SubFactory(RealExperimentValueDataFactory), + } + ) class ExperimentDataSourceDataDataFactory(factory.DictFactory): - experiments = factory.List([factory.SubFactory(CandidateExperimentSnapshotDataFactory)]) + experiments = factory.List( + [factory.SubFactory(CandidateExperimentSnapshotDataFactory)] + ) class ExperimentDataSourceMetadataDataFactory(factory.DictFactory): - branch_root_id = factory.Faker('uuid4') - version = factory.Faker('random_digit_not_null') + branch_root_id = factory.Faker("uuid4") + version = factory.Faker("random_digit_not_null") created = factory.SubFactory(UserTimestampDataFactory) class ExperimentDataSourceDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(ExperimentDataSourceDataDataFactory) metadata = factory.SubFactory(ExperimentDataSourceMetadataDataFactory) class AnalysisPlotMetadataDataFactory(factory.DictFactory): - rank = factory.Faker('random_int', min=1, max=10) + rank = factory.Faker("random_int", min=1, max=10) created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) class AnalysisPlotDataDataFactory(factory.DictFactory): - name = factory.Faker('company') - description = factory.Faker('catch_phrase') - plot_type = factory.Faker('random_element', elements=('SCATTER', 'VIOLIN')) + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + plot_type = factory.Faker("random_element", elements=("SCATTER", "VIOLIN")) config = {} class AnalysisPlotEntityDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(AnalysisPlotDataDataFactory) metadata = factory.SubFactory(AnalysisPlotMetadataDataFactory) class LatestBuildDataFactory(factory.DictFactory): class Params: - is_failed = factory.LazyAttribute(lambda o: o.status == 'FAILED') + is_failed = factory.LazyAttribute(lambda o: o.status == "FAILED") - status = factory.Faker('random_element', elements=('INPROGRESS', 'SUCCEEDED', 'FAILED')) - failure_reason = factory.Maybe('is_failed', ['This is a test failure message'], []) + status = factory.Faker( + "random_element", elements=("INPROGRESS", "SUCCEEDED", "FAILED") + ) + failure_reason = factory.Maybe("is_failed", ["This is a test failure message"], []) query = factory.SubFactory(GemdQueryDataFactory) class AnalysisWorkflowMetadataDataFactory(factory.DictFactory): class Meta: - exclude = ('is_archived', 'has_build') + exclude = ("is_archived", "has_build") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) - archived = factory.Maybe('is_archived', factory.SubFactory(UserTimestampDataFactory), None) - latest_build = factory.Maybe('has_build', factory.SubFactory(LatestBuildDataFactory), None) + archived = factory.Maybe( + "is_archived", factory.SubFactory(UserTimestampDataFactory), None + ) + latest_build = factory.Maybe( + "has_build", factory.SubFactory(LatestBuildDataFactory), None + ) class AnalysisWorkflowDataDataFactory(factory.DictFactory): class Meta: - exclude = ('has_snapshot', 'plot_count') + exclude = ("has_snapshot", "plot_count") class Params: plot_count = 1 - name = factory.Faker('company') - description = factory.Faker('catch_phrase') - snapshot_id = factory.Maybe('has_snapshot', factory.Faker('uuid4'), None) - plots = factory.LazyAttribute(lambda self: [AnalysisPlotEntityDataFactory() for _ in range(self.plot_count)]) + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + snapshot_id = factory.Maybe("has_snapshot", factory.Faker("uuid4"), None) + plots = factory.LazyAttribute( + lambda self: [AnalysisPlotEntityDataFactory() for _ in range(self.plot_count)] + ) class AnalysisWorkflowEntityDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(AnalysisWorkflowDataDataFactory) metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory) class FeatureEffectsResponseResultFactory(factory.DictFactory): - materials = factory.List([ - factory.Faker('uuid4', cast_to=None), - factory.Faker('uuid4', cast_to=None), - factory.Faker('uuid4', cast_to=None) - ]) - outputs = factory.Dict({ - "output1": factory.Dict({ - "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) - }), - "output2": factory.Dict({ - "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]), - "feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) - }) - }) + materials = factory.List( + [ + factory.Faker("uuid4", cast_to=None), + factory.Faker("uuid4", cast_to=None), + factory.Faker("uuid4", cast_to=None), + ] + ) + outputs = factory.Dict( + { + "output1": factory.Dict( + { + "feature1": factory.List( + [ + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + ] + ) + } + ), + "output2": factory.Dict( + { + "feature1": factory.List( + [ + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + ] + ), + "feature2": factory.List( + [ + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + ] + ), + } + ), + } + ) + class FeatureEffectsMetadataFactory(factory.DictFactory): - predictor_id = factory.Faker('uuid4') - predictor_version = factory.Faker('random_digit_not_null') + predictor_id = factory.Faker("uuid4") + predictor_version = factory.Faker("random_digit_not_null") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) - status = 'SUCCEEDED' + status = "SUCCEEDED" class FeatureEffectsResponseFactory(factory.DictFactory):