Adding Q-Learning
In order to add Qlearning to the game, we need to subclass our game. Assuming the end code from the last part is placed inside a file called snake.py
then we can create our outline as follows in qsnake.py
:
import snake
import pygame
class QGame(snake.Game):
def __init__(self):
super().__init__(60, 800, 600)
self.snake = Snake(self, 40)
def main():
game = QGame()
game.play()
class Snake(snake.Snake):
def __init__(self, game, size):
super().__init__(game, size)
if __name__ == "__main__":
main()
Now we are able to launch our program python3 qsnake.py
and we will get the game that we have just finished. This allows us to both edit our new file qsnake
but also edit the original snake
file and see the effects in qsnake
.
Now we can begin to set up the QLearning aspect of our snake. The core of the QLearning algorithm is the Q-table. The algorithm will constantly reference this table in order to make decisions. If we imagine a snake in the middle of a game, it can, at any time, take one of four possible actions. These actions are which direction to change to. Looking at our snake.py
, the only action the user has control over is the direction of the snake. This can be represented along the lines of the following table.
UP | DOWN | LEFT | RIGHT | ||
---|---|---|---|---|---|
State: | 0 | .5 | -5 | 3 |
Here we can see this table has 4 possible directions and that entry can be looked up by the current state. Being at this state the reward for going down would be .5. Going right would yield the highest reward of 3. Therefore, if we looked up this table in order to choose our next action, we would choose to go right.
If we want to represent this in Python, a good choice is pandas. Using pandas requires us to install it so python can access it. $ pip3 install pandas
This allows us to use a Dataframe to represent our table exactly as it’s displayed above. However, our QTable will be a very specific form of DataFrame, so it should be subclassed.
#In qsnake.py
import pandas as pd
class QTable(pd.DataFrame)
def __init__(self, learning_rate=.1, discount_factor=.9):
super().__init__(columns=["UP", "DOWN", "LEFT", "RIGHT"])
self.learning_rate = learning_rate
self.discount_factor = discount_factor
Here we are creating a pandas Dataframe with the columns of UP
, DOWN
, LEFT
, and RIGHT
. Then, we set two instance variables learning\_rate
and discount\_factor
. These are important to the QLearning equation.
Q_new(state, action) = Q_old(state, action) + learning_rate * (reward + discount_factor * (maxQ(state + 1) - Q_old(state, action)))
The learning_rate represents to what degree new information overwrites old information.
The discount_factor represents how we balance current vs future rewards.
Note, once we have the QTable
class, we need to create an instance inside of QGame
The first thing we need to be able to do is return a row from the QTable
. Looking at the DataFrame.
loc[] function, it takes in the index
value of the row. In the above table, calling qtable.loc['State']
would return the row that contains [0, .5, -5, 3]
. However, if we try to access a row by index that doesn’t exist, a KeyError
will be thrown. That way, we know if we catch a KeyError
, we don’t have that entry in the QTable. In order to add a row to the QTable, we can use the DataFrame.
loc[]. This will access a row by the label, or throw a KeyError if it doesn’t exist. However, if it is being set as equal to something that doesn’t exist, it will create it. This way we can try to self.loc the index we want, if it throws a KeyError and we can add the row to the QTable. Finally, we can return the row whether it was added or not.
#In qsnake.py
#In class QTable
def getRow(self, index):
try:
self.loc[index]
except KeyError:
self.loc[index] = [0, 0, 0, 0]
finally:
return self.loc[index]
Here, getRow()
will try to return the existing row, or, if it doesn’t exist, will add the row and then return the added row.
We can test this in our main.
#in main
table = QTable()
print(table.getRow("a"))
#OUTPUT:
UP 0
DOWN 0
LEFT 0
RIGHT 0
Name: a, dtype: object
Now that we can create and access rows, we need to be able to create the indices we will use to access rows. In our QTable, the indicies will be the state encoding. This is the most important part of our algorithm, because the encoding determines all the information that the AI will have. If there are too many possible states, the AI will rarely be in the same state and be able to learn from different actions. However, if the AI doesn’t have enough information, then it may treat two states the same way when they should be treated differently. The initial encoding (which isn’t perfect) will be a bitmap. $ pip3 install bitmap
. The overall idea will be:
direction(2), quadrantOfFood(2), immediateSurrounding(8).
The four directions the snake could be travelling can be represented using 2 bits: 00 - UP 01 - LEFT 10 - DOWN 11 - RIGHT
If we imagine the board cut into quadrants with the origin at the head, then we can encode the positon of the food relative to the head by which quadrant it falls in. 00 - I 01 - II 10 - III 11 - IV
Note, we will have to decide what to do if the food falls on the same axis as the head. In this case, I will consider them to fall in the next quadrant. For example, if the food is in the middle of quadrant I and II, then I would encode it as being in qudrant II.
For the 8 surrounding bits, each of the 8 grid squares around the head of the snake will be analyzed. If there is an obstacle in that spot, then that bit will be a 1. If there is not an obstacle, the bit will be a 0. The order of this encoding will be: TopLeft, Left, BottomLeft, Top, Bottom, TopRight, Right, BottomRight.
If we put all of this together, a snake going up, with the food to the North East of the head, with no obstacles, would be encoded as [00][00][00000000].
A snake going DOWN with the food to the South East of the head, and an obstacle directly above and below the head would be encoded as [10][11][00011000].
Time to convert this into code:
#In qsnake.py
import bitmap
#in class QTable
def encodeState(self, snake_obj, food):
encoded_map = bitmap.BitMap(12)
bit_position = 0
leftBoundry = 0
rightBoundry = 800
bottomBoundry = 600
topBoundry = 0
#Encode the surrounding
#Go over the columns from left to right
for x in range(snake_obj.x - snake_obj.width, snake_obj.x + snake_obj.width * 2, snake_obj.width):
#Go over the squares from top to bottom
for y in range(snake_obj.y - snake_obj.width, snake_obj.y + snake_obj.width * 2, snake_obj.width):
if (x, y) == (snake_obj.x, snake_obj.y):
continue
#Loop over the tail
for block in [snake_obj] + snake_obj.tail:
#Check if that square has a tail block or hits a wall
if blog_2.Block(snake_obj.width, x, y).colliderect(block) or y == bottomBoundry or \
y == topBoundry or x == leftBoundry or x == rightBoundry:
encoded_map.set(bit_position)
break
bit_position += 1
#Enocde Quadrants
#food is on the right of the head and above or equal to the head
if food.x > snake_obj.x and food.y <= snake_obj.y:
bit_position += 2 #00
#Food is on the left or equal to the head and above the head
elif food.x <= snake_obj.x and food.y < snake_obj.y:
encoded_map.set(bit_position)
bit_position += 2
#Food is on the left of the head and below or equal to the head
elif food.x < snake_obj.x and food.y >= snake_obj.y:
bit_position += 1
encoded_map.set(bit_position)
bit_position += 1
#Food is on the right of or equal to the head and below the head
else:
encoded_map.set(bit_position)
encoded_map.set(bit_position + 1)
bit_position += 2
#Encode Direction
if snake_obj.direction.name == "UP":
bit_position += 2
elif snake_obj.direction.name == "LEFT":
encoded_map.set(bit_position)
bit_position += 2
elif snake_obj.direction.name == "DOWN":
bit_position += 1
encoded_map.set(bit_position)
bit_position += 1
else:
encoded_map.set(bit_position)
encoded_map.set(bit_position + 1)
bit_position += 2
return encoded_map.tostring()[4:] #We only need 12 bits, but the bitmap defaults to hold bytes (16 bits).
In order to test this to make sure it functions as intended, we will play the game, but as the game is played, we will stop it at certain intervals to examine the output and see if it matches.
To help visualize this, I added pause functionality so that at any point while paused, the encoded state can be seen in the console.
#In snake.py
#In Game.__init__()
self.pause = False
#In Game.play()
while not self.done:
for #...
while self.pause:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
self.pause = False
elif event.type == pygame.QUIT:
self.done = True
self.pause = False
if input_buffer and #...
In order to print the encoded state, we must override the play
function in QGame
and print the encoded state when the snake moves.
#In qsnake.py
def play(self):
#Note, all of the code is exactly the same except...
timer = 0
speed = 10
input_buffer = []
moved = False
while not self.done:
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.done = True
if event.type == pygame.KEYDOWN:
if moved:
self.userInput(event.key)
moved = False
else:
input_buffer.append(event.key)
while self.pause:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
self.pause = False
elif event.type == pygame.QUIT:
self.done = True
self.pause = False
if input_buffer and moved:
self.userInput(input_buffer.pop(0))
moved = False
if timer * speed > 1:
timer = 0
self.snake.move()
if self.snake.checkDead():
self.gameOver()
moved = True
###################
####This Line######
print(self.table.encodeState(self.snake, self.food))
###################
self.snake.checkEat()
self.screen.fill((0, 0, 0))
self.score.draw()
pygame.draw.rect(self.screen, snake.Snake.color, self.snake)
pygame.draw.rect(self.screen, snake.Food.color, self.food)
for block in self.snake.tail:
pygame.draw.rect(self.screen, snake.Block.color, block)
pygame.display.flip()
timer += self.clock.tick(self.fps) / 1000
Now if we pause the game at random points we can check if our encoded state matches what we expect it to be.
What would the state of this one be?
000001101000
What about this one?
010100000010
The next step will be writing an updateQValue
function that will take in the current state, the next state, the action taken, and the reward in order to update the Q-value inside the QTable.
The function for updating Q-values in math is:
Q_new(state, action) = Q_old(state, action) + learning_rate * (reward + discount_factor * (maxQ(state + 1) - Q_old(state, action)))
In order to accomplish this, we need to be able to access individual cells in our QTable and update them in place. Luckily, we already have our getRow
function that will return a pandas.
series. When we have a series, we can use its series.
at[] function to access values inside the series in the same way they were accessed in our getRow
function.
converting this to code yields the following:
#In qsnake.py
#In class QTable
def updateQValue(self, current_state, next_state, action, reward):
currentRow = self.getRow(current_state)
nextRow = self.getRow(next_state)
value = currentRow.at[action] #Q_old
newValue = reward + self.discount_factor * nextRow.max() - value
self.loc[current_state].at[action] = value + self.learning_rate * newValue
An important note, I recently realized that getting a row, adding a row, and then changing the first row would not update the Q-Table.
Ex:
However, if you edit the Q-Table, but index it using .loc
then it will produce the desired result:
row = self.getRow("I exist")
newRow = self.getRow("I Don't")
self.loc["I exist"].at["UP"] = 5 #WORKS
This can be tested with arbitrary inputs in the main
function.
#In main
row = game.table.getRow("h")
row.at["LEFT"] = 1
newRow = game.table.getRow("p")
newRow.at["UP"] = 4
game.table.updateQValue("h", "p", "LEFT", 2)
print(game.table.getRow("h")
#OUTPUT:
UP 0
DOWN 0
LEFT 1
RIGHT 0.56
name: h, dtype: object
Everything is looking pretty good! So far we have a QTable that can return a row, encode a state, and update a Q-value using the correct formula. Now, we need to update the Q-table at every movement of the snake, and then use this value to choose an action.
To choose an action, the algorithm will choose the action that has the highest Q-value in the table. If there are multiple actions with the same, maximum value, then the action is chosen randomly.
Luckily, panda.
series has helpful functions that do some of the functionality we are looking for. panda.series.
max() will return the maximum value in the series. panda.series.
items() will return an iterable list of index, value pairs. Using these functions, and random.
randint() in order to get a random index in the list of actions.
It is important to remember that we have to keep in mind our current direction. If we are going Direction.RIGHT, we can’t go Direction.LEFT. This is because our changeDirection function will allow it, but will just return.
#In qsnake.py
import random
#in class QTable
def chooseAction(self, current_state):
currentRow = self.getRow(current_state)
#List comprehension gives every index (action) that equals the max of the row of the current state we are in.
max_actions = [index for index, value in currentRow.items() if value == currentRow.max()]
next_action = max_actions[random.randint(0, len(max_actions) - 1)]
return next_action
In order to test this, we can add it to the same location where we printed our encoded state. This encoded state can be considered our current_state
.
#In qsnake.py
#in QGame.play()
#...other code
moved = True
self.current_state = self.table.encodeState(self.snake, self.food)
print(self.table.chooseAction())
Note we are using self.current_state
, but that will lead to a NameError if the variable is undefined at the time of access (whcih it is). To fix this, we will initialize it in the constructor of QGame
.
#In qsnake.py
#In Qgame.__init__()
self.current_state = self.table.encodeState(self.snake, self.food)
While this does work, it is hard to tell because we aren’t updating the Q-value so the algorithm is always choosing a random direction because all actions are at 0.
Now, to update the Q-table, we will have to call QTable.updateQValue()
after the next_action is chosen. The issue is that it requires a reward to be supplied as a parameter, and we haven’t defined how the snake will be rewarded.
A simple idea is to reward the snake when he eats, and do nothing else otherwise.
To check if the snake eats, we will add an instance variable to the Snake
class in snake.py.
#In snake.py
#In Snake.__init__():
self.ate = False
def checkEat(self):
if self.colliderect(self.game.food):
self.game.food.relocate()
self.growTail()
self.game.score.changeScore(1)
self.ate = True
else:
self.ate = False
Then our code to update Q-values after every move:
#In qsnake.py
#in QGame.play()
#...other code
moved = True
old_state = self.current_state
self.current_state = self.table.encodeState(self.snake, self.food)
new_action = self.table.chooseAction(self.current_state)
reward = 1 if self.snake.ate else 0
if reward == 1: #Only print the row if it is changing.
print(self.table.getRow(old_state))
self.table.updateQValue(old_state, self.current_state, new_action, reward)
print(self.table.getRow(old_state))
Finally, we can update Game.play()
so it isn’t user controlled input that causes a change in direction, but the computed action from QTable.chooseAction()
. Also, since we are only using the computed action, we can get rid of all actions that would detect user controlled input.
#In qsnake.py
#In QGame.play()
if timer *#...:
old_state = self.current_state
new_action = self.table.chooseAction(self.current_state)
self.snake.changeDirection(snake.Direction[new_action])
self.snake.move()
if self.snake.checkDead():
self.gameOver()
self.current_state = self.table.encodeState(self.snake, self.food)
reward = 1 if self.snake.ate else 0
if reward == 1: #Only print the row if it is chaning.
print(self.table.getRow(old_state))
self.table.updateQValue(old_state, self.current_state, new_action, reward)
print(self.table.getRow(old_state))
In the end, we will have the following:
#snake.py
import pygame
from enum import Enum
from random import randint
from math import floor
class Game():
def __init__(self, fps, width, height):
pygame.init()
self.done = False
self.fps = fps
self.clock = pygame.time.Clock()
self.screen = pygame.display.set_mode((width, height))
self.snake = Snake(self, 40)
self.food = Food(self, 40)
self.score = Score(self.screen, (10, 10))
self.pause = False
def checkPause(self, key):
if key == pygame.K_SPACE:
self.pause = ~self.pause
def play(self):
timer = 0
speed = 10
input_buffer = []
moved = False
while not self.done:
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.done = True
if event.type == pygame.KEYDOWN:
if moved:
self.userInput(event.key)
moved = False
else:
input_buffer.append(event.key)
while self.pause:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
self.checkPause(event.key)
if input_buffer and moved:
self.userInput(input_buffer.pop(0))
moved = False
if timer * speed > 1:
timer = 0
self.snake.move()
if self.snake.checkDead():
self.gameOver()
moved = True
self.snake.checkEat()
self.screen.fill((0, 0, 0))
self.score.draw()
pygame.draw.rect(self.screen, Snake.color, self.snake)
pygame.draw.rect(self.screen, Food.color, self.food)
for block in self.snake.tail:
pygame.draw.rect(self.screen, Block.color, block)
pygame.display.flip()
timer += self.clock.tick(self.fps) / 1000
def userInput(self, key):
if key == pygame.K_UP:
self.snake.changeDirection(Direction.UP)
elif key == pygame.K_DOWN:
self.snake.changeDirection(Direction.DOWN)
elif key == pygame.K_LEFT:
self.snake.changeDirection(Direction.LEFT)
elif key == pygame.K_RIGHT:
self.snake.changeDirection(Direction.RIGHT)
elif key == pygame.K_SPACE:
self.pause = ~self.pause
def gameOver(self):
print(f"You have died! You ate {len(self.snake.tail)} pieces of food!")
self.done = True
class Block(pygame.Rect):
color = (0, 128, 255) # blue
def __init__(self, size, x, y):
self.width = size
self.height = size
self.x = x
self.y = y
self.dim = size
class Direction(Enum):
UP = [0, -40]
DOWN = [0, 40]
LEFT = [-40, 0]
RIGHT = [40, 0]
NONE = [0, 0]
class Snake(Block):
color = (124, 252, 0)
def __init__(self, game, size):
Block.__init__(self, size,
floor((pygame.display.get_window_size()
[0]/size - 1)/2) * size,
floor((pygame.display.get_window_size()
[1]/size - 1)/2) * size)
self.direction = Direction.NONE
self.game = game
self.tail = []
def move(self):
if self.tail:
self.tail = self.tail[1:]
self.tail.append(Block(self.dim, self.x, self.y))
self.x += self.direction.value[0]
self.y += self.direction.value[1]
self.checkEat()
self.hitWall()
self.hitSelf()
def changeDirection(self, newDirection):
if newDirection.value == [-x for x in self.direction.value]:
return
self.direction = newDirection
def checkEat(self):
if self.colliderect(self.game.food):
self.game.food.relocate()
self.growTail()
self.game.score.changeScore(1)
self.ate = True
else:
self.ate = False
def growTail(self):
self.tail.append(Block(
self.dim, self.x - self.direction.value[0], self.y - self.direction.value[1]))
def hitWall(self):
if (self.x < 0 or self.y < 0
or self.x > pygame.display.get_window_size()[0] - self.dim
or self.y > pygame.display.get_window_size()[1] - self.dim
):
self.hit_wall = True
# Push back
self.x -= self.direction.value[0]
self.y -= self.direction.value[1]
else:
self.hit_wall = False
def hitSelf(self):
for block in self.tail:
if block.colliderect(self):
self.hit_self = True
return
self.hit_self = False
def checkDead(self):
if self.hit_self or self.hit_wall:
return True
else:
return False
class Food(Block):
color = (255, 0, 100)
def __init__(self, game, size):
Block.__init__(self, size, 0, 0)
self.game = game
self.relocate()
def relocate(self):
cols = floor(pygame.display.get_window_size()[0]/self.dim - 1)
rows = floor(pygame.display.get_window_size()[1]/self.dim - 1)
self.x, self.y = randint(0, cols) * \
self.dim, randint(0, rows) * self.dim
if not self.isSafe():
self.relocate()
def isSafe(self):
if self.colliderect(self.game.snake):
return False
else:
return True
class Score():
def __init__(self, screen, location):
self.font = pygame.font.Font(None, 36)
self.value = 0
self.location = location
self.screen = screen
def draw(self):
text = self.font.render(f"Score: {self.value}", 1, (255, 255, 255))
self.screen.blit(text, self.location)
def changeScore(self, change):
self.value += change
def reset(self):
self.value = 0
def main():
game = Game(60, 800, 600)
game.play()
if __name__ == "__main__":
main()
import snake
import pygame
import pandas as pd
import bitmap
import random
class QGame(snake.Game):
def __init__(self):
super().__init__(60, 800, 600)
self.table = QTable()
self.current_state = self.table.encodeState(self.snake, self.food)
def play(self):
timer = 0
speed = 10
input_buffer = []
moved = False
while not self.done:
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.done = True
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
self.pause = True
while self.pause:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
self.pause = False
elif event.type == pygame.QUIT:
self.done = True
self.pause = False
if timer * speed > 1:
timer = 0
old_state = self.current_state
new_action = self.table.chooseAction(self.current_state)
self.snake.changeDirection(snake.Direction[new_action])
self.snake.move()
if self.snake.checkDead():
self.gameOver()
self.current_state = self.table.encodeState(self.snake, self.food)
reward = 1 if self.snake.ate else 0
if reward == 1: #Only print the row if it is chaning.
print(self.table.getRow(old_state))
self.table.updateQValue(old_state, self.current_state, new_action, reward)
print(self.table.getRow(old_state))
self.snake.checkEat()
self.screen.fill((0, 0, 0))
self.score.draw()
pygame.draw.rect(self.screen, snake.Snake.color, self.snake)
pygame.draw.rect(self.screen, snake.Food.color, self.food)
for block in self.snake.tail:
pygame.draw.rect(self.screen, snake.Block.color, block)
pygame.display.flip()
timer += self.clock.tick(self.fps) / 1000
class QTable(pd.DataFrame):
def __init__(self, learning_rate=.1, discount_factor=.9):
super().__init__(columns=["UP", "DOWN", "LEFT", "RIGHT"])
self.learning_rate = learning_rate
self.discount_factor = discount_factor
def getRow(self, index):
try:
self.loc[index]
except KeyError:
self.loc[index] = [0, 0, 0, 0]
finally:
return self.loc[index]
def updateQValue(self, current_state, next_state, action, reward):
currentRow = self.getRow(current_state)
nextRow = self.getRow(next_state)
value = currentRow.at[action] #Q_old
newValue = reward + self.discount_factor * nextRow.max() - value
self.loc[current_state].at[action] = value + self.learning_rate * newValue
def chooseAction(self, current_state):
currentRow = self.getRow(current_state)
#List comprehension gives every index (action) that equals the max of the row of the current state we are in.
max_actions = [index for index, value in currentRow.items() if value == currentRow.max()]
next_action = max_actions[random.randint(0, len(max_actions) - 1)]
return next_action
def encodeState(self, snake_obj, food):
encoded_map = bitmap.BitMap(12)
bit_position = 0
leftBoundry = 0
topBoundry = 0
rightBoundry, bottomBoundry = pygame.display.get_window_size()
#Encode the surrounding
#Go over the columns from left to right
for x in range(snake_obj.x - snake_obj.width, snake_obj.x + snake_obj.width * 2, snake_obj.width):
#Go over the squares from top to bottom
for y in range(snake_obj.y - snake_obj.width, snake_obj.y + snake_obj.width * 2, snake_obj.width):
if (x, y) == (snake_obj.x, snake_obj.y): #Don't count the snake head.
continue
#Loop over the tail
for block in snake_obj.tail:
#Check if that square has a tail block or hits a wall
if snake.Block(x, y).colliderect(block) or x < leftBoundry or y < topBoundry or x > rightBoundry or \
y > bottomBoundry:
encoded_map.set(bit_position)
break
bit_position += 1
#Enocde Quadrants
#food is on the right of the head and above or equal to the head
if food.x > snake_obj.x and food.y <= snake_obj.y:
bit_position += 2 #00
#Food is on the left or equal to the head and above teh head
elif food.x <= snake_obj.x and food.y < snake_obj.y:
encoded_map.set(bit_position)
bit_position += 2
#Food is on the left of the head and below or equal to the head
elif food.x < snake_obj.x and food.y >= snake_obj.y:
bit_position += 1
encoded_map.set(bit_position)
bit_position += 1
#Food is on the right of or equal to the head and below the head
else:
encoded_map.set(bit_position)
encoded_map.set(bit_position + 1)
bit_position += 2
#Encode Direction
if snake_obj.direction.name == "UP":
bit_position += 2
elif snake_obj.direction.name == "LEFT":
encoded_map.set(bit_position)
bit_position += 2
elif snake_obj.direction.name == "DOWN":
bit_position += 1
encoded_map.set(bit_position)
bit_position += 1
else:
encoded_map.set(bit_position)
encoded_map.set(bit_position + 1)
bit_position += 2
return encoded_map.tostring()[4:] #We only need 12 bits, but the bitmap defaults to hold bytes (16 bits).
def main():
game = QGame()
#table = QTable()
game.play()
if __name__ == "__main__":
main()
This does work, but it can be difficult to see. If we think about how the snake learns, the only time learning happens is when food is eaten. That means the snake will only learn when it randomly bumps into the food. An idea to fix this is to give a reward of -.1 for every move that the snake doesn’t eat the food, but then we have the same problem. The snake doesn’t know where the food is; it can’t be expected to move towards the food without this key information.
The idea I settled on was to give a reward of -.2 for every move that takes the snake farther away form the food. However, if the snake gets closer to the food, it will get a reward of .1. This way, not only is the snake punished for moving farther away from the food, but it is rewarderd for moving towards the food. It is also important that the negative reward outweight the positive reward.
Imagine if the snake constanlty went in a square. This is technically moving two squares closer to the food and two squares away from the food. If they were the same value, the net reward would be 0, but we want to punish that type of activity. If the negative reward is greater, going in circles will lead to a net negative reward.
In order to implement this, we need a function that can return the distance the snake is from the food.
#In qsnake.py
#In class Game
def getDistanceBetweenFoodAndSnake(self):
x = abs(self.food.x - self.snake.x)
y = abs(self.food.y - self.snake.y)
return x + y
We take the absolute value of the distance between the head and the food because we don’t care if the food is to the left or the right, we only want the raw length. If we are imagining a coordinate plane with origin at the head, if the food is at point (-2, -2), the distance between (0, 0) and (-2, 0) is 2, not -2. We add the x and y value together to get the overall number of squares between the food and the snake.
Now that we can get the distance, we need to create a function to return the correct reward based on what move the snake just made.
#In qsnake.py
#In class Snake
def getReward(self):
if self.ate:
reward = 1
else:
last_distance = self.distance
self.distance = self.game.getDistanceBetweenFoodAndSnake()
if last_distance > self.distance:
reward = -.2
else:
reward = .1
return reward
Notice, we now have a new variable self.distance
we can’t forget to initialize this in the constructor
Now, we need to replace the line reward = 1 if self.snake.ate else 0
with reward = self.snake.getReward()
. This also means that because we are gettinga a reward every frame, we need to update the Q-Table every frame.
#In qsnake.py
#In QGame.play()
old_state = self.current_state
self.current_state = self.table.encodeState(self.snake, self.food)
new_action = self.table.chooseAction(self.current_state)
self.snake.changeDirection(snake.Direction[new_action])
reward = self.snake.getReward()
self.table.updateQValue(old_state, self.current_state, new_action, reward)
Now, if we run our code, we can start to see a moving snake. Let’s change the Snake.gameOver() function to print out the current Q-Table
This will lead to this type of output:
101000000000 0 0 0.01 0
011100000000 0 0.01 0 0
101100000000 0 0.109 0 0
100000001000 0 -0.02 -0.02 0
010001000000 0.01 0 0 0
000000010000 0 0 0 0.01
110000000010 0 0 0 0.01
110100000010 0.01 0 0 0
000100010000 0.11791 0 0 0
001100010000 0 0 0 0.01
111100000110 0 0.01 0 0
101100001001 0 0.01 0 0
101100001000 0 0.0454636 0 0
100010011100 0 0 0 0.01
110010010111 0 0 0 0
Here, we can see that the snake has been updating the Q-Table based on the actions it has been taking. For example, If the snake is going UP and the food is to the NW section, then it has received rewards so that it thinks the optimal move is to go UP. If we think about this, that should be the best move, because that will go closer to quadrant II. If we run over multiple iterations, we can see that the QTable will get larger and larger. First, we have to create a Game.reset() function. This way we can keep the QTable and just reset the food and the snake.
#In qsnake.py
#In class Game
def reset(self):
self.food = snake.Food(self, 40)
self.snake = Snake(self, 40)
self.current_state = self.table.encodeState(self.snake, self.food)
self.score.reset()
self.done = False
Now, once we call our reset, we can play the game again.
This can lead to a Q-Table like the following:
UP DOWN LEFT RIGHT
111000000000 0 0.01 0 0
111100000000 0 0 0 0.00940108
111000000000 -0.02 0.019 0 -0.02
001000000000 0 0 0.019 0
011100000000 0 0.0223058 -0.02 0
101100000000 0 0.15031 0 0
... ... ... ... ...
000010010000 0 0 0 0.01
110000010110 -0.02 0 0 0
001100010100 0 0 -0.02 0
011111010000 0 0.01 0 0
100001101000 0 0 0 0
(65 rows x 4 columns)
This means that the snake encountered 65 states. The maximum number of possible states in the qtable can be seen as 2^number of bits = 2^12 = 4,096 possible states.
One thing that we are missing is a negative reward upon death. If we give a -100 reward upon death (we really don’t want to die) then our reward function would change slightly
#In qsnake.py
#In Snake.getReward()
new_distance = self.game.getDistanceBetweenFoodAndSnake()
if self.ate:
reward = 1
self.distance = new_distance
elif self.checkDead():
reward = -100
else:
if new_distance < self.distance:
reward = .1
else:
reward = -.2
self.distance = new_distance
return reward
Now, we may want to train our snake over many many trials in order to see how it can perform. The overall idea in this is to train the game over a number of trials. After so many trials, then we will have a “real” game and get the final score. This entire process is repeated over a certain number of replications.
For example, if replications = 10 and trial_set = [1] then the snake will train for 1 game and play a real game. Then it will repeat playing 1 game then a real game 10 times. Then all of those real game scores are recorded. We will want to keep track of this data for later, so we will offer the option to save this information to a file.
#In qsnake.py
from scipy.stats import describe #pip3 install scipy
import json
#Not in any class
def __experiment(replications, trials):
final_scores = []
for replication in range(replications):
print(f"replication = {replication}")
game = QGame()
for trial in range(trials):
print(f"trial = {trial}")
game.play()
game.reset()
game.play()
final_scores += [game.score.value]
return final_scores
def train(replications, trial_set, out_file_name=None):
records = []
for trial in trial_set:
final_scores = __experiment(replications, trial)
record = {
'trials': trial,
'replications': replications,
'final_scores': final_scores
}
records += [record]
if out_file_name:
with open(out_file_name, 'a') as out_file:
json.dump(records + '\n', out_file)
out_file.write('\n')
else:
for record in records:
print(f"Trials: {record{'trials'}}; Replications: {replications}")
print(describe(record['final_scores']))
Here, there are two new functions that will help us train the snake.
__experiment()
will help actually run the test while train()
will help us record the data we get from the experiment.
Now we can call our train() function from the main method
This would produce the following output in “training_data.txt”
'[{"trials": 5, "replications": 2, "final_scores": [4, 2]}]
Now, it won’t take long to find out a major flaw with this system. It is incredibly slow. Try to use 100 replications and 2 games. It will take forever. In order to speed it up there are a couple things we could do. Instead of updating based on a timer, we could just update the game immediately when possible. However, if we don’t care about watching the snake, why not just avoid drawing the screen and allow even faster updating?
Results of time before updating
real 0m41.832s
user 0m3.184s
sys 0m1.024s
Let’s change our code to not show any screen if we specify that we are only training.
#In qsnake.py
#Lines are either commented out or indented differently
#In qsnake.play()
#In while not self.done
#Other code untouched
#if timer * speed > 1:
#timer = 0
old_state = self.current_state
new_action = self.table.chooseAction(self.current_state, self.snake.direction)
self.snake.changeDirection(blog_2.Direction[new_action])
self.snake.move()
self.snake.checkEat()
if self.snake.checkDead():
self.gameOver()
self.current_state = self.table.encodeState(self.snake, self.food)
reward = self.snake.getReward()
self.table.updateQValue(old_state, self.current_state, new_action, reward)
self.clock.tick()
#self.screen.fill((0, 0, 0))
#self.score.draw()
#pygame.draw.rect(self.screen, blog_2.Snake.color, self.snake)
#pygame.draw.rect(self.screen, blog_2.Food.color, self.food)
#for block in self.snake.tail:
# pygame.draw.rect(self.screen, blog_2.Block.color, block)
#pygame.display.flip()
#timer += self.clock.tick(self.fps) / 1000
However, if you still wanted to play the game normally, the quick fix would be to pass an argument to the QGame.__init__() function.
#In qsnake.py
#In class Game
def __init__(self, training=False):
#Other code...
self.training = training
Don’t forget to change this in experiment!! game = QGame(training=True)
Then, in the play function, a simple check.
#In qsnake.py
#In game.play()
def play(self):
timer = 0
speed = 10
moved = False
while not self.done:
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.done = True
if event.type == pygame.KEYDOWN:
if moved:
self.userInput(event.key)
moved = False
else:
input_buffer.append(event.key)
while self.pause:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
self.pause = False
elif event.type == pygame.QUIT:
self.done = True
self.pause = False
if self.training:
timer = 0
old_state = self.current_state
new_action = self.table.chooseAction(self.current_state, self.snake.direction)
self.snake.changeDirection(blog_2.Direction[new_action])
self.snake.move()
self.snake.checkEat()
if self.snake.checkDead():
self.gameOver()
self.current_state = self.table.encodeState(self.snake, self.food)
reward = self.snake.getReward()
self.table.updateQValue(old_state, self.current_state, new_action, reward)
self.clock.tick()
else:
if timer * speed > 1:
timer = 0
old_state = self.current_state
new_action = self.table.chooseAction(self.current_state, self.snake.direction)
self.snake.changeDirection(blog_2.Direction[new_action])
self.snake.move()
self.snake.checkEat()
if self.snake.checkDead():
self.gameOver()
self.current_state = self.table.encodeState(self.snake, self.food)
reward = self.snake.getReward()
self.table.updateQValue(old_state, self.current_state, new_action, reward)
self.screen.fill((0, 0, 0))
self.score.draw()
pygame.draw.rect(self.screen, blog_2.Snake.color, self.snake)
pygame.draw.rect(self.screen, blog_2.Food.color, self.food)
for block in self.snake.tail:
pygame.draw.rect(self.screen, blog_2.Block.color, block)
pygame.display.flip()
timer += self.clock.tick(self.fps) / 1000
Results of new training:
real 0m1.933s
user 0m2.051s
sys 0m0.334s
This is much, much faster. From 41 seconds to almost 2 seconds. However, this can be very very slow if we change it to say, 100 replications and say [10, 20] trials
real 3m46.570s
user 3m39.304s
sys 0m7.374s
That took 3 minutes! We can make this even faster if we introduce multiprocessing into our program. Why does this make sense? If we think about our program, it is very CPU intensive (computing things over and over again) and not very I/O intensive. Also, each tiral in trialset is independent, it uses it’s own game each time. This allows us to do all of the work for the game in separate processes and then combine all of the work into the file later.
I originally thought about multithreading, but that didn’t work because for 1, it is best to only update a game from the main thread, but each thread would be acting as the main thread, which the OS wouldn’t understand. Also, Python utilizes something called the GIL. This basically causes multithreaded applications that are CPU intensive (our application) to be run as a single-threaded application. Basically, multithreading is only effective if another thread is waiting for I/O. If there is no waiting, then the threads are “basically” run sequentially.
Our multiprocess application becomes
#In qsnake.py
#In train()
def train(replications, trial_set, out_file_name=None):
records = []
formatted_input = []
for trial in trial_set:
formatted_input.append([replication, trial])
with multiprocessing.Pool() as pool:
results = pool.starmap(experimnet, formatted_input)
for final_scores, trial in zip(results, trial_set):
record = {
'trials': trial,
'replications': replications,
'final_scores': final_scores
}
records += [record]
if out_file_name:
with open(out_file_name, 'a') as out_file:
json.dump(records, out_file)
out_file.write('\n')
else:
for record in records:
print(f"Trials: {record['trials']}; Replications: {replications}")
print(describe(record['final_scores']))
Results:
real 2m47.453s
user 4m21.260s
sys 0m11.165s
We were able to cut down a full minute! It is interesting to note, the time our code was actually on the CPU executing increased from 3m46s to 4m21s. This is the overhead of creating new processes on the OS.
Now we have a basic Q-Learning snake and the ability to test over many replications and trials. This will make our analysis of the snake easier In the next post, we will begin to go in depth in the results from our training. Then, we will go into work on adding other obstacles into the game.