diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 43be2c9..22cfc95 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -434,6 +434,34 @@ "title": "AuditorySecondaryReinforcer", "type": "object" }, + "AutoWaterParameters": { + "properties": { + "min_ignored_trials": { + "default": 3, + "description": "Minimum consecutive ignored trials before auto water is triggered.", + "minimum": 0, + "title": "Min Ignored Trials", + "type": "integer" + }, + "min_unrewarded_trials": { + "default": 3, + "description": "Minimum consecutive unrewarded trials before auto water is triggered.", + "minimum": 0, + "title": "Min Unrewarded Trials", + "type": "integer" + }, + "reward_fraction": { + "default": 0.8, + "description": "Fraction of full reward volume delivered during auto water (0=none, 1=full).", + "maximum": 1, + "minimum": 0, + "title": "Reward Fraction", + "type": "number" + } + }, + "title": "AutoWaterParameters", + "type": "object" + }, "Axis": { "description": "Motor axis available", "enum": [ @@ -965,6 +993,22 @@ }, "description": "Parameters defining the reward probability structure." }, + "autowater_parameters": { + "default": { + "min_ignored_trials": 3, + "min_unrewarded_trials": 3, + "reward_fraction": 0.8 + }, + "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "oneOf": [ + { + "$ref": "#/$defs/AutoWaterParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -3235,6 +3279,22 @@ }, "description": "Parameters defining the reward probability structure." }, + "autowater_parameters": { + "default": { + "min_ignored_trials": 3, + "min_unrewarded_trials": 3, + "reward_fraction": 0.8 + }, + "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "oneOf": [ + { + "$ref": "#/$defs/AutoWaterParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "const": true, "default": true, diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index c3dfd91..a2b7e14 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -769,6 +769,116 @@ public override string ToString() } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class AutoWaterParameters + { + + private int _minIgnoredTrials; + + private int _minUnrewardedTrials; + + private double _rewardFraction; + + public AutoWaterParameters() + { + _minIgnoredTrials = 3; + _minUnrewardedTrials = 3; + _rewardFraction = 0.8D; + } + + protected AutoWaterParameters(AutoWaterParameters other) + { + _minIgnoredTrials = other._minIgnoredTrials; + _minUnrewardedTrials = other._minUnrewardedTrials; + _rewardFraction = other._rewardFraction; + } + + /// + /// Minimum consecutive ignored trials before auto water is triggered. + /// + [Newtonsoft.Json.JsonPropertyAttribute("min_ignored_trials")] + [System.ComponentModel.DescriptionAttribute("Minimum consecutive ignored trials before auto water is triggered.")] + public int MinIgnoredTrials + { + get + { + return _minIgnoredTrials; + } + set + { + _minIgnoredTrials = value; + } + } + + /// + /// Minimum consecutive unrewarded trials before auto water is triggered. + /// + [Newtonsoft.Json.JsonPropertyAttribute("min_unrewarded_trials")] + [System.ComponentModel.DescriptionAttribute("Minimum consecutive unrewarded trials before auto water is triggered.")] + public int MinUnrewardedTrials + { + get + { + return _minUnrewardedTrials; + } + set + { + _minUnrewardedTrials = value; + } + } + + /// + /// Fraction of full reward volume delivered during auto water (0=none, 1=full). + /// + [Newtonsoft.Json.JsonPropertyAttribute("reward_fraction")] + [System.ComponentModel.DescriptionAttribute("Fraction of full reward volume delivered during auto water (0=none, 1=full).")] + public double RewardFraction + { + get + { + return _rewardFraction; + } + set + { + _rewardFraction = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new AutoWaterParameters(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new AutoWaterParameters(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("MinIgnoredTrials = " + _minIgnoredTrials + ", "); + stringBuilder.Append("MinUnrewardedTrials = " + _minUnrewardedTrials + ", "); + stringBuilder.Append("RewardFraction = " + _rewardFraction); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] @@ -1455,6 +1565,8 @@ public partial class CoupledTrialGeneratorSpec : TrialGeneratorSpec private RewardProbabilityParameters _rewardProbabilityParameters; + private AutoWaterParameters _autowaterParameters; + private bool _isBaiting; private CoupledTrialGenerationEndConditions _trialGenerationEndParameters; @@ -1473,6 +1585,7 @@ public CoupledTrialGeneratorSpec() _minBlockReward = 1; _kernelSize = 2; _rewardProbabilityParameters = new RewardProbabilityParameters(); + _autowaterParameters = new AutoWaterParameters(); _isBaiting = false; _trialGenerationEndParameters = new CoupledTrialGenerationEndConditions(); _behaviorStabilityParameters = new BehaviorStabilityParameters(); @@ -1490,6 +1603,7 @@ protected CoupledTrialGeneratorSpec(CoupledTrialGeneratorSpec other) : _minBlockReward = other._minBlockReward; _kernelSize = other._kernelSize; _rewardProbabilityParameters = other._rewardProbabilityParameters; + _autowaterParameters = other._autowaterParameters; _isBaiting = other._isBaiting; _trialGenerationEndParameters = other._trialGenerationEndParameters; _behaviorStabilityParameters = other._behaviorStabilityParameters; @@ -1633,6 +1747,25 @@ public RewardProbabilityParameters RewardProbabilityParameters } } + /// + /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + + "ignored or unrewarded trial thresholds.")] + public AutoWaterParameters AutowaterParameters + { + get + { + return _autowaterParameters; + } + set + { + _autowaterParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -1729,6 +1862,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("MinBlockReward = " + _minBlockReward + ", "); stringBuilder.Append("KernelSize = " + _kernelSize + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); + stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", "); stringBuilder.Append("BehaviorStabilityParameters = " + _behaviorStabilityParameters + ", "); @@ -5369,6 +5503,8 @@ public partial class WarmupTrialGeneratorSpec : TrialGeneratorSpec private RewardProbabilityParameters _rewardProbabilityParameters; + private AutoWaterParameters _autowaterParameters; + private bool _isBaiting; private WarmupTrialGenerationEndConditions _trialGenerationEndParameters; @@ -5383,6 +5519,7 @@ public WarmupTrialGeneratorSpec() _minBlockReward = 1; _kernelSize = 2; _rewardProbabilityParameters = new RewardProbabilityParameters(); + _autowaterParameters = new AutoWaterParameters(); _isBaiting = true; _trialGenerationEndParameters = new WarmupTrialGenerationEndConditions(); } @@ -5398,6 +5535,7 @@ protected WarmupTrialGeneratorSpec(WarmupTrialGeneratorSpec other) : _minBlockReward = other._minBlockReward; _kernelSize = other._kernelSize; _rewardProbabilityParameters = other._rewardProbabilityParameters; + _autowaterParameters = other._autowaterParameters; _isBaiting = other._isBaiting; _trialGenerationEndParameters = other._trialGenerationEndParameters; } @@ -5539,6 +5677,25 @@ public RewardProbabilityParameters RewardProbabilityParameters } } + /// + /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + + "ignored or unrewarded trial thresholds.")] + public AutoWaterParameters AutowaterParameters + { + get + { + return _autowaterParameters; + } + set + { + _autowaterParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -5598,6 +5755,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("MinBlockReward = " + _minBlockReward + ", "); stringBuilder.Append("KernelSize = " + _kernelSize + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); + stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters); return true; @@ -6453,6 +6611,11 @@ public System.IObservable Process(System.IObservable(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -6641,6 +6804,7 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 0112c95..bb46a4c 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -20,6 +20,21 @@ logger = logging.getLogger(__name__) +class AutoWaterParameters(BaseModel): + min_ignored_trials: int = Field( + default=3, ge=0, description="Minimum consecutive ignored trials before auto water is triggered." + ) + min_unrewarded_trials: int = Field( + default=3, ge=0, description="Minimum consecutive unrewarded trials before auto water is triggered." + ) + reward_fraction: float = Field( + default=0.8, + ge=0, + le=1, + description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", + ) + + class RewardProbabilityParameters(BaseModel): """Defines the reward probability structure for a dynamic foraging task. @@ -93,6 +108,12 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel): validate_default=True, ) + autowater_parameters: Optional[AutoWaterParameters] = Field( + default=AutoWaterParameters(), + validate_default=True, + description="Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + ) + is_baiting: bool = Field(default=False, description="Whether uncollected rewards carry over to the next trial.") def create_generator(self) -> "BlockBasedTrialGenerator": @@ -157,29 +178,53 @@ def next(self) -> Trial | None: iti = draw_sample(self.spec.inter_trial_interval_duration) quiescent = draw_sample(self.spec.quiescent_duration) - p_reward_left = self.block.p_left_reward - p_reward_right = self.block.p_right_reward - + # determine baiting if self.spec.is_baiting: random_numbers = np.random.random(2) - is_left_baited = self.block.p_left_reward > random_numbers[0] or self.is_left_baited - logger.debug(f"Left baited: {is_left_baited}") - p_reward_left = 1 if is_left_baited else p_reward_left + self.is_left_baited = self.block.p_left_reward > random_numbers[0] or self.is_left_baited + logger.debug(f"Left baited: {self.is_left_baited}") - is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited - logger.debug(f"Right baited: {is_left_baited}") - p_reward_right = 1 if is_right_baited else p_reward_right + self.is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited + logger.debug(f"Right baited: {self.is_right_baited}") + + # determine autowater + if self._are_autowater_conditions_met(): + is_right_autowater = True if self.block.p_right_reward > self.block.p_left_reward else False return Trial( - p_reward_left=p_reward_left, - p_reward_right=p_reward_right, + p_reward_left=1 if (self.is_left_baited and self.spec.is_baiting) else self.block.p_left_reward, + p_reward_right=1 if (self.is_right_baited and self.spec.is_baiting) else self.block.p_right_reward, reward_consumption_duration=self.spec.reward_consumption_duration, response_deadline_duration=self.spec.response_duration, quiescence_period_duration=quiescent, inter_trial_interval_duration=iti, + is_auto_response_right=is_right_autowater, ) + def _are_autowater_conditions_met(self) -> bool: + """Checks whether autowater should be given. + + Returns: + True if autowater conditions are met, False otherwise. + """ + + if self.spec.autowater_parameters is None: # autowater disabled + return False + + min_ignore = self.spec.autowater_parameters.min_ignored_trials + min_unreward = self.spec.autowater_parameters.min_unrewarded_trials + + is_ignored = [choice is None for choice in self.is_right_choice_history] + if all(is_ignored[-min_ignore:]): + return True + + is_unrewarded = [not reward for reward in self.reward_history] + if all(is_unrewarded[-min_unreward:]): + return True + + return False + @abstractmethod def _are_end_conditions_met(self) -> bool: """Checks whether the session should end. @@ -195,7 +240,7 @@ def _generate_next_block( reward_pairs: list[list[float, float]], base_reward_sum: float, block_len: Union[UniformDistribution, ExponentialDistribution], - current_block: Optional[None] = None, + current_block: Optional[Block] = None, ) -> Block: """Generates the next block, avoiding repeating the current block's side bias. diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 0b5daae..0943e38 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -115,18 +115,6 @@ def test_next_returns_correct_reward_probs(self): self.assertEqual(trial.p_reward_left, self.generator.block.p_left_reward) self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward) - #### Test unbaited #### - - def test_baiting_disabled_reward_prob_unchanged(self): - """Without baiting, reward probs should equal block probs exactly.""" - self.generator.block = Block(p_right_reward=0.8, p_left_reward=0.2, min_length=10) - self.generator.is_left_baited = True - self.generator.is_right_baited = True - trial = self.generator.next() - - self.assertEqual(trial.p_reward_right, 0.8) - self.assertEqual(trial.p_reward_left, 0.2) - class TestBlockBaseBaitingTrialGenerator(unittest.TestCase): def setUp(self):