Skip to content

Commit 0728312

Browse files
committed
bug fixes and linting
1 parent 3f0ec83 commit 0728312

1 file changed

Lines changed: 54 additions & 24 deletions

File tree

machine_learning/q_learning.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
See: [https://en.wikipedia.org/wiki/Q-learning](https://en.wikipedia.org/wiki/Q-learning)
99
"""
1010

11-
from collections import defaultdict
1211
import random
12+
from collections import defaultdict
13+
14+
# Type alias for state
15+
type State = tuple[int, int]
1316

1417
# Hyperparameters for Q-Learning
1518
LEARNING_RATE = 0.1
@@ -19,28 +22,30 @@
1922
EPSILON_MIN = 0.01
2023

2124
# Global Q-table to store state-action values
22-
q_table = defaultdict(lambda: defaultdict(float))
25+
q_table: dict[State, dict[int, float]] = defaultdict(lambda: defaultdict(float))
2326

2427
# Environment variables for simple grid world
2528
SIZE = 4
2629
GOAL = (SIZE - 1, SIZE - 1)
2730
current_state = (0, 0)
2831

2932

30-
def get_q_value(state, action):
33+
def get_q_value(state: State, action: int) -> float:
3134
"""
3235
Get Q-value for a given state-action pair.
3336
37+
>>> q_table.clear()
3438
>>> get_q_value((0, 0), 2)
3539
0.0
3640
"""
3741
return q_table[state][action]
3842

3943

40-
def get_best_action(state, available_actions):
44+
def get_best_action(state: State, available_actions: list[int]) -> int:
4145
"""
4246
Get the action with maximum Q-value in the given state.
4347
48+
>>> q_table.clear()
4449
>>> q_table[(0, 0)][1] = 0.7
4550
>>> q_table[(0, 0)][2] = 0.7
4651
>>> q_table[(0, 0)][3] = 0.5
@@ -54,14 +59,18 @@ def get_best_action(state, available_actions):
5459
return random.choice(best)
5560

5661

57-
def choose_action(state, available_actions):
62+
def choose_action(state: State, available_actions: list[int]) -> int:
5863
"""
5964
Choose action using epsilon-greedy policy.
6065
66+
>>> q_table.clear()
67+
>>> old_epsilon = EPSILON
6168
>>> EPSILON = 0.0
6269
>>> q_table[(0, 0)][1] = 1.0
6370
>>> q_table[(0, 0)][2] = 0.5
64-
>>> choose_action((0, 0), [1, 2])
71+
>>> result = choose_action((0, 0), [1, 2])
72+
>>> EPSILON = old_epsilon # Restore
73+
>>> result
6574
1
6675
"""
6776
global EPSILON
@@ -72,64 +81,84 @@ def choose_action(state, available_actions):
7281
return get_best_action(state, available_actions)
7382

7483

75-
def update(state, action, reward, next_state, next_available_actions, done=False):
84+
def update(
85+
state: State,
86+
action: int,
87+
reward: float,
88+
next_state: State,
89+
next_available_actions: list[int],
90+
done: bool = False,
91+
alpha: float | None = None,
92+
gamma: float | None = None,
93+
) -> None:
7694
"""
7795
Perform Q-value update for a transition using the Q-learning rule.
7896
7997
Q(s,a) <- Q(s,a) + alpha * (r + gamma * max_a' Q(s',a') - Q(s,a))
8098
81-
>>> LEARNING_RATE = 0.5
82-
>>> DISCOUNT_FACTOR = 0.9
83-
>>> update((0,0), 1, 1.0, (0,1), [1,2], done=True)
84-
>>> get_q_value((0,0), 1)
99+
>>> q_table.clear()
100+
>>> update((0, 0), 1, 1.0, (0, 1), [1, 2], done=True, alpha=0.5, gamma=0.9)
101+
>>> get_q_value((0, 0), 1)
85102
0.5
86103
"""
87104
global LEARNING_RATE, DISCOUNT_FACTOR
105+
alpha = alpha if alpha is not None else LEARNING_RATE
106+
gamma = gamma if gamma is not None else DISCOUNT_FACTOR
107+
max_q_next = 0.0 if done or not next_available_actions else max(
108+
get_q_value(next_state, a) for a in next_available_actions
88109
max_q_next = (
89110
0.0
90111
if done or not next_available_actions
91112
else max(get_q_value(next_state, a) for a in next_available_actions)
92113
)
93114
old_q = get_q_value(state, action)
94-
new_q = (1 - LEARNING_RATE) * old_q + LEARNING_RATE * (
95-
reward + DISCOUNT_FACTOR * max_q_next
115+
new_q = (1 - alpha) * old_q + alpha * (
116+
reward + gamma * max_q_next
96117
)
97118
q_table[state][action] = new_q
98119

99120

100-
def get_policy():
121+
def get_policy() -> dict[State, int]:
101122
"""
102123
Extract a deterministic policy from the Q-table.
103124
104-
>>> q_table[(1,2)][1] = 2.0
105-
>>> q_table[(1,2)][2] = 1.0
106-
>>> get_policy()[(1,2)]
125+
>>> q_table.clear()
126+
>>> q_table[(1, 2)][1] = 2.0
127+
>>> q_table[(1, 2)][2] = 1.0
128+
>>> get_policy()[(1, 2)]
107129
1
108130
"""
109-
policy = {}
131+
policy: dict[State, int] = {}
110132
for s, a_dict in q_table.items():
111133
if a_dict:
112134
policy[s] = max(a_dict, key=a_dict.get)
113135
return policy
114136

115137

116-
def reset_env():
138+
def reset_env() -> State:
117139
"""
118140
Reset the environment to initial state.
141+
142+
>>> old_state = current_state
143+
>>> current_state = (1, 1) # Simulate non-initial state
144+
>>> result = reset_env()
145+
>>> current_state = old_state # Restore for other tests
146+
>>> result
147+
(0, 0)
119148
"""
120149
global current_state
121150
current_state = (0, 0)
122151
return current_state
123152

124153

125-
def get_available_actions_env():
154+
def get_available_actions_env() -> list[int]:
126155
"""
127156
Get available actions in the current environment state.
128157
"""
129-
return [0, 1, 2, 3]
158+
return [0, 1, 2, 3] # 0: up, 1: right, 2: down, 3: left
130159

131160

132-
def step_env(action):
161+
def step_env(action: int) -> tuple[State, float, bool]:
133162
"""
134163
Take a step in the environment with the given action.
135164
"""
@@ -150,13 +179,13 @@ def step_env(action):
150179
return next_state, reward, done
151180

152181

153-
def run_q_learning():
182+
def run_q_learning() -> None:
154183
"""
155184
Run Q-Learning on the simple grid world environment.
156185
"""
157186
global EPSILON
158187
episodes = 200
159-
for episode in range(episodes):
188+
for _ in range(episodes):
160189
state = reset_env()
161190
done = False
162191
while not done:
@@ -178,3 +207,4 @@ def run_q_learning():
178207

179208
doctest.testmod()
180209
run_q_learning()
210+

0 commit comments

Comments
 (0)