Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
eba6406
adds acquisition mapping
micahwoodard Mar 30, 2026
441ad11
adds rig schema
micahwoodard Mar 30, 2026
a32a842
adds data description
micahwoodard Mar 30, 2026
3fef59b
updates __init__
micahwoodard Mar 30, 2026
062b917
fixes init
micahwoodard Mar 30, 2026
6d0f998
Merge branch 'feat-adding-curriculum' into feat-metadata-mapping
micahwoodard Mar 31, 2026
6dbe0f0
adds cli for mapping
micahwoodard Mar 31, 2026
91febd4
Merge branch 'feat-adding-curriculum' into feat-metadata-mapping
micahwoodard Apr 1, 2026
79dc6f9
refrences current metrics in mapping
micahwoodard Apr 1, 2026
9e03163
dumps metrics
micahwoodard Apr 1, 2026
963292f
prints json
micahwoodard Apr 1, 2026
062aa2f
removes return
micahwoodard Apr 1, 2026
1f5a2fc
lints
micahwoodard Apr 1, 2026
ba96b5f
adds logging
micahwoodard Apr 2, 2026
2f0a446
converts end condition time to seconds
micahwoodard Apr 2, 2026
382ae85
streams logs to stdout
micahwoodard Apr 2, 2026
f160fbd
moves logging before import
micahwoodard Apr 2, 2026
f1fff6d
log fixes
micahwoodard Apr 2, 2026
f7502b1
push fixes for baiting
micahwoodard Apr 7, 2026
37f9c84
add cli
micahwoodard Apr 8, 2026
690982d
adds project script
micahwoodard Apr 8, 2026
c9ee84a
updates args
micahwoodard Apr 8, 2026
e68fe44
removes suffix
micahwoodard Apr 8, 2026
a107aa8
lints
micahwoodard Apr 9, 2026
6ef9ed3
removes data description
micahwoodard Apr 13, 2026
0a7e2f4
configures logger to json and removes cluttered log mesages
micahwoodard Apr 16, 2026
2c411f9
lints
micahwoodard Apr 16, 2026
96ea485
adds name to log schema
micahwoodard Apr 16, 2026
4a305d1
cleans up baiting logic
micahwoodard Apr 17, 2026
80080af
removes test
micahwoodard Apr 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions scripts/walk_through_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os

from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset
from aind_behavior_dynamic_foraging.task_logic.trial_generators.warmup_trial_generator import WarmupTrialGeneratorSpec
from aind_behavior_dynamic_foraging.task_logic.trial_generators import (
CoupledTrialGeneratorSpec,
)
from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome

