-
Notifications
You must be signed in to change notification settings - Fork 0
/
qlearning.py
158 lines (122 loc) · 5.77 KB
/
qlearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Classes implementing the general q-learning algorithm.
"""
import random as rand
from random import random
class QTable(object):
"""A table of state-action qvalues, with the learning logic.
"""
DEFAULT_QVALUE = 0
def __init__(self, states, actions, learning_rate,
discount_factor, default_qvalue=None):
"""Initializes the table and checks parameter restrictions.
In case a parameter restriction isn't met, a ValueError is raised:
- 0 <= learning_rate <= 1
- 0 <= discount_factor < 1
"""
# Verify and initialize learning_rate
if 0 <= learning_rate and learning_rate <= 1:
self.__learning_rate = learning_rate
else:
raise ValueError("Invalid learning rate of %d not contained in [0,1]."
% learning_rate)
# Verify and initialize discount_factor
if 0 <= discount_factor and discount_factor < 1:
self.__discount_factor = discount_factor
else:
raise ValueError("Invalid discount factor of %d not contained in [0,1)."
% discount_factor)
# Initialize the table of qvalues
if default_qvalue is None:
default_qvalue = self.DEFAULT_QVALUE
self.__table = dict( (s, dict((a, default_qvalue) for a in actions))
for s in states )
self.__states = tuple(states)
self.__actions = tuple(actions)
def __getitem__(self, pair):
return self.__table[pair[0]][pair[1]]
def subtable(self, state):
"""Obtain the part of the table describing the given state.
The result is a dict from actions to their q-value
on the given state.
"""
return self.__table[state]
@property
def states(self):
"""Return a tuple with all possible states."""
return self.__states
@property
def actions(self):
"""Return a tuple witha all possible actions."""
return self.__actions
def observe(self, state, action, new_state, reward):
"""Update q-values according to the observed behavior."""
max_future = max( self[new_state, new_action]
for new_action in self.__actions )
old_val = self[state, action]
change = reward + (self.__discount_factor * max_future) - old_val
self.__table[state][action] = old_val + (self.__learning_rate * change)
def act(self, state):
"""Return the recommended action for this state.
The choice of action may include random exploration.
"""
raise NotImplementedError("The basic QTable has no policy.")
class EpsilonGreedyQTable(QTable):
"""QTable with epsilon-greedy strategy.
With a given probability, called curiosity, chooses a random action.
Otherwise, does a greedy choice (action with greatest qvalue, in case
of a tie randomly picks one of the best).
The curiosity decays a given percentage after each random choice.
"""
DEFAULT_CURIOSITY_DECAY = 0
def __init__(self, states, actions, learning_rate, discount_factor, curiosity,
curiosity_decay=None, default_qvalue=None):
"""Initializes the table and checks for parameter restrictions.
In case a parameter restriction isn't met, a ValueError is raised:
- all restrictions from QTable apply
- 0 <= curiosity_decay < 1
"""
super(EpsilonGreedyQTable, self).__init__(states, actions, learning_rate,
discount_factor, default_qvalue)
self.__curiosity = curiosity
# Verify and initialize curiosity_decay
if curiosity_decay is None:
curiosity_decay = self.DEFAULT_CURIOSITY_DECAY
if 0 <= curiosity_decay and curiosity_decay < 1:
self.__curiosity_factor = 1 - curiosity_decay
else:
raise ValueError("Invalid curiosity decay %d not contained in [0,1)."
% curiosity_decay)
def act(self, state):
if random() < self.__curiosity:
return self.random_choice(state)
else:
return self.greedy_choice(state)
def random_choice(self, state):
"""Makes a uniformly random choice between all actions for this state."""
self.__curiosity *= self.__curiosity_factor
return rand.choice(self.actions)
def greedy_choice(self, state):
"""Makes a uniformly random choice between the actions with best q-value.
The analysed q-values are those related with actions on this state."""
best_qval = max(self.subtable(state).values())
best_actions = [act for (act, qval) in self.subtable(state).items()
if qval == best_qval]
return rand.choice(best_actions)
class EpsilonFirstQTable(EpsilonGreedyQTable):
"""QTable with epsilon-greedy strategy limited to the first N actions."""
def __init__(self, states, actions, learning_rate, discount_factor, curiosity,
exploration_period, curiosity_decay=None, default_qvalue=None):
super(EpsilonFirstQTable, self).__init__(states, actions, learning_rate,
discount_factor, curiosity,
curiosity_decay, default_qvalue)
if exploration_period > 0:
self.__remaining_exploration = exploration_period
else:
raise ValueError("Invalid non-positive exploration period %d."
% exploration_period)
def act(self, state):
if self.__remaining_exploration > 0:
self.__remaining_exploration -= 1
return super(EpsilonFirstQTable, self).act(state)
else:
return self.greedy_choice(state)