-
Notifications
You must be signed in to change notification settings - Fork 24
1503 add python bindings for ABM #1515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
52163a8
4f49e04
e18e5b2
cc3ddb2
efd5e65
5851eec
7e352b1
9e7c6aa
ef29c59
cf15f80
43cb996
10607ae
d442c86
1b364cf
8f6a2fb
7cfe952
c30c597
0521f44
cc1171d
b8fdc58
ee1ecc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,10 +27,17 @@ | |
| #include "parameter_distributions.h" | ||
| #include <memory> | ||
| #include <string> | ||
| #include <concepts> | ||
|
|
||
| namespace mio | ||
| { | ||
|
|
||
| template <class T> | ||
| concept HasSampleFunction = requires(T t) { | ||
| { t.get_sample(std::declval<RandomNumberGenerator&>()) } -> std::convertible_to<ScalarType>; | ||
| { t.get_sample(std::declval<abm::PersonalRandomNumberGenerator&>()) } -> std::convertible_to<ScalarType>; | ||
| }; | ||
|
Comment on lines
+35
to
+38
|
||
|
|
||
| /** | ||
| * @brief This class represents an arbitrary ParameterDistribution. | ||
| * @see mio::ParameterDistribution | ||
|
|
@@ -44,7 +51,7 @@ class AbstractParameterDistribution | |
| * The implementation handed to the constructor should have get_sample function | ||
| * overloaded with mio::RandomNumberGenerator and mio::abm::PersonalRandomNumberGenerator as input arguments | ||
| */ | ||
| template <class Impl> | ||
| template <HasSampleFunction Impl> | ||
| AbstractParameterDistribution(Impl&& dist) | ||
| : m_dist(std::make_shared<Impl>(std::move(dist))) | ||
| , sample_impl1([](void* d, RandomNumberGenerator& rng) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| ############################################################################# | ||
| # Copyright (C) 2020-2026 MEmilio | ||
| # | ||
| # Authors: Carlotta Gerstein | ||
| # | ||
| # Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de> | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| ############################################################################# | ||
|
|
||
| from memilio.simulation import AgeGroup | ||
| import memilio.simulation.abm as abm | ||
| import memilio.simulation as mio | ||
|
|
||
| import numpy as np | ||
| import random | ||
|
|
||
| num_age_groups = 4 | ||
|
|
||
| model = abm.Model(num_age_groups) | ||
|
|
||
| # Set parameters | ||
|
|
||
| for age_group in range(num_age_groups): | ||
| model.parameters.TimeExposedToNoSymptoms[abm.VirusVariant.Wildtype, AgeGroup(age_group)] = mio.AbstractParameterDistribution(mio.ParameterDistributionLogNormal( | ||
| 4., 1.)) | ||
|
|
||
| model.parameters.AgeGroupGotoSchool[AgeGroup(age_group)] = False | ||
| model.parameters.AgeGroupGotoWork[AgeGroup(age_group)] = False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As the default is false, these line could be removed, unless you have an intention to explicitly write them down.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true, removed also from the cpp version |
||
|
|
||
| model.parameters.AgeGroupGotoSchool[AgeGroup(1)] = True | ||
| model.parameters.AgeGroupGotoWork[AgeGroup(2)] = True | ||
| model.parameters.AgeGroupGotoWork[AgeGroup(3)] = True | ||
|
|
||
| for age in range(num_age_groups): | ||
| model.parameters.InfectionProtectionFactor[abm.ProtectionType.GenericVaccine, AgeGroup( | ||
| age), abm.VirusVariant.Wildtype] = mio.TimeSeriesFunctor( | ||
| [[0, 0.0], [14, 0.67], [180, 0.4]]) | ||
|
|
||
| model.parameters.SeverityProtectionFactor[abm.ProtectionType.GenericVaccine, AgeGroup( | ||
| age), abm.VirusVariant.Wildtype] = mio.TimeSeriesFunctor( | ||
| [[0, 0.0], [14, 0.85], [180, 0.7]]) | ||
|
Comment on lines
+42
to
+49
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is not in the .cpp file. We should probably stay as close as possible to the .cpp example. If we want to add this, then we should also add it in the .cpp example.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the cpp version this is divided into the abm_minimal and the abm_vaccination example. Agree that we should probably stay close, so either merge them in the cpp version or split it here. Do you have a favorite solution? |
||
|
|
||
| model.parameters.check_constraints() | ||
|
|
||
| # Set populations | ||
|
|
||
| n_households = 10 | ||
|
|
||
| child = abm.HouseholdMember(num_age_groups) | ||
| child.age_weights[AgeGroup(0)] = 1. | ||
| child.age_weights[AgeGroup(1)] = 1. | ||
|
|
||
| parent = abm.HouseholdMember(num_age_groups) | ||
| parent.age_weights[AgeGroup(2)] = 1. | ||
| parent.age_weights[AgeGroup(3)] = 1. | ||
|
|
||
| twoPersonHousehold_group = abm.HouseholdGroup() | ||
| twoPersonHousehold_full = abm.Household() | ||
| twoPersonHousehold_full.add_members(child, 1) | ||
| twoPersonHousehold_full.add_members(parent, 1) | ||
| twoPersonHousehold_group.add_households(twoPersonHousehold_full, n_households) | ||
| abm.add_household_group_to_model(model, twoPersonHousehold_group) | ||
|
|
||
| threePersonHousehold_group = abm.HouseholdGroup() | ||
| threePersonHousehold_full = abm.Household() | ||
| threePersonHousehold_full.add_members(child, 1) | ||
| threePersonHousehold_full.add_members(parent, 2) | ||
| threePersonHousehold_group.add_households( | ||
| threePersonHousehold_full, n_households) | ||
| abm.add_household_group_to_model(model, threePersonHousehold_group) | ||
|
|
||
| # Set locations | ||
|
|
||
| event = model.add_location(abm.LocationType.SocialEvent) | ||
| model.get_location(event).infection_parameters.MaximumContacts = 5 | ||
|
|
||
| hospital = model.add_location(abm.LocationType.Hospital) | ||
| model.get_location(hospital).infection_parameters.MaximumContacts = 5 | ||
| icu = model.add_location(abm.LocationType.ICU) | ||
| model.get_location(icu).infection_parameters.MaximumContacts = 5 | ||
|
|
||
| shop = model.add_location(abm.LocationType.BasicsShop) | ||
| model.get_location(shop).infection_parameters.MaximumContacts = 20 | ||
|
|
||
| school = model.add_location(abm.LocationType.School) | ||
| model.get_location(school).infection_parameters.MaximumContacts = 20 | ||
|
|
||
| work = model.add_location(abm.LocationType.Work) | ||
| model.get_location(work).infection_parameters.MaximumContacts = 20 | ||
|
|
||
| model.parameters.AerosolTransmissionRates[abm.VirusVariant.Wildtype] = 10 | ||
|
|
||
| contacts = np.zeros((num_age_groups, num_age_groups)) | ||
| contacts[2, 3] = 10 | ||
|
|
||
| model.get_location( | ||
| work).infection_parameters.ContactRates.baseline = contacts | ||
|
|
||
| # Testing Schemes | ||
|
|
||
| validity_period = abm.days(1) | ||
| probability = 0.5 | ||
| start_date = abm.TimePoint(0) | ||
| end_date = abm.TimePoint(0) + abm.days(10) | ||
| test_type = abm.TestType.Antigen | ||
| test_parameters = model.parameters.TestData[test_type] | ||
|
|
||
| testing_criteria_work = abm.TestingCriteria() | ||
| testing_scheme_work = abm.TestingScheme( | ||
| testing_criteria_work, validity_period, start_date, end_date, test_parameters, probability) | ||
|
|
||
| model.testing_strategy.add_scheme( | ||
| abm.LocationType.Work, testing_scheme_work) | ||
|
|
||
| # Seed infections | ||
|
|
||
| infection_distribution = [0.5, 0.3, 0.05, 0.05, 0.05, 0.05, 0.0, 0.0] | ||
| rng = np.random.default_rng() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the availability of the MEmilio RNG / Personal RNG we should use that one here instead of the numpy rng. The discrete distribution would be suitable.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exchanged |
||
| for person in model.persons: | ||
| infection_state = abm.InfectionState(rng.choice( | ||
| len(infection_distribution), p=infection_distribution)) | ||
|
|
||
| if infection_state != abm.InfectionState.Susceptible: | ||
| person.add_new_infection(model, abm.VirusVariant.Wildtype, | ||
| person.age, model.parameters, start_date, infection_state) | ||
|
|
||
| # Assign locations | ||
|
|
||
| for person in model.persons: | ||
| id = person.id | ||
|
|
||
| model.assign_location(id, event) | ||
| model.assign_location(id, shop) | ||
|
|
||
| model.assign_location(id, hospital) | ||
| model.assign_location(id, icu) | ||
|
|
||
| if person.age == AgeGroup(1): | ||
| model.assign_location(id, school) | ||
|
|
||
| if person.age == AgeGroup(2) or person.age == AgeGroup(3): | ||
| model.assign_location(id, work) | ||
|
charlie0614 marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Vaccinations | ||
|
|
||
| vacc_rate = 0.7 | ||
| vaccination_priority = [AgeGroup(3), AgeGroup(2), AgeGroup(1)] | ||
| vaccination_time = start_date - abm.days(20) | ||
|
|
||
| persons_by_age = [[] for _ in range(num_age_groups)] | ||
| for idx, person in enumerate(model.persons): | ||
| persons_by_age[person.age.get()].append(idx) | ||
|
|
||
| for age in vaccination_priority: | ||
| indices = persons_by_age[age.get()] | ||
|
|
||
| random.shuffle(indices) | ||
|
|
||
| temp = vacc_rate * len(indices) | ||
| n_to_vaccinate = int(np.round(vacc_rate * len(indices))) | ||
|
|
||
| count = 0 | ||
|
charlie0614 marked this conversation as resolved.
Outdated
|
||
| for i in range(n_to_vaccinate): | ||
| person = model.persons[indices[i]] | ||
| if person.get_infection_state(vaccination_time) == abm.InfectionState.Susceptible: | ||
| person.add_new_vaccination( | ||
| abm.ProtectionType.GenericVaccine, vaccination_time) | ||
|
|
||
| # Simulate | ||
|
|
||
| t_lockdown = start_date + abm.days(10) | ||
| abm.close_social_events(t_lockdown, 0.9, model.parameters) | ||
|
|
||
| t0 = start_date | ||
| tmax = t0 + abm.days(10) | ||
| sim = abm.Simulation(t0, model) | ||
|
|
||
|
|
||
| history = abm.TimeSeriesWriterLogInfectionStateHistory( | ||
| mio.TimeSeries(len(abm.InfectionState.values()))) | ||
|
|
||
| sim.advance(tmax, history) | ||
|
|
||
| history.get_log().print_table() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| /* | ||
| * Copyright (C) 2020-2026 MEmilio | ||
| * | ||
| * Authors: Carlotta Gerstein | ||
| * | ||
|
charlie0614 marked this conversation as resolved.
|
||
| * Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de> | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include "memilio/math/time_series_functor.h" | ||
| #include "memilio/math/interpolation.h" | ||
|
|
||
| #include "pybind_util.h" | ||
| #include "pybind11/pybind11.h" | ||
|
|
||
| namespace py = pybind11; | ||
|
|
||
| namespace pymio | ||
| { | ||
|
|
||
| void bind_time_series_functor(py::module_& m, std::string const& name) | ||
| { | ||
| bind_class<mio::TimeSeriesFunctor<double>, EnablePickling::Never>(m, name.c_str()) | ||
| .def(py::init()) | ||
| .def(py::init<mio::TimeSeriesFunctorType, mio::TimeSeries<double>>()) | ||
| .def(py::init([](const mio::TimeSeries<double>& data) { | ||
| return mio::TimeSeriesFunctor(mio::TimeSeriesFunctorType::LinearInterpolation, data); | ||
| })) | ||
| .def(py::init([](std::vector<std::vector<double>>&& table) { | ||
| return mio::TimeSeriesFunctor<double>(mio::TimeSeriesFunctorType::LinearInterpolation, table); | ||
| })) | ||
|
charlie0614 marked this conversation as resolved.
|
||
| .def("__call__", [](mio::TimeSeriesFunctor<double>& self, double time) { | ||
| return self(time); | ||
| }); | ||
| } | ||
|
|
||
| } // namespace pymio | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The abstract parameter disctribution is in utils/ and should remain model-independent. An
abm::PersonalRandomNumberGeneratoris only known when there is knowledge about the ABM..Is this required at this point? Requiring an abm::PRNG makes this unusable for other models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed