diff --git a/simpeg_drivers/components/factories/entity_factory.py b/simpeg_drivers/components/factories/entity_factory.py index 6ad64801..f87abd39 100644 --- a/simpeg_drivers/components/factories/entity_factory.py +++ b/simpeg_drivers/components/factories/entity_factory.py @@ -109,7 +109,8 @@ def _build(self, inversion_data: InversionData): cells = self.params.data_object.transmitters.cells if getattr(self.params.data_object, "tx_id_property", None) is not None: - self.params.data_object.tx_id_property.copy(parent=entity) + tx_id = self.params.data_object.tx_id_property.copy(parent=entity) + entity.tx_id_property = tx_id if isinstance( self.params.data_object.transmitters, @@ -119,13 +120,25 @@ def _build(self, inversion_data: InversionData): self.params.data_object.transmitters ) - entity.transmitters = self.params.data_object.transmitters.copy( + transmitters = self.params.data_object.transmitters.copy( copy_complement=False, vertices=vertices, cells=cells, parent=self.params.out_group, + copy_children=False, ) + if ( + getattr(self.params.data_object.transmitters, "tx_id_property", None) + is not None + ): + tx_id = self.params.data_object.transmitters.tx_id_property.copy( + parent=transmitters + ) + transmitters.tx_id_property = tx_id + + entity.transmitters = transmitters + tx_freq = self.params.data_object.transmitters.get_data("Tx frequency") if tx_freq: tx_freq[0].copy(parent=entity.transmitters) diff --git a/tests/run_tests/driver_ground_tem_test.py b/tests/run_tests/driver_ground_tem_test.py index 0374ead8..354465e0 100644 --- a/tests/run_tests/driver_ground_tem_test.py +++ b/tests/run_tests/driver_ground_tem_test.py @@ -130,6 +130,7 @@ def test_ground_tem_fwr_run( fwr_driver = TimeDomainElectromagneticsDriver(params) survey.transmitters.remove_cells([15]) + survey.tx_id_property.name = "tx_id" assert fwr_driver.inversion_data.survey.source_list[0].n_segments == 16 if pytest: @@ -232,6 +233,7 @@ def test_ground_tem_run(tmp_path: Path, max_iterations=1, pytest=True): output = get_inversion_output( driver.params.geoh5.h5file, driver.params.out_group.uid ) + assert driver.inversion_data.entity.tx_id_property.name == "tx_id" output["data"] = orig_dBzdt if pytest: check_target(output, target_run, tolerance=0.5)