diff --git a/scripts/walk_through_session.py b/scripts/walk_through_session.py index f6288ac..fae67ee 100644 --- a/scripts/walk_through_session.py +++ b/scripts/walk_through_session.py @@ -2,7 +2,9 @@ import os from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset -from aind_behavior_dynamic_foraging.task_logic.trial_generators.warmup_trial_generator import WarmupTrialGeneratorSpec +from aind_behavior_dynamic_foraging.task_logic.trial_generators import ( + CoupledTrialGeneratorSpec, +) from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome logging.basicConfig( @@ -17,10 +19,10 @@ def walk_through_session(data_directory: os.PathLike): software_events.load_all() trial_outcomes = software_events["TrialOutcome"].data["data"].iloc - warmup_trial_generator = WarmupTrialGeneratorSpec().create_generator() + trial_generator = CoupledTrialGeneratorSpec().create_generator() for i, outcome in enumerate(trial_outcomes): - warmup_trial_generator.update(TrialOutcome.model_validate(outcome)) - trial = warmup_trial_generator.next() + trial_generator.update(TrialOutcome.model_validate(outcome)) + trial = trial_generator.next() if not trial: print(f"Session finished at trial {i}") diff --git a/src/Extensions/bonsai.py b/src/Extensions/bonsai.py index 5fab0e0..6c0c281 100644 --- a/src/Extensions/bonsai.py +++ b/src/Extensions/bonsai.py @@ -1,8 +1,16 @@ +import logging +import sys from typing import TYPE_CHECKING from pydantic import TypeAdapter -from aind_behavior_dynamic_foraging.task_logic import TrialGeneratorSpec +logging.basicConfig( + stream=sys.stdout, + level=logging.DEBUG, + format='{"name": "%(name)s", "level": %(levelno)d, "msg": "%(message)s"}', +) + +from aind_behavior_dynamic_foraging.task_logic import TrialGeneratorSpec # noqa if TYPE_CHECKING: from aind_behavior_dynamic_foraging.task_logic.trial_generators._base import ITrialGenerator diff --git a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py index 3f2e238..b88f99b 100644 --- a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py +++ b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py @@ -1,6 +1,6 @@ from pathlib import Path -from aind_behavior_curriculum import TrainerState +from aind_behavior_curriculum import Metrics, TrainerState from aind_behavior_services.session import Session from contraqctor.contract import Dataset, DataStreamCollection from contraqctor.contract.camera import Camera @@ -61,10 +61,18 @@ def make_dataset( data_streams=[ Json( name="PreviousMetrics", - reader_params=Json.make_params( + reader_params=PydanticModel.make_params( + model=Metrics, path=root_path / "behavior/previous_metrics.json", ), ), + PydanticModel( + name="Metrics", + reader_params=PydanticModel.make_params( + model=Metrics, + path=root_path / "behavior/metrics.json", + ), + ), PydanticModel( name="TrainerState", reader_params=PydanticModel.make_params( diff --git a/src/aind_behavior_dynamic_foraging/data_contract/utils.py b/src/aind_behavior_dynamic_foraging/data_contract/utils.py new file mode 100644 index 0000000..ab7819d --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/data_contract/utils.py @@ -0,0 +1,34 @@ +import os +from typing import Optional + +from aind_behavior_dynamic_foraging.data_contract import dataset +from aind_behavior_dynamic_foraging.task_logic import AindDynamicForagingTaskLogic + + +def calculate_consumed_water(session_path: os.PathLike) -> Optional[float]: + """Calculate the total volume of water consumed during a session. + + Args: + session_path (os.PathLike): Path to the session directory. + + Returns: + Optional[float]: Total volume of water consumed in milliliters, or None if unavailable. + """ + + trial_outcomes = dataset(session_path)["Behavior"]["SoftwareEvents"]["TrialOutcome"].load().data["data"] + is_right_choice = [to["is_right_choice"] for to in trial_outcomes] + is_rewarded = [to["is_rewarded"] for to in trial_outcomes] + + task_logic_data = dataset(session_path)["Behavior"]["InputSchemas"]["TaskLogic"].load().data + task_logic = AindDynamicForagingTaskLogic.model_validate(task_logic_data) + right_reward_size = task_logic.task_parameters.reward_size.right_value_volume + left_reward_size = task_logic.task_parameters.reward_size.left_value_volume + + total = 0 + for choice, rewarded in zip(is_right_choice, is_rewarded): + if rewarded: + if choice is True: + total += right_reward_size * 1e-3 + if choice is False: + total += left_reward_size * 1e-3 + return total diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 0112c95..2e093b7 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -157,23 +157,18 @@ def next(self) -> Trial | None: iti = draw_sample(self.spec.inter_trial_interval_duration) quiescent = draw_sample(self.spec.quiescent_duration) - p_reward_left = self.block.p_left_reward - p_reward_right = self.block.p_right_reward - if self.spec.is_baiting: random_numbers = np.random.random(2) - is_left_baited = self.block.p_left_reward > random_numbers[0] or self.is_left_baited - logger.debug(f"Left baited: {is_left_baited}") - p_reward_left = 1 if is_left_baited else p_reward_left + self.is_left_baited = self.block.p_left_reward > random_numbers[0] or self.is_left_baited + logger.debug(f"Left baited: {self.is_left_baited}") - is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited - logger.debug(f"Right baited: {is_left_baited}") - p_reward_right = 1 if is_right_baited else p_reward_right + self.is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited + logger.debug(f"Right baited: {self.is_right_baited}") return Trial( - p_reward_left=p_reward_left, - p_reward_right=p_reward_right, + p_reward_left=1 if self.is_left_baited else self.block.p_left_reward, + p_reward_right=1 if self.is_right_baited else self.block.p_right_reward, reward_consumption_duration=self.spec.reward_consumption_duration, response_deadline_duration=self.spec.response_duration, quiescence_period_duration=quiescent, @@ -195,7 +190,7 @@ def _generate_next_block( reward_pairs: list[list[float, float]], base_reward_sum: float, block_len: Union[UniformDistribution, ExponentialDistribution], - current_block: Optional[None] = None, + current_block: Optional[Block] = None, ) -> Block: """Generates the next block, avoiding repeating the current block's side bias. diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/coupled_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/coupled_trial_generator.py index 93c37db..34ba90f 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/coupled_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/coupled_trial_generator.py @@ -134,6 +134,13 @@ def _are_end_conditions_met(self) -> bool: logger.debug("Maximum trial count exceeded.") return True + logger.debug( + "Trial generation end conditions are not met: " + f"total trials={len(self.is_right_choice_history)}, " + f"time elapsed={time_elapsed}," + f"ignored trial={choice_history[-win:].count(None)}," + ) + return False def update(self, outcome: TrialOutcome | str) -> None: @@ -147,8 +154,6 @@ def update(self, outcome: TrialOutcome | str) -> None: outcome: The TrialOutcome from the most recently completed trial. """ - logger.info(f"Updating coupled trial generator with trial outcome of {outcome}") - if isinstance(outcome, str): outcome = TrialOutcome.model_validate_json(outcome) diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 0b5daae..0943e38 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -115,18 +115,6 @@ def test_next_returns_correct_reward_probs(self): self.assertEqual(trial.p_reward_left, self.generator.block.p_left_reward) self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward) - #### Test unbaited #### - - def test_baiting_disabled_reward_prob_unchanged(self): - """Without baiting, reward probs should equal block probs exactly.""" - self.generator.block = Block(p_right_reward=0.8, p_left_reward=0.2, min_length=10) - self.generator.is_left_baited = True - self.generator.is_right_baited = True - trial = self.generator.next() - - self.assertEqual(trial.p_reward_right, 0.8) - self.assertEqual(trial.p_reward_left, 0.2) - class TestBlockBaseBaitingTrialGenerator(unittest.TestCase): def setUp(self): diff --git a/uv.lock b/uv.lock index 65f6e74..c776f3d 100644 --- a/uv.lock +++ b/uv.lock @@ -17,6 +17,7 @@ resolution-markers = [ members = [ "aind-behavior-dynamic-foraging", "aind-behavior-dynamic-foraging-curricula", + "aind-behavior-dynamic-foraging-metadata-mapper", ] [[package]] @@ -149,6 +150,57 @@ docs = [ { name = "ruff" }, ] +[[package]] +name = "aind-behavior-dynamic-foraging-metadata-mapper" +version = "0.0.1" +source = { editable = "workspace/aind_behavior_dynamic_foraging_metadata_mapper" } +dependencies = [ + { name = "aind-behavior-dynamic-foraging" }, + { name = "aind-data-schema" }, + { name = "cyclopts" }, + { name = "numpy" }, + { name = "pydantic-settings" }, +] + +[package.dev-dependencies] +dev = [ + { name = "codespell" }, + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "ruff" }, +] +docs = [ + { name = "mkdocs" }, + { name = "mkdocs-material" }, + { name = "mkdocstrings", extra = ["python"] }, + { name = "pymdown-extensions" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "aind-behavior-dynamic-foraging", editable = "." }, + { name = "aind-data-schema", specifier = ">=2.6.0" }, + { name = "cyclopts", specifier = ">=4.10.0" }, + { name = "numpy", specifier = ">=2.4.2" }, + { name = "pydantic-settings" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "codespell" }, + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "ruff" }, +] +docs = [ + { name = "mkdocs" }, + { name = "mkdocs-material" }, + { name = "mkdocstrings", extras = ["python"] }, + { name = "pymdown-extensions" }, + { name = "ruff" }, +] + [[package]] name = "aind-behavior-services" version = "0.13.5" @@ -325,6 +377,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149, upload-time = "2025-07-30T10:01:59.329Z" }, ] +[[package]] +name = "attrs" +version = "26.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/8e/82a0fe20a541c03148528be8cac2408564a6c9a0cc7e9171802bc1d26985/attrs-26.1.0.tar.gz", hash = "sha256:d03ceb89cb322a8fd706d4fb91940737b6642aa36998fe130a9bc96c985eff32", size = 952055, upload-time = "2026-03-19T14:22:25.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, +] + [[package]] name = "autodoc-pydantic" version = "2.2.0" @@ -865,6 +926,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, ] +[[package]] +name = "cyclopts" +version = "4.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "docstring-parser" }, + { name = "rich" }, + { name = "rich-rst" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/c4/2ce2ca1451487dc7d59f09334c3fa1182c46cfcf0a2d5f19f9b26d53ac74/cyclopts-4.10.1.tar.gz", hash = "sha256:ad4e4bb90576412d32276b14a76f55d43353753d16217f2c3cd5bdceba7f15a0", size = 166623, upload-time = "2026-03-23T14:43:01.098Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/0b/2261922126b2e50c601fe22d7ff5194e0a4d50e654836260c0665e24d862/cyclopts-4.10.1-py3-none-any.whl", hash = "sha256:35f37257139380a386d9fe4475e1e7c87ca7795765ef4f31abba579fcfcb6ecd", size = 204331, upload-time = "2026-03-23T14:43:02.625Z" }, +] + [[package]] name = "dnspython" version = "2.8.0" @@ -874,6 +950,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + [[package]] name = "docutils" version = "0.22.4" @@ -1642,9 +1727,9 @@ name = "msal" version = "1.35.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cryptography" }, - { name = "pyjwt", extra = ["crypto"] }, - { name = "requests" }, + { name = "cryptography", marker = "sys_platform == 'win32'" }, + { name = "pyjwt", extra = ["crypto"], marker = "sys_platform == 'win32'" }, + { name = "requests", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3c/aa/5a646093ac218e4a329391d5a31e5092a89db7d2ef1637a90b82cd0b6f94/msal-1.35.1.tar.gz", hash = "sha256:70cac18ab80a053bff86219ba64cfe3da1f307c74b009e2da57ef040eb1b5656", size = 165658, upload-time = "2026-03-04T23:38:51.812Z" } wheels = [ @@ -2195,7 +2280,7 @@ wheels = [ [package.optional-dependencies] crypto = [ - { name = "cryptography" }, + { name = "cryptography", marker = "sys_platform == 'win32'" }, ] [[package]] @@ -2412,6 +2497,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, ] +[[package]] +name = "rich-rst" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bc/6d/a506aaa4a9eaa945ed8ab2b7347859f53593864289853c5d6d62b77246e0/rich_rst-1.3.2.tar.gz", hash = "sha256:a1196fdddf1e364b02ec68a05e8ff8f6914fee10fbca2e6b6735f166bb0da8d4", size = 14936, upload-time = "2025-10-14T16:49:45.332Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/2f/b4530fbf948867702d0a3f27de4a6aab1d156f406d72852ab902c4d04de9/rich_rst-1.3.2-py3-none-any.whl", hash = "sha256:a99b4907cbe118cf9d18b0b44de272efa61f15117c61e39ebdc431baf5df722a", size = 12567, upload-time = "2025-10-14T16:49:42.953Z" }, +] + [[package]] name = "roman-numerals" version = "4.1.0" diff --git a/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/coupled_baiting/stages.py b/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/coupled_baiting/stages.py index 66ee583..6050287 100644 --- a/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/coupled_baiting/stages.py +++ b/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/coupled_baiting/stages.py @@ -66,8 +66,8 @@ def make_s_stage_1_warmup(): CoupledTrialGeneratorSpec( trial_generation_end_parameters=CoupledTrialGenerationEndConditions( max_trial=1000, - max_time=75, - min_time=30, + max_time=4500, + min_time=1800, ignore_win=20000, ignore_ratio_threshold=1, ), @@ -113,8 +113,8 @@ def make_s_stage_1(): trial_generator=CoupledTrialGeneratorSpec( trial_generation_end_parameters=CoupledTrialGenerationEndConditions( max_trial=1000, - max_time=75, - min_time=30, + max_time=4500, + min_time=1800, ignore_win=20000, ignore_ratio_threshold=1, ), @@ -158,8 +158,8 @@ def make_s_stage_2(): trial_generator=CoupledTrialGeneratorSpec( trial_generation_end_parameters=CoupledTrialGenerationEndConditions( max_trial=1000, - max_time=75, - min_time=30, + max_time=4500, + min_time=1800, ignore_win=30, ignore_ratio_threshold=0.83, ), @@ -203,8 +203,8 @@ def make_s_stage_3(): trial_generator=CoupledTrialGeneratorSpec( trial_generation_end_parameters=CoupledTrialGenerationEndConditions( max_trial=1000, - max_time=75, - min_time=30, + max_time=4500, + min_time=1800, ignore_win=30, ignore_ratio_threshold=0.83, ), @@ -248,8 +248,8 @@ def make_s_stage_final(): trial_generator=CoupledTrialGeneratorSpec( trial_generation_end_parameters=CoupledTrialGenerationEndConditions( max_trial=1000, - max_time=75, - min_time=30, + max_time=4500, + min_time=1800, ignore_win=30, ignore_ratio_threshold=0.83, ), @@ -289,8 +289,8 @@ def make_s_stage_graduated(): trial_generator=CoupledTrialGeneratorSpec( trial_generation_end_parameters=CoupledTrialGenerationEndConditions( max_trial=1000, - max_time=75, - min_time=30, + max_time=4500, + min_time=1800, ignore_win=30, ignore_ratio_threshold=0.83, ), diff --git a/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/metrics.py b/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/metrics.py index 76f0922..942926f 100644 --- a/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/metrics.py +++ b/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/metrics.py @@ -71,7 +71,7 @@ def metrics_from_dataset( logger.debug(f"Calculated foraging efficiency as {foraging_efficiency}") try: - prev_metrics = DynamicForagingMetrics(**dataset["Behavior"]["PreviousMetrics"].data) + prev_metrics = DynamicForagingMetrics.model_validate(dataset["Behavior"]["PreviousMetrics"].data) prev_stage = prev_metrics.stage_name except FileNotFoundError: logger.info("No previous metrics found.") @@ -138,7 +138,7 @@ def compute_foraging_efficiency( if not is_baiting: logger.debug("Calculated non baiting foraging efficiency.") - optimal_rewards_per_session = np.nanmean(np.max([p_right_reward], axis=0)) * len(p_left_reward) + optimal_rewards_per_session = np.nanmean(np.max([p_right_reward, p_left_reward], axis=0)) * len(p_left_reward) else: logger.debug("Calculated baiting foraging efficiency.") p_max = np.maximum(p_left_reward, p_right_reward) diff --git a/workspace/aind_behavior_dynamic_foraging_metadata_mapper/README.md b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/README.md new file mode 100644 index 0000000..e69de29 diff --git a/workspace/aind_behavior_dynamic_foraging_metadata_mapper/pyproject.toml b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/pyproject.toml new file mode 100644 index 0000000..7571e5d --- /dev/null +++ b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/pyproject.toml @@ -0,0 +1,78 @@ +[build-system] +requires = ["uv_build>=0.8.22"] +build-backend = "uv_build" + +[project] +name = "aind-behavior-dynamic-foraging-metadata-mapper" +description = "A library of mapping for the Dynamic Foraging task." +authors = [ + {name = "Bruno Cruz", email = "bruno.cruz@alleninstitute.org"}, + {name = "Micah Woodard", email = "micah.woodard@alleninstitute.org"} + ] +license = "MIT" +version = "0.0.1" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Operating System :: Microsoft :: Windows", +] +readme = {file = "README.md", content-type = "text/markdown"} + +dependencies = [ + "numpy>=2.4.2", + "pydantic-settings", + "aind-behavior-dynamic-foraging==0.0.2rc24", + "aind-data-schema>=2.6.0", + "cyclopts>=4.10.0" +] + +[tool.uv.sources] +aind-behavior-dynamic-foraging = { workspace = true } + +[dependency-groups] + +dev = [ + 'ruff', + 'pytest', + 'pytest-cov', + 'codespell', +] + +docs = [ + 'mkdocs', + 'mkdocs-material', + 'mkdocstrings[python]', + 'pymdown-extensions', + 'ruff', +] + +[tool.uv] +default-groups = ['dev'] + +[tool.ruff] +line-length = 120 +target-version = 'py311' + +[tool.ruff.lint] +extend-select = ['Q', 'RUF100', 'C90', 'I'] +extend-ignore = [] +mccabe = { max-complexity = 14 } +pydocstyle = { convention = 'google' } + +[tool.codespell] +skip = '.git,*.pdf,*.svg,uv.lock' +ignore-words-list = 'nd' + +[tool.pytest.ini_options] +addopts = "--strict-markers --tb=short --cov=src --cov-report=term-missing --cov-fail-under=70" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + +[project.scripts] +acquisition = "aind_behavior_dynamic_foraging_metadata_mapper.acquisition:app" +instrument = "aind_behavior_dynamic_foraging_metadata_mapper.instrument:app" +mapper = "aind_behavior_dynamic_foraging_metadata_mapper.cli:main" diff --git a/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/__init__.py b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/__init__.py new file mode 100644 index 0000000..938adb6 --- /dev/null +++ b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/__init__.py @@ -0,0 +1,7 @@ +from .acquisition import acqusition_from_dataset +from .instrument import instrument_from_dataset + +__all__ = [ + "acqusition_from_dataset", + "instrument_from_dataset", +] diff --git a/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/acquisition.py b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/acquisition.py new file mode 100644 index 0000000..b64b5dd --- /dev/null +++ b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/acquisition.py @@ -0,0 +1,227 @@ +import logging +import os +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import List, Optional + +import git +from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset +from aind_behavior_dynamic_foraging.data_contract.utils import calculate_consumed_water +from aind_behavior_dynamic_foraging.rig import AindDynamicForagingRig +from aind_behavior_dynamic_foraging.task_logic import AindDynamicForagingTaskLogic +from aind_behavior_services.rig import Device as AbsDevice +from aind_behavior_services.rig import cameras as abs_camera +from aind_behavior_services.rig import water_valve as abs_water_valve +from aind_behavior_services.session import Session +from aind_behavior_services.utils import get_fields_of_type, utcnow +from aind_data_schema.components.configs import TriggerType +from aind_data_schema.components.measurements import CalibrationFit, FitType, GenericModel, VolumeCalibration +from aind_data_schema.core.acquisition import ( + Acquisition, + AcquisitionSubjectDetails, + Code, + DataStream, + DetectorConfig, + PerformanceMetrics, + StimulusEpoch, + StimulusModality, +) +from aind_data_schema_models import units +from aind_data_schema_models.modalities import Modality +from clabe.data_mapper import helpers as data_mapper_helpers +from cyclopts import App + +logger = logging.getLogger(__name__) + +app = App() + + +@app.default +def acqusition_from_dataset( + data_directory: Path, repo_path: os.PathLike, end_time: Optional[datetime] = None +) -> Acquisition: + """ + Create acquisition model for completed session. + + Args: + data_directory (os.PathLike): + Path to the directory containing the dataset to analyze. This + directory is expected to include all required behavioral data files. + + repo_path (os.PathLike): + Path to github repository. + + end_time: Optional[datetime]: + End time of acquisition. If None, current time will be used. + + Returns: + Acquisition: + Acquisition model for session + + Raises: + FileNotFoundError: + If the specified data directory or required files do not exist. + + ValueError: + If the dataset is malformed or missing required fields for + computing metrics. + """ + dataset = df_foraging_dataset(data_directory) + input_schemas = dataset["Behavior"]["InputSchemas"] + session_model = Session.model_validate(input_schemas["Session"].data) + rig_model = AindDynamicForagingRig.model_validate(input_schemas["Rig"].data) + task_logic_model = AindDynamicForagingTaskLogic.model_validate(input_schemas["TaskLogic"].data) + repository = git.Repo(repo_path) + + if end_time is None: + logger.warning("Session end time is not set. Using current time as end time.") + acquisition_end_time = datetime.now(tz=timezone.utc) + + bonsai_code = _get_bonsai_as_code(repository) + python_code = _get_python_as_code(repository) + + cameras = data_mapper_helpers.get_cameras(rig_model, exclude_without_video_writer=True) + camera_configs = [_get_cameras_config(k, v, repository) for k, v in cameras.items()] + + # construct data stream + modalities: list[Modality] = [getattr(Modality, "BEHAVIOR")] + if len(camera_configs) > 0: + modalities.append(getattr(Modality, "BEHAVIOR_VIDEOS")) + modalities = list(set(modalities)) + + active_devices = [ + _device[0] + for _device in get_fields_of_type(rig_model, AbsDevice, stop_recursion_on_type=False) + if _device[0] is not None and not isinstance(_device[1], abs_camera.CameraController) + ] + + data_streams = [ + DataStream( + stream_start_time=session_model.date, + stream_end_time=acquisition_end_time, + code=[bonsai_code, python_code], + active_devices=active_devices, + modalities=modalities, + configurations=camera_configs, + notes=session_model.notes, + ) + ] + + # populate behavior epoch + metrics = dataset["Behavior"]["Metrics"].data + trainer_state = dataset["Behavior"]["TrainerState"].data + performance_metrics = PerformanceMetrics(output_parameters=metrics.model_dump()) + + stimulus_epoch = StimulusEpoch( + stimulus_start_time=session_model.date, + stimulus_end_time=acquisition_end_time, + stimulus_name="GoCue", + code=bonsai_code, + stimulus_modalities=[StimulusModality.AUDITORY], + performance_metrics=performance_metrics, + curriculum_status=trainer_state.stage.name, + ) + + # Construct aind-data-schema session + return Acquisition( + subject_id=session_model.subject, + subject_details=_get_subject_details(data_directory), + instrument_id=rig_model.rig_name, + acquisition_end_time=acquisition_end_time, + acquisition_start_time=session_model.date, + experimenters=session_model.experimenter, + acquisition_type=session_model.experiment or task_logic_model.name, + coordinate_system=None, + data_streams=data_streams, + calibrations=_get_water_calibration(rig_model), + stimulus_epochs=[stimulus_epoch], + ) + + +def _get_subject_details(data_directory: os.PathLike) -> AcquisitionSubjectDetails: + return AcquisitionSubjectDetails( + mouse_platform_name="tube", + reward_consumed_total=calculate_consumed_water(data_directory), + reward_consumed_unit=units.VolumeUnit.ML, + ) + + +def _get_water_calibration(rig_model: AindDynamicForagingRig) -> List[VolumeCalibration]: + + water_calibrations = get_fields_of_type(rig_model, abs_water_valve.WaterValveCalibration) + vol_cal = [] + for device_name, water_calibration in water_calibrations: + c = water_calibration + vol_cal.append( + VolumeCalibration( + device_name=device_name, + calibration_date=water_calibration.date if water_calibration.date else utcnow(), + input=list(c.interval_average.keys()), + output=list(c.interval_average.values()), + input_unit=units.TimeUnit.S, + output_unit=units.VolumeUnit.ML, + fit=CalibrationFit( + fit_type=FitType.LINEAR, + fit_parameters=GenericModel.model_validate(c.model_dump()), + ), + ) + ) + return vol_cal + + +def _get_cameras_config(name: str, camera: abs_camera.CameraTypes, repository: git.Repo) -> List[DetectorConfig]: + + if isinstance(camera.video_writer, abs_camera.VideoWriterFfmpeg): + compression = Code( + url="https://ffmpeg.org/", + name="FFMPEG", + parameters=GenericModel.model_validate(camera.video_writer.model_dump()), + ) + elif isinstance(camera.video_writer, abs_camera.VideoWriterOpenCv): + bonsai = _get_bonsai_as_code(repository) + bonsai.parameters = GenericModel.model_validate(camera.video_writer.model_dump()) + compression = bonsai + else: + raise ValueError("Camera does not have a valid video writer configured.") + + camera = DetectorConfig( + device_name=name, + exposure_time=getattr(camera, "exposure", -1), + exposure_time_unit=units.TimeUnit.US, + trigger_type=TriggerType.EXTERNAL, + compression=compression(camera.video_writer), + ) + + cameras = data_mapper_helpers.get_cameras(AindDynamicForagingTaskLogic, exclude_without_video_writer=True) + + return list(map(camera, cameras.keys(), cameras.values())) + + +def _get_bonsai_as_code(repository: git.Repo) -> Code: + bonsai_folder = Path(Path(repository.working_tree_dir) / "bonsai" / "bonsai.exe").parent + bonsai_env = data_mapper_helpers.snapshot_bonsai_environment(bonsai_folder / "bonsai.config") + bonsai_version = bonsai_env.get("Bonsai", "unknown") + assert isinstance(repository, git.Repo) + + return Code( + url=repository.remote().url, + name="Aind.Behavior.DynamicForaging", + version=repository.head.commit.hexsha, + language="Bonsai", + language_version=bonsai_version, + ) + + +def _get_python_as_code(repository: git.Repo) -> Code: + v = sys.version_info + semver = f"{v.major}.{v.minor}.{v.micro}" + if v.releaselevel != "final": + semver += f"-{v.releaselevel}.{v.serial}" + return Code( + url=repository.remote().url, + name="aind-behavior-dynamic-foraging", + version=repository.head.commit.hexsha, + language="Python", + language_version=semver, + ) diff --git a/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/cli.py b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/cli.py new file mode 100644 index 0000000..4217652 --- /dev/null +++ b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/cli.py @@ -0,0 +1,49 @@ +import logging +import os +import typing as t +from pathlib import Path + +from pydantic import AwareDatetime, Field +from pydantic_settings import BaseSettings, CliApp + +logger = logging.getLogger(__name__) + + +class DataMapperCli(BaseSettings, cli_kebab_case=True): + data_directory: os.PathLike = Field(description="Path to the session data directory.") + repo_path: os.PathLike = Field( + default=Path("."), description="Path to the repository. By default it will use the current directory." + ) + session_end_time: AwareDatetime | None = Field( + default=None, + description="End time of the session in ISO format. If not provided, will use the time the data mapping is run.", + ) + suffix: t.Optional[str] = Field(default="", description="Suffix to append to the output filenames.") + + def cli_cmd(self): + """Generate aind-data-schema metadata for the Dynamic Foraging dataset located at the specified path.""" + from .acquisition import acqusition_from_dataset + from .instrument import instrument_from_dataset + + acquisition = acqusition_from_dataset( + data_directory=Path(self.data_directory), + repo_path=Path(self.repo_path), + end_time=self.session_end_time, + ) + instrument = instrument_from_dataset(data_directory=Path(self.data_directory)) + + acquisition.write_standard_file(output_directory=Path(self.data_directory), filename_suffix=self.suffix) + instrument.write_standard_file(output_directory=Path(self.data_directory), filename_suffix=self.suffix) + + logger.info( + "Mapping completed! Saved acquisition.json, instrument.json to %s", + self.data_directory, + ) + + +def main(): + CliApp.run(DataMapperCli) + + +if __name__ == "__main__": + main() diff --git a/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/instrument.py b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/instrument.py new file mode 100644 index 0000000..bb4d4f0 --- /dev/null +++ b/workspace/aind_behavior_dynamic_foraging_metadata_mapper/src/aind_behavior_dynamic_foraging_metadata_mapper/instrument.py @@ -0,0 +1,175 @@ +from datetime import date +from pathlib import Path + +from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset +from aind_behavior_dynamic_foraging.rig import AindDynamicForagingRig +from aind_data_schema.components.connections import Connection +from aind_data_schema.components.coordinates import Axis, AxisName, CoordinateSystem, Direction, Origin +from aind_data_schema.components.devices import ( + AnatomicalRelative, + Camera, + CameraAssembly, + CameraTarget, + DataInterface, + HarpDevice, + HarpDeviceType, + Lens, + MotorizedStage, + SizeUnit, +) +from aind_data_schema.core.instrument import Instrument +from aind_data_schema_models.modalities import Modality +from aind_data_schema_models.organizations import Organization +from cyclopts import App + +app = App() + + +@app.default +def instrument_from_dataset( + data_directory: Path, +) -> Instrument: + """ + Create Instrument model for completed session. + + Args: + data_directory (os.PathLike): + Path to the directory containing the dataset to analyze. This + directory is expected to include all required behavioral data files. + + Returns: + Instrument: + Instrument model for session + + Raises: + FileNotFoundError: + If the specified data directory or required files do not exist. + + ValueError: + If the dataset is malformed or missing required fields for + computing metrics. + """ + + dataset = df_foraging_dataset(data_directory) + input_schemas = dataset["Behavior"]["InputSchemas"] + rig = AindDynamicForagingRig.model_validate(input_schemas["Rig"].data) + + components = [] + connections = [] + + # cameras + for name, cam in rig.triggered_camera_controller.cameras.items(): + camera = Camera( + name=name, + serial_number=cam.serial_number, + manufacturer=Organization.SPINNAKER, + data_interface=DataInterface.COAX, + ) + assembly = CameraAssembly( + name=f"{name}Assembly", + camera=camera, + target=CameraTarget.BODY if "Body" in name else CameraTarget.FACE, + lens=Lens(name="Lens A", manufacturer=Organization.FUJINON), + relative_position=[AnatomicalRelative.RIGHT if "Body" in name else AnatomicalRelative.SUPERIOR], + ) + components.append(assembly) + + # behavior board + components.append( + HarpDevice( + name="BehaviorBoard", + harp_device_type=HarpDeviceType.BEHAVIOR, + serial_number=rig.harp_behavior.serial_number, + manufacturer=Organization.CHAMPALIMAUD, + is_clock_generator=False, + ) + ) + + # clock generator + components.append( + HarpDevice( + name="ClockGenerator", + harp_device_type=HarpDeviceType.WHITERABBIT, + serial_number=rig.harp_clock_generator.serial_number, + is_clock_generator=True, + ) + ) + + # sound card + components.append( + HarpDevice( + name="SoundCard", + harp_device_type=HarpDeviceType.SOUNDCARD, + serial_number=rig.harp_sound_card.serial_number, + manufacturer=Organization.CHAMPALIMAUD, + is_clock_generator=False, + ) + ) + + # optional harp devices + if rig.harp_lickometer_left: + components.append( + HarpDevice( + name="LickometerLeft", + harp_device_type=HarpDeviceType.LICKETYSPLIT, + serial_number=rig.harp_lickometer_left.serial_number, + is_clock_generator=False, + ) + ) + if rig.harp_lickometer_right: + components.append( + HarpDevice( + name="LickometerRight", + serial_number=rig.harp_lickometer_right.serial_number, + harp_device_type=HarpDeviceType.LICKETYSPLIT, + is_clock_generator=False, + ) + ) + if rig.harp_sniff_detector: + components.append( + HarpDevice( + name="SniffDetector", + harp_device_type=HarpDeviceType.SNIFFDETECTOR, + serial_number=rig.harp_sniff_detector.serial_number, + is_clock_generator=False, + ) + ) + if rig.harp_environment_sensor: + components.append( + HarpDevice( + name="EnvironmentSensor", + harp_device_type=HarpDeviceType.ENVIRONMENTSENSOR, + serial_number=rig.harp_environment_sensor.serial_number, + is_clock_generator=False, + ) + ) + + # manipulator + components.append(MotorizedStage(name="Manipulator", serial_number=rig.manipulator.serial_number, travel=0.0)) + + # connections + for name in rig.triggered_camera_controller.cameras: + connections.append( + Connection( + source_device="BehaviorBoard", + target_device=name, + ) + ) + + return Instrument( + instrument_id=rig.rig_name, + modification_date=date.today(), + modalities=[Modality.BEHAVIOR, Modality.BEHAVIOR_VIDEOS], + coordinate_system=CoordinateSystem( + name="RigCoordinateSystem", + origin=Origin.ORIGIN, + axes=[ + Axis(name=AxisName.X, direction=Direction.LR), + Axis(name=AxisName.Y, direction=Direction.FB), + Axis(name=AxisName.Z, direction=Direction.DU), + ], + axis_unit=SizeUnit.MM, + ), + components=components, + connections=connections, + )