|
42 | 42 | v.n_rewards = 0 # Number of rewards obtained.
|
43 | 43 | v.n_trials = 0 # Number of trials recieved.
|
44 | 44 | v.n_blocks = 0 # Number of reversals.
|
45 |
| -v.state = withprob(0.5) # Which side is currently good: True: left, False: right |
46 |
| -v.mov_ave = exp_mov_ave(tau=v.tau, init_value = 0.5) # Moving average of choices. |
| 45 | +v.good_side = choice(['left', 'right']) # Which side is currently good. |
| 46 | +v.correct_mov_ave = exp_mov_ave(tau=v.tau, init_value = 0.5) # Moving average of correct/incorrect choices |
47 | 47 | v.threshold_crossed = False # Whether performance threshold has been crossed.
|
48 | 48 | v.trials_till_reversal = 0 # Used after threshold crossing to trigger reversal.
|
49 | 49 |
|
50 | 50 | #-------------------------------------------------------------------------
|
51 | 51 | # Non-state machine code.
|
52 | 52 | #-------------------------------------------------------------------------
|
53 | 53 |
|
54 |
| -def trial_outcome(choice): |
| 54 | +def get_trial_outcome(chosen_side): |
55 | 55 | # Function called after choice is made which determines trial outcome,
|
56 | 56 | # controls when reversals happen, and prints trial information.
|
57 | 57 |
|
58 | 58 | # Determine trial outcome.
|
59 |
| - if choice: # Subject chose left. |
60 |
| - if v.state: |
61 |
| - reward_prob = v.good_prob |
62 |
| - else: |
63 |
| - reward_prob = v.bad_prob |
64 |
| - else: # Subject chose right |
65 |
| - if v.state: |
66 |
| - reward_prob = v.bad_prob |
67 |
| - else: |
68 |
| - reward_prob = v.good_prob |
69 |
| - outcome = withprob(reward_prob) # Whether trial is rewarded or not. |
| 59 | + if chosen_side == v.good_side: # Subject choose good side. |
| 60 | + v.outcome = withprob(v.good_prob) |
| 61 | + v.correct_mov_ave.update(1) |
| 62 | + |
| 63 | + else: |
| 64 | + v.outcome = withprob(v.bad_prob) |
| 65 | + v.correct_mov_ave.update(0) |
70 | 66 |
|
71 | 67 | # Determine when reversal occurs.
|
72 |
| - v.mov_ave.update(choice) # Update moving average of choices. |
73 | 68 | if v.threshold_crossed: # Subject has already crossed threshold.
|
74 | 69 | v.trials_till_reversal -= 1
|
75 | 70 | if v.trials_till_reversal == 0: # Trigger reversal.
|
76 |
| - v.state = not v.state |
| 71 | + v.good_side = 'left' if (v.good_side == 'right') else 'right' |
| 72 | + v.correct_mov_ave.value = 1 - v.correct_mov_ave.value |
77 | 73 | v.threshold_crossed = False
|
78 | 74 | v.n_blocks += 1
|
79 | 75 | else: # Check for threshold crossing.
|
80 |
| - if (( v.state and (v.mov_ave.value > v.threshold)) or |
81 |
| - (not v.state and (v.mov_ave.value < (1- v.threshold)))): |
| 76 | + if v.correct_mov_ave.value > v.threshold: |
82 | 77 | v.threshold_crossed = True
|
83 | 78 | v.trials_till_reversal = randint(*v.trials_post_threshold)
|
84 | 79 |
|
85 | 80 | # Print trial information.
|
86 | 81 | v.n_trials +=1
|
87 |
| - v.n_rewards += outcome |
88 |
| - print_variables(['n_trials', 'n_rewards', 'n_blocks', 'choice', 'outcome', 'state']) |
89 |
| - |
90 |
| - return outcome |
| 82 | + v.n_rewards += v.outcome |
| 83 | + v.choice = chosen_side |
| 84 | + v.ave_correct = v.correct_mov_ave.value |
| 85 | + print_variables(['n_trials', 'n_rewards', 'n_blocks', 'good_side', 'choice', 'outcome', 'ave_correct']) |
| 86 | + return v.outcome |
91 | 87 |
|
92 | 88 | #-------------------------------------------------------------------------
|
93 | 89 | # State machine code.
|
@@ -116,20 +112,20 @@ def init_trial(event):
|
116 | 112 | goto_state('choice_state')
|
117 | 113 |
|
118 | 114 | def choice_state(event):
|
119 |
| - # Wait for left or right choice, evaluate if reward is delivered using trial_outcome function. |
| 115 | + # Wait for left or right choice, evaluate if reward is delivered using get_trial_outcome function. |
120 | 116 | if event == 'entry':
|
121 | 117 | hw.left_poke.LED.on()
|
122 | 118 | hw.right_poke.LED.on()
|
123 | 119 | elif event == 'exit':
|
124 | 120 | hw.left_poke.LED.off()
|
125 | 121 | hw.right_poke.LED.off()
|
126 | 122 | elif event == 'left_poke':
|
127 |
| - if trial_outcome(True): |
| 123 | + if get_trial_outcome('left'): |
128 | 124 | goto_state('left_reward')
|
129 | 125 | else:
|
130 | 126 | goto_state('inter_trial_interval')
|
131 | 127 | elif event == 'right_poke':
|
132 |
| - if trial_outcome(False): |
| 128 | + if get_trial_outcome('right'): |
133 | 129 | goto_state('right_reward')
|
134 | 130 | else:
|
135 | 131 | goto_state('inter_trial_interval')
|
|
0 commit comments