logging.basicConfig(
Expand All @@ -17,10 +19,10 @@ def walk_through_session(data_directory: os.PathLike):
software_events.load_all()

trial_outcomes = software_events["TrialOutcome"].data["data"].iloc
warmup_trial_generator = WarmupTrialGeneratorSpec().create_generator()
trial_generator = CoupledTrialGeneratorSpec().create_generator()
for i, outcome in enumerate(trial_outcomes):
warmup_trial_generator.update(TrialOutcome.model_validate(outcome))
trial = warmup_trial_generator.next()
trial_generator.update(TrialOutcome.model_validate(outcome))
trial = trial_generator.next()

if not trial:
print(f"Session finished at trial {i}")
Expand Down
6 changes: 5 additions & 1 deletion src/Extensions/bonsai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
import sys
from typing import TYPE_CHECKING

from pydantic import TypeAdapter

from aind_behavior_dynamic_foraging.task_logic import TrialGeneratorSpec
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

from aind_behavior_dynamic_foraging.task_logic import TrialGeneratorSpec # noqa
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the #noqa?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm remembering correctly, it was because I was failing the linting tests since the imports weren't grouped at the top of the file and I had to set up logging before importing the trail generators


if TYPE_CHECKING:
from aind_behavior_dynamic_foraging.task_logic.trial_generators._base import ITrialGenerator
Expand Down
12 changes: 10 additions & 2 deletions src/aind_behavior_dynamic_foraging/data_contract/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from aind_behavior_curriculum import TrainerState
from aind_behavior_curriculum import Metrics, TrainerState
from aind_behavior_services.session import Session
from contraqctor.contract import Dataset, DataStreamCollection
from contraqctor.contract.camera import Camera
Expand Down Expand Up @@ -61,10 +61,18 @@ def make_dataset(
data_streams=[
Json(
name="PreviousMetrics",
reader_params=Json.make_params(
reader_params=PydanticModel.make_params(
model=Metrics,
Comment on lines 61 to +65
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine (and even a good practice.) However, keep in mind this requires you not to make breaking changes to the Metrics model. IF you break it, the contract will not be backwards compatible anymore.

path=root_path / "behavior/previous_metrics.json",
),
),
PydanticModel(
name="Metrics",
reader_params=PydanticModel.make_params(
model=Metrics,
path=root_path / "behavior/metrics.json",
),
),
PydanticModel(
name="TrainerState",
reader_params=PydanticModel.make_params(
Expand Down
34 changes: 34 additions & 0 deletions src/aind_behavior_dynamic_foraging/data_contract/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from typing import Optional

from aind_behavior_dynamic_foraging.data_contract import dataset
from aind_behavior_dynamic_foraging.task_logic import AindDynamicForagingTaskLogic


def calculate_consumed_water(session_path: os.PathLike) -> Optional[float]:
"""Calculate the total volume of water consumed during a session.

Args:
session_path (os.PathLike): Path to the session directory.

Returns:
Optional[float]: Total volume of water consumed in milliliters, or None if unavailable.
"""

trial_outcomes = dataset(session_path)["Behavior"]["SoftwareEvents"]["TrialOutcome"].load().data["data"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can validate this with the TrialOutcome pydantic class to make your life easier

is_right_choice = [to["is_right_choice"] for to in trial_outcomes]
is_rewarded = [to["is_rewarded"] for to in trial_outcomes]

task_logic_data = dataset(session_path)["Behavior"]["InputSchemas"]["TaskLogic"].load().data
task_logic = AindDynamicForagingTaskLogic.model_validate(task_logic_data)
right_reward_size = task_logic.task_parameters.reward_size.right_value_volume
left_reward_size = task_logic.task_parameters.reward_size.left_value_volume

total = 0
for choice, rewarded in zip(is_right_choice, is_rewarded):
if rewarded:
if choice is True:
total += right_reward_size * 1e-3
if choice is False:
total += left_reward_size * 1e-3
return total
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ def next(self) -> Trial | None:
if self.spec.is_baiting:
random_numbers = np.random.random(2)

is_left_baited = self.block.p_left_reward > random_numbers[0] or self.is_left_baited
logger.debug(f"Left baited: {is_left_baited}")
p_reward_left = 1 if is_left_baited else p_reward_left
self.is_left_baited = self.block.p_left_reward > random_numbers[0] or self.is_left_baited
logger.debug(f"Left baited: {self.is_left_baited}")
p_reward_left = 1 if self.is_left_baited else p_reward_left

is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited
logger.debug(f"Right baited: {is_left_baited}")
p_reward_right = 1 if is_right_baited else p_reward_right
self.is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited
logger.debug(f"Right baited: {self.is_right_baited}")
p_reward_right = 1 if self.is_right_baited else p_reward_right

return Trial(
p_reward_left=p_reward_left,
Expand All @@ -195,7 +195,7 @@ def _generate_next_block(
reward_pairs: list[list[float, float]],
base_reward_sum: float,
block_len: Union[UniformDistribution, ExponentialDistribution],
current_block: Optional[None] = None,
current_block: Optional[Block] = None,
) -> Block:
"""Generates the next block, avoiding repeating the current block's side bias.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def _are_end_conditions_met(self) -> bool:
logger.debug("Maximum trial count exceeded.")
return True

logger.debug(
"Trial generation end conditions are not met: "
f"total trials={len(self.is_right_choice_history)}, "
f"time elapsed={time_elapsed},"
f"ignored trial={choice_history[-win:].count(None)},"
)

return False

def update(self, outcome: TrialOutcome | str) -> None:
Expand Down
106 changes: 102 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def make_s_stage_1_warmup():
CoupledTrialGeneratorSpec(
trial_generation_end_parameters=CoupledTrialGenerationEndConditions(
max_trial=1000,
max_time=75,
min_time=30,
max_time=4500,
min_time=1800,
ignore_win=20000,
ignore_ratio_threshold=1,
),
Expand Down Expand Up @@ -113,8 +113,8 @@ def make_s_stage_1():
trial_generator=CoupledTrialGeneratorSpec(
trial_generation_end_parameters=CoupledTrialGenerationEndConditions(
max_trial=1000,
max_time=75,
min_time=30,
max_time=4500,
min_time=1800,
ignore_win=20000,
ignore_ratio_threshold=1,
),
Expand Down Expand Up @@ -158,8 +158,8 @@ def make_s_stage_2():
trial_generator=CoupledTrialGeneratorSpec(
trial_generation_end_parameters=CoupledTrialGenerationEndConditions(
max_trial=1000,
max_time=75,
min_time=30,
max_time=4500,
min_time=1800,
ignore_win=30,
ignore_ratio_threshold=0.83,
),
Expand Down Expand Up @@ -203,8 +203,8 @@ def make_s_stage_3():
trial_generator=CoupledTrialGeneratorSpec(
trial_generation_end_parameters=CoupledTrialGenerationEndConditions(
max_trial=1000,
max_time=75,
min_time=30,
max_time=4500,
min_time=1800,
ignore_win=30,
ignore_ratio_threshold=0.83,
),
Expand Down Expand Up @@ -248,8 +248,8 @@ def make_s_stage_final():
trial_generator=CoupledTrialGeneratorSpec(
trial_generation_end_parameters=CoupledTrialGenerationEndConditions(
max_trial=1000,
max_time=75,
min_time=30,
max_time=4500,
min_time=1800,
ignore_win=30,
ignore_ratio_threshold=0.83,
),
Expand Down Expand Up @@ -289,8 +289,8 @@ def make_s_stage_graduated():
trial_generator=CoupledTrialGeneratorSpec(
trial_generation_end_parameters=CoupledTrialGenerationEndConditions(
max_trial=1000,
max_time=75,
min_time=30,
max_time=4500,
min_time=1800,
ignore_win=30,
ignore_ratio_threshold=0.83,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def metrics_from_dataset(
logger.debug(f"Calculated foraging efficiency as {foraging_efficiency}")

try:
prev_metrics = DynamicForagingMetrics(**dataset["Behavior"]["PreviousMetrics"].data)
prev_metrics = DynamicForagingMetrics.model_validate(dataset["Behavior"]["PreviousMetrics"].data)
prev_stage = prev_metrics.stage_name
except FileNotFoundError:
logger.info("No previous metrics found.")
Expand Down Expand Up @@ -138,7 +138,7 @@ def compute_foraging_efficiency(

if not is_baiting:
logger.debug("Calculated non baiting foraging efficiency.")
optimal_rewards_per_session = np.nanmean(np.max([p_right_reward], axis=0)) * len(p_left_reward)
optimal_rewards_per_session = np.nanmean(np.max([p_right_reward, p_left_reward], axis=0)) * len(p_left_reward)
else:
logger.debug("Calculated baiting foraging efficiency.")
p_max = np.maximum(p_left_reward, p_right_reward)
Expand Down
Empty file.
Loading
Loading