-
Notifications
You must be signed in to change notification settings - Fork 20
1393 improve parameters io for graph ode based models #1430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1959d3b
664689f
561f657
a1d73a7
d7ce1f1
c553af3
511956b
d054a5e
0d2adcf
db7d684
0e0c23a
3b3c628
9461e4d
bfe89f2
356b84d
d7cf12c
e5de2e2
4e58516
460db72
15beb32
cade4eb
50e5d6f
d188bc9
16e43e8
adcee20
28dc354
c348614
d6f9906
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -103,7 +103,60 @@ get_holidays(StateId state); | |||||
| Range<std::pair<std::vector<std::pair<Date, Date>>::const_iterator, std::vector<std::pair<Date, Date>>::const_iterator>> | ||||||
| get_holidays(StateId state, Date start_date, Date end_date); | ||||||
|
|
||||||
| namespace de | ||||||
| { | ||||||
|
|
||||||
| struct EpidataFilenames { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a documentation above the struct |
||||||
| private: | ||||||
| EpidataFilenames(std::string& pydata) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
same for the others below |
||||||
| : population_data_path(mio::path_join(pydata, "county_current_population.json")) | ||||||
| { | ||||||
| } | ||||||
|
|
||||||
| public: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. overall, this is all very static. At least the moving average should be adaptive. |
||||||
| static EpidataFilenames county(std::string& pydata) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a function that validates all path's and check which one are missing would be great here. But optional (only if you want to add this) |
||||||
| { | ||||||
| EpidataFilenames s(pydata); | ||||||
|
|
||||||
| s.case_data_path = mio::path_join(pydata, "cases_all_county_age_ma7.json"); | ||||||
| s.divi_data_path = mio::path_join(pydata, "county_divi_ma7.json"); | ||||||
| s.vaccination_data_path = mio::path_join(pydata, "vacc_county_ageinf_ma7.json"); | ||||||
|
|
||||||
| return s; | ||||||
| } | ||||||
|
|
||||||
| static EpidataFilenames states(std::string& pydata) | ||||||
| { | ||||||
| EpidataFilenames s(pydata); | ||||||
|
|
||||||
| s.case_data_path = mio::path_join(pydata, "cases_all_state_age_ma7.json"); | ||||||
| s.divi_data_path = mio::path_join(pydata, "state_divi_ma7.json"); | ||||||
| s.vaccination_data_path = mio::path_join(pydata, "vacc_state_ageinf_ma7.json"); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the vaccination data is the same for states and county. intended? if yes, add a comment. |
||||||
|
|
||||||
| return s; | ||||||
| } | ||||||
|
|
||||||
| static EpidataFilenames country(std::string& pydata) | ||||||
| { | ||||||
| EpidataFilenames s(pydata); | ||||||
|
|
||||||
| s.case_data_path = mio::path_join(pydata, "cases_all_age_ma7.json"); | ||||||
| s.divi_data_path = mio::path_join(pydata, "germany_divi_ma7.json"); | ||||||
| s.vaccination_data_path = mio::path_join(pydata, "vacc_ageinf_ma7.json"); | ||||||
|
|
||||||
| return s; | ||||||
| } | ||||||
|
|
||||||
| std::string population_data_path; | ||||||
| std::string case_data_path; | ||||||
| std::string divi_data_path; | ||||||
| std::string vaccination_data_path; | ||||||
| }; | ||||||
|
|
||||||
| } // namespace de | ||||||
|
|
||||||
| } // namespace regions | ||||||
|
|
||||||
| } // namespace mio | ||||||
|
|
||||||
| #endif //MIO_EPI_REGIONS_H | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| */ | ||
|
|
||
| #include "memilio/config.h" | ||
| #include "memilio/io/parameters_io.h" | ||
|
|
||
| #ifdef MEMILIO_HAS_JSONCPP | ||
|
|
||
|
|
@@ -30,6 +31,45 @@ | |
|
|
||
| namespace mio | ||
| { | ||
| IOResult<std::vector<ScalarType>> compute_divi_data(const std::vector<DiviEntry>& divi_data, | ||
| const std::vector<int>& vregion, Date date) | ||
| { | ||
|
|
||
| auto max_date_entry = std::max_element(divi_data.begin(), divi_data.end(), [](auto&& a, auto&& b) { | ||
| return a.date < b.date; | ||
| }); | ||
| if (max_date_entry == divi_data.end()) { | ||
| log_error("DIVI data is empty."); | ||
| return failure(StatusCode::InvalidValue, "DIVI data is empty."); | ||
| } | ||
| auto max_date = max_date_entry->date; | ||
| if (max_date < date) { | ||
| log_error("DIVI data does not contain the specified date."); | ||
| return failure(StatusCode::OutOfRange, "DIVI data does not contain the specified date."); | ||
| } | ||
|
|
||
| std::vector<ScalarType> vnum_icu(vregion.size(), 0.0); | ||
|
|
||
| for (auto&& entry : divi_data) { | ||
| auto it = std::find_if(vregion.begin(), vregion.end(), [&entry](auto r) { | ||
| return r == 0 || r == get_region_id(entry); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment here, that the 0 represents a state |
||
| }); | ||
| auto date_df = entry.date; | ||
| if (it != vregion.end() && date_df == date) { | ||
| auto region_idx = size_t(it - vregion.begin()); | ||
| vnum_icu[region_idx] = entry.num_icu; | ||
| } | ||
| } | ||
|
|
||
| return success(vnum_icu); | ||
| } | ||
|
|
||
| IOResult<std::vector<ScalarType>> read_divi_data(const std::string& path, const std::vector<int>& vregion, Date date) | ||
| { | ||
| BOOST_OUTCOME_TRY(auto&& divi_data, mio::read_divi_data(path)); | ||
| return compute_divi_data(divi_data, vregion, date); | ||
| } | ||
|
|
||
| IOResult<std::vector<std::vector<ScalarType>>> | ||
| read_population_data(const std::vector<PopulationDataEntry>& population_data, const std::vector<int>& vregion) | ||
| { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -52,67 +52,26 @@ int get_region_id(const EpiDataEntry& data_entry) | |||||
|
|
||||||
| /** | ||||||
| * @brief Extracts the number of individuals in critical condition (ICU) for each region | ||||||
| * on a specified date from the provided DIVI data. | ||||||
| * | ||||||
| * @tparam FP Floating point type (default: double). | ||||||
| * on a specified date from the provided DIVI data- | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| * | ||||||
| * @param[in] divi_data Vector of DIVI data entries containing date, region, and ICU information. | ||||||
| * @param[in] vregion Vector of region IDs for which the data is computed. | ||||||
| * @param[in] date Date for which the ICU data is computed. | ||||||
| * @param[in, out] vnum_icu Output vector containing the number of ICU cases for each region. | ||||||
| * | ||||||
| * @return An IOResult indicating success or failure. | ||||||
| * @return An IOResult containing a vector with the number of ICU cases for each region, or an | ||||||
| * error if the function fails. | ||||||
| */ | ||||||
| template <typename FP = ScalarType> | ||||||
| IOResult<void> compute_divi_data(const std::vector<DiviEntry>& divi_data, const std::vector<int>& vregion, Date date, | ||||||
| std::vector<FP>& vnum_icu) | ||||||
| { | ||||||
| auto max_date_entry = std::max_element(divi_data.begin(), divi_data.end(), [](auto&& a, auto&& b) { | ||||||
| return a.date < b.date; | ||||||
| }); | ||||||
| if (max_date_entry == divi_data.end()) { | ||||||
| log_error("DIVI data is empty."); | ||||||
| return failure(StatusCode::InvalidValue, "DIVI data is empty."); | ||||||
| } | ||||||
| auto max_date = max_date_entry->date; | ||||||
| if (max_date < date) { | ||||||
| log_error("DIVI data does not contain the specified date."); | ||||||
| return failure(StatusCode::OutOfRange, "DIVI data does not contain the specified date."); | ||||||
| } | ||||||
|
|
||||||
| for (auto&& entry : divi_data) { | ||||||
| auto it = std::find_if(vregion.begin(), vregion.end(), [&entry](auto r) { | ||||||
| return r == 0 || r == get_region_id(entry); | ||||||
| }); | ||||||
| auto date_df = entry.date; | ||||||
| if (it != vregion.end() && date_df == date) { | ||||||
| auto region_idx = size_t(it - vregion.begin()); | ||||||
| vnum_icu[region_idx] = entry.num_icu; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| return success(); | ||||||
| } | ||||||
| IOResult<std::vector<ScalarType>> compute_divi_data(const std::vector<DiviEntry>& divi_data, const std::vector<int>& vregion, Date date); | ||||||
|
|
||||||
| /** | ||||||
| * @brief Reads DIVI data from a file and computes the ICU data for specified regions and date. | ||||||
| * | ||||||
| * @tparam FP Floating point type (default: double). | ||||||
| * | ||||||
| * @param[in] path Path to the file containing DIVI data. | ||||||
| * @param[in] vregion Vector of region IDs for which the data is computed. | ||||||
| * @param[in] date Date for which the ICU data is computed. | ||||||
| * @param[in, out] vnum_icu Output vector containing the number of ICU cases for each region. | ||||||
| * | ||||||
| * @return An IOResult indicating success or failure. | ||||||
| * @return An IOResult containing a vector with the number of ICU cases for each region, or an | ||||||
| * error if the function fails. | ||||||
| */ | ||||||
| template <typename FP = ScalarType> | ||||||
| IOResult<void> read_divi_data(const std::string& path, const std::vector<int>& vregion, Date date, | ||||||
| std::vector<FP>& vnum_icu) | ||||||
| { | ||||||
| BOOST_OUTCOME_TRY(auto&& divi_data, mio::read_divi_data(path)); | ||||||
| return compute_divi_data(divi_data, vregion, date, vnum_icu); | ||||||
| } | ||||||
| IOResult<std::vector<ScalarType>> read_divi_data(const std::string& path, const std::vector<int>& vregion, Date date); | ||||||
|
|
||||||
| /** | ||||||
| * @brief Reads population data from a vector of population data entries. | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -305,88 +305,73 @@ class Graph | |
| private: | ||
| std::vector<Node<NodePropertyT>> m_nodes; | ||
| std::vector<Edge<EdgePropertyT>> m_edges; | ||
| }; // namespace mio | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Sets the graph nodes for counties or districts. | ||
| * Reads the node ids which could refer to districts or counties and the epidemiological | ||
| * data from json files and creates one node for each id. Every node contains a model. | ||
| * @param[in] params Model Parameters that are used for every node. | ||
| * @param[in] start_date Start date for which the data should be read. | ||
| * @param[in] end_data End date for which the data should be read. | ||
| * @param[in] data_dir Directory that contains the data files. | ||
| * @param[in] population_data_path Path to json file containing the population data. | ||
| * @param[in] is_node_for_county Specifies whether the node ids should be county ids (true) or district ids (false). | ||
| * @param[in, out] params_graph Graph whose nodes are set by the function. | ||
| * @param[in] read_func Function that reads input data for german counties and sets Model compartments. | ||
| * @param[in] node_func Function that returns the county ids. | ||
| * @param[in] scaling_factor_inf Factor of confirmed cases to account for undetected cases in each county. | ||
| * @param[in] scaling_factor_icu Factor of ICU cases to account for underreporting. | ||
| * @brief Set test and trace capacity with uncertainty for the given models. | ||
| * | ||
| * @param[in, out] nodes VectorRange of Node%s each containing a Model in which the data is set. | ||
| * @param[in] tnt_capacity_factor Factor for test and trace capacity. | ||
| * @param[in] num_days Number of days to be simulated; required to load data for vaccinations during the simulation. | ||
| * @param[in] export_time_series If true, reads data for each day of simulation and writes it in the same directory as the input files. | ||
| * @param[in] rki_age_groups Specifies whether rki-age_groups should be used. | ||
| */ | ||
| template <typename FP, class TestAndTrace, class ContactPattern, class Model, class MobilityParams, class Parameters, | ||
| class ReadFunction, class NodeIdFunction> | ||
| IOResult<void> set_nodes(const Parameters& params, Date start_date, Date end_date, const fs::path& data_dir, | ||
| const std::string& population_data_path, bool is_node_for_county, | ||
| Graph<Model, MobilityParams>& params_graph, ReadFunction&& read_func, | ||
| NodeIdFunction&& node_func, const std::vector<FP>& scaling_factor_inf, FP scaling_factor_icu, | ||
| FP tnt_capacity_factor, int num_days = 0, bool export_time_series = false, | ||
| bool rki_age_groups = true) | ||
|
|
||
| template <class Model, class TestAndTrace> | ||
| void set_test_and_trace_capacity(const mio::VectorRange<Node<Model>>& nodes, ScalarType tnt_capacity_factor) | ||
| { | ||
| BOOST_OUTCOME_TRY(auto&& node_ids, node_func(population_data_path, is_node_for_county, rki_age_groups)); | ||
| std::vector<Model> nodes(node_ids.size(), Model(int(size_t(params.get_num_groups())))); | ||
| for (auto& node : nodes) { | ||
| node.parameters = params; | ||
| } | ||
|
|
||
| BOOST_OUTCOME_TRY(read_func(nodes, start_date, node_ids, scaling_factor_inf, scaling_factor_icu, data_dir.string(), | ||
| num_days, export_time_series)); | ||
|
|
||
| for (size_t node_idx = 0; node_idx < nodes.size(); ++node_idx) { | ||
|
|
||
| auto tnt_capacity = nodes[node_idx].populations.get_total() * tnt_capacity_factor; | ||
| auto tnt_capacity = nodes[node_idx].property.populations.get_total() * tnt_capacity_factor; | ||
|
|
||
| //local parameters | ||
| auto& tnt_value = nodes[node_idx].parameters.template get<TestAndTrace>(); | ||
| tnt_value = UncertainValue<FP>(0.5 * (1.2 * tnt_capacity + 0.8 * tnt_capacity)); | ||
| auto& tnt_value = nodes[node_idx].property.parameters.template get<TestAndTrace>(); | ||
| tnt_value = UncertainValue<ScalarType>(tnt_capacity); | ||
| tnt_value.set_distribution(mio::ParameterDistributionUniform(0.8 * tnt_capacity, 1.2 * tnt_capacity)); | ||
| } | ||
| } | ||
|
|
||
| auto id = 0; | ||
| if (is_node_for_county) { | ||
| id = int(regions::CountyId(node_ids[node_idx])); | ||
| } | ||
| else { | ||
| id = int(regions::DistrictId(node_ids[node_idx])); | ||
| } | ||
| //holiday periods | ||
| auto holiday_periods = regions::get_holidays(regions::get_state_id(id), start_date, end_date); | ||
| auto& contacts = nodes[node_idx].parameters.template get<ContactPattern>(); | ||
| /** | ||
| * @brief Set german state holidays for the given nodes. | ||
| * | ||
| * Works for nodes of a graph depicting german counties or states. | ||
| * | ||
| * @param[in, out] nodes VectorRange of Node%s each containing a Model in which the data is set. | ||
| * @param[in] start_date Date at the beginning of the simulation. | ||
| * @param[in] end_date Date at the end of the simulation. | ||
| */ | ||
| template <class FP, class Model, class ContactPattern> | ||
| void set_german_holidays(const mio::VectorRange<Node<Model>>& nodes, const mio::Date& start_date, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in generell, we should try to keep this generic and add a overall function set_holidays which allows a country specifier e.g., for Germany. |
||
| const mio::Date& end_date) | ||
| { | ||
| for (size_t node_idx = 0; node_idx < nodes.size(); ++node_idx) { | ||
| auto state_id = regions::get_state_id(nodes[node_idx].id); | ||
| auto holiday_periods = regions::get_holidays(state_id, start_date, end_date); | ||
|
|
||
| auto& contacts = nodes[node_idx].property.parameters.template get<ContactPattern>(); | ||
| contacts.get_school_holidays() = | ||
| std::vector<std::pair<mio::SimulationTime<FP>, mio::SimulationTime<FP>>>(holiday_periods.size()); | ||
| std::transform( | ||
| holiday_periods.begin(), holiday_periods.end(), contacts.get_school_holidays().begin(), [=](auto& period) { | ||
| return std::make_pair(mio::SimulationTime<FP>(mio::get_offset_in_days(period.first, start_date)), | ||
| mio::SimulationTime<FP>(mio::get_offset_in_days(period.second, start_date))); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| //uncertainty in populations | ||
| for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) { | ||
| /** | ||
| * @brief Add uncertainty to the population of the given nodes. | ||
| * | ||
| * @param[in, out] nodes VectorRange of Node%s each containing a Model in which the data is set. | ||
| */ | ||
| template <class Model> | ||
| void set_uncertainty_on_population(const mio::VectorRange<Node<Model>>& nodes) | ||
| { | ||
| for (size_t node_idx = 0; node_idx < nodes.size(); ++node_idx) { | ||
| for (auto i = mio::AgeGroup(0); i < nodes[0].property.parameters.get_num_groups(); i++) { | ||
| for (auto j = Index<typename Model::Compartments>(0); j < Model::Compartments::Count; ++j) { | ||
| auto& compartment_value = nodes[node_idx].populations[{i, j}]; | ||
| compartment_value = | ||
| UncertainValue<FP>(0.5 * (1.1 * compartment_value.value() + 0.9 * compartment_value.value())); | ||
| auto& compartment_value = nodes[node_idx].property.populations[{i, j}]; | ||
| compartment_value = UncertainValue<ScalarType>(compartment_value.value()); | ||
| compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * compartment_value.value(), | ||
| 1.1 * compartment_value.value())); | ||
| } | ||
| } | ||
|
|
||
| params_graph.add_node(node_ids[node_idx], nodes[node_idx]); | ||
| } | ||
| return success(); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -426,7 +411,7 @@ IOResult<void> set_edges(const fs::path& mobility_data_file, Graph<Model, Mobili | |
| commuting_weights = | ||
| (commuting_weights.size() == 0 ? std::vector<FP>(num_age_groups, 1.0) : commuting_weights); | ||
| //commuters | ||
| auto working_population = 0.0; | ||
| FP working_population = 0.0; | ||
| for (auto age = AgeGroup(0); age < populations.template size<mio::AgeGroup>(); ++age) { | ||
| working_population += populations.get_group_total(age) * commuting_weights[size_t(age)]; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.