Skip to content

Commit fbaf0eb

Browse files
HenrZuMaxBetzDLR
andauthored
1379 Add read_graph to python bindings (#1380)
- Add bindings for the read_graph function in secir and secirvvs - Simple test to assure that write_graph and read_graph are working as intended Co-authored-by: MaxBetz <104758467+MaxBetzDLR@users.noreply.github.com>
1 parent a1d8af8 commit fbaf0eb

File tree

4 files changed

+66
-15
lines changed

4 files changed

+66
-15
lines changed

pycode/memilio-simulation/memilio/simulation/bindings/io/mobility_io.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "memilio/mobility/metapopulation_mobility_instant.h"
2424
#include "memilio/io/mobility_io.h"
25+
#include "pybind_util.h"
2526

2627
#include "pybind11/pybind11.h"
2728
#include <cstddef>
@@ -32,11 +33,27 @@ namespace pymio
3233
template <class Model>
3334
void bind_write_graph(pybind11::module_& m)
3435
{
35-
m.def("write_graph",
36-
[&](const mio::Graph<Model, mio::MobilityParameters<double>>& graph, const std::string& directory) {
37-
int ioflags = mio::IOF_None;
38-
auto ioresult = mio::write_graph<double, Model>(graph, directory, ioflags);
39-
});
36+
m.def(
37+
"write_graph",
38+
[&](const mio::Graph<Model, mio::MobilityParameters<double>>& graph, const std::string& directory) {
39+
int ioflags = mio::IOF_None;
40+
auto ioresult = mio::write_graph<double, Model>(graph, directory, ioflags);
41+
},
42+
"Write a graph (nodes and edges) as JSON files to the given directory.", pybind11::arg("graph"),
43+
pybind11::arg("directory"));
44+
}
45+
46+
template <class Model>
47+
void bind_read_graph(pybind11::module_& m)
48+
{
49+
m.def(
50+
"read_graph",
51+
[&](const std::string& directory) {
52+
auto result = mio::read_graph<double, Model>(directory, 0, true);
53+
return pymio::check_and_throw(result);
54+
},
55+
"Read a graph from JSON files in the given directory (see write_graph).", pybind11::arg("directory"),
56+
pybind11::return_value_policy::move);
4057
}
4158

4259
} // namespace pymio

pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,11 @@ PYBIND11_MODULE(_simulation_osecir, m)
287287
mio::osecir::InfectionState::InfectedSymptoms, mio::osecir::InfectionState::Recovered};
288288
auto weights = std::vector<ScalarType>{0., 0., 1.0, 1.0, 0.33, 0., 0.};
289289
auto result = mio::set_edges<double, // FP
290-
ContactLocation, mio::osecir::Model<double>, mio::MobilityParameters<double>,
291-
mio::MobilityCoefficientGroup<double>, mio::osecir::InfectionState,
292-
decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph,
293-
mobile_comp, contact_locations_size,
294-
mio::read_mobility_plain, weights);
290+
ContactLocation, mio::osecir::Model<double>, mio::MobilityParameters<double>,
291+
mio::MobilityCoefficientGroup<double>, mio::osecir::InfectionState,
292+
decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph,
293+
mobile_comp, contact_locations_size,
294+
mio::read_mobility_plain, weights);
295295
return pymio::check_and_throw(result);
296296
},
297297
py::return_value_policy::move);
@@ -302,6 +302,7 @@ PYBIND11_MODULE(_simulation_osecir, m)
302302

303303
#ifdef MEMILIO_HAS_JSONCPP
304304
pymio::bind_write_graph<mio::osecir::Model<double>>(m);
305+
pymio::bind_read_graph<mio::osecir::Model<double>>(m);
305306
m.def(
306307
"read_input_data_county",
307308
[](std::vector<mio::osecir::Model<double>>& model, mio::Date date, const std::vector<int>& county,
@@ -314,8 +315,9 @@ PYBIND11_MODULE(_simulation_osecir, m)
314315
py::return_value_policy::move);
315316
#endif // MEMILIO_HAS_JSONCPP
316317

317-
m.def("interpolate_simulation_result", py::overload_cast<const MobilityGraph&>(
318-
&mio::interpolate_simulation_result<double, mio::osecir::Simulation<double>>));
318+
m.def("interpolate_simulation_result",
319+
py::overload_cast<const MobilityGraph&>(
320+
&mio::interpolate_simulation_result<double, mio::osecir::Simulation<double>>));
319321

320322
m.def("interpolate_ensemble_results", &mio::interpolate_ensemble_results<MobilityGraph>);
321323

pycode/memilio-simulation/memilio/simulation/bindings/models/osecirvvs.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ PYBIND11_MODULE(_simulation_osecirvvs, m)
340340
mio::osecirvvs::InfectionState::InfectedSymptomsImprovedImmunity};
341341
auto weights = std::vector<ScalarType>{0., 0., 1.0, 1.0, 0.33, 0., 0.};
342342
auto result = mio::set_edges<double, // FP,
343-
ContactLocation, mio::osecirvvs::Model<double>,
344-
mio::MobilityParameters<double>, mio::MobilityCoefficientGroup<double>,
345-
mio::osecirvvs::InfectionState, decltype(mio::read_mobility_plain)>(
343+
ContactLocation, mio::osecirvvs::Model<double>,
344+
mio::MobilityParameters<double>, mio::MobilityCoefficientGroup<double>,
345+
mio::osecirvvs::InfectionState, decltype(mio::read_mobility_plain)>(
346346
mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain,
347347
weights);
348348
return pymio::check_and_throw(result);
@@ -355,6 +355,7 @@ PYBIND11_MODULE(_simulation_osecirvvs, m)
355355

356356
#ifdef MEMILIO_HAS_JSONCPP
357357
pymio::bind_write_graph<mio::osecirvvs::Model<double>>(m);
358+
pymio::bind_read_graph<mio::osecirvvs::Model<double>>(m);
358359
m.def(
359360
"read_input_data_county",
360361
[](std::vector<mio::osecirvvs::Model<double>>& model, mio::Date date, const std::vector<int>& county,

pycode/memilio-simulation/memilio/simulation_test/test_mobility.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# limitations under the License.
1919
#############################################################################
2020
import unittest
21+
import tempfile
2122

2223
import numpy as np
2324

@@ -75,6 +76,36 @@ def test_mobility_sim(self):
7576
self.assertGreaterEqual(sim.graph.get_node(
7677
0).property.result.get_num_time_points(), 3)
7778

79+
def test_write_read_graph_simple(self):
80+
# build a simple model graph
81+
model = osecir.Model(1)
82+
model.parameters.TestAndTraceCapacity.value = 42
83+
model.apply_constraints()
84+
85+
graph = osecir.ModelGraph()
86+
graph.add_node(0, model)
87+
graph.add_node(1, model)
88+
89+
num_compartments = 10
90+
graph.add_edge(0, 1, 0.1 * np.ones(num_compartments))
91+
graph.add_edge(1, 0, 0.1 * np.ones(num_compartments))
92+
93+
with tempfile.TemporaryDirectory() as tmpdir:
94+
# save graph
95+
osecir.write_graph(graph, tmpdir)
96+
# read graph back
97+
g_read = osecir.read_graph(tmpdir)
98+
99+
# basic structure
100+
self.assertEqual(graph.num_nodes, g_read.num_nodes)
101+
self.assertEqual(graph.num_edges, g_read.num_edges)
102+
103+
# check one parameter
104+
self.assertEqual(
105+
g_read.get_node(0).property.parameters.TestAndTraceCapacity.value,
106+
42,
107+
)
108+
78109

79110
if __name__ == '__main__':
80111
unittest.main()

0 commit comments

Comments
 (0)