Skip to content

Commit 7233d5a

Browse files
committed
Edited reversal learning example to try and make logic for getting trial outcome and handling reversals more readable.
1 parent 403a661 commit 7233d5a

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

tasks/example/reversal_learning.py

+21-25
Original file line numberDiff line numberDiff line change
@@ -42,52 +42,48 @@
4242
v.n_rewards = 0 # Number of rewards obtained.
4343
v.n_trials = 0 # Number of trials recieved.
4444
v.n_blocks = 0 # Number of reversals.
45-
v.state = withprob(0.5) # Which side is currently good: True: left, False: right
46-
v.mov_ave = exp_mov_ave(tau=v.tau, init_value = 0.5) # Moving average of choices.
45+
v.good_side = choice(['left', 'right']) # Which side is currently good.
46+
v.correct_mov_ave = exp_mov_ave(tau=v.tau, init_value = 0.5) # Moving average of correct/incorrect choices
4747
v.threshold_crossed = False # Whether performance threshold has been crossed.
4848
v.trials_till_reversal = 0 # Used after threshold crossing to trigger reversal.
4949

5050
#-------------------------------------------------------------------------
5151
# Non-state machine code.
5252
#-------------------------------------------------------------------------
5353

54-
def trial_outcome(choice):
54+
def get_trial_outcome(chosen_side):
5555
# Function called after choice is made which determines trial outcome,
5656
# controls when reversals happen, and prints trial information.
5757

5858
# Determine trial outcome.
59-
if choice: # Subject chose left.
60-
if v.state:
61-
reward_prob = v.good_prob
62-
else:
63-
reward_prob = v.bad_prob
64-
else: # Subject chose right
65-
if v.state:
66-
reward_prob = v.bad_prob
67-
else:
68-
reward_prob = v.good_prob
69-
outcome = withprob(reward_prob) # Whether trial is rewarded or not.
59+
if chosen_side == v.good_side: # Subject choose good side.
60+
v.outcome = withprob(v.good_prob)
61+
v.correct_mov_ave.update(1)
62+
63+
else:
64+
v.outcome = withprob(v.bad_prob)
65+
v.correct_mov_ave.update(0)
7066

7167
# Determine when reversal occurs.
72-
v.mov_ave.update(choice) # Update moving average of choices.
7368
if v.threshold_crossed: # Subject has already crossed threshold.
7469
v.trials_till_reversal -= 1
7570
if v.trials_till_reversal == 0: # Trigger reversal.
76-
v.state = not v.state
71+
v.good_side = 'left' if (v.good_side == 'right') else 'right'
72+
v.correct_mov_ave.value = 1 - v.correct_mov_ave.value
7773
v.threshold_crossed = False
7874
v.n_blocks += 1
7975
else: # Check for threshold crossing.
80-
if (( v.state and (v.mov_ave.value > v.threshold)) or
81-
(not v.state and (v.mov_ave.value < (1- v.threshold)))):
76+
if v.correct_mov_ave.value > v.threshold:
8277
v.threshold_crossed = True
8378
v.trials_till_reversal = randint(*v.trials_post_threshold)
8479

8580
# Print trial information.
8681
v.n_trials +=1
87-
v.n_rewards += outcome
88-
print_variables(['n_trials', 'n_rewards', 'n_blocks', 'choice', 'outcome', 'state'])
89-
90-
return outcome
82+
v.n_rewards += v.outcome
83+
v.choice = chosen_side
84+
v.ave_correct = v.correct_mov_ave.value
85+
print_variables(['n_trials', 'n_rewards', 'n_blocks', 'good_side', 'choice', 'outcome', 'ave_correct'])
86+
return v.outcome
9187

9288
#-------------------------------------------------------------------------
9389
# State machine code.
@@ -116,20 +112,20 @@ def init_trial(event):
116112
goto_state('choice_state')
117113

118114
def choice_state(event):
119-
# Wait for left or right choice, evaluate if reward is delivered using trial_outcome function.
115+
# Wait for left or right choice, evaluate if reward is delivered using get_trial_outcome function.
120116
if event == 'entry':
121117
hw.left_poke.LED.on()
122118
hw.right_poke.LED.on()
123119
elif event == 'exit':
124120
hw.left_poke.LED.off()
125121
hw.right_poke.LED.off()
126122
elif event == 'left_poke':
127-
if trial_outcome(True):
123+
if get_trial_outcome('left'):
128124
goto_state('left_reward')
129125
else:
130126
goto_state('inter_trial_interval')
131127
elif event == 'right_poke':
132-
if trial_outcome(False):
128+
if get_trial_outcome('right'):
133129
goto_state('right_reward')
134130
else:
135131
goto_state('inter_trial_interval')

0 commit comments

Comments
 (0)