Adding Q-Learning

In order to add Qlearning to the game, we need to subclass our game. Assuming the end code from the last part is placed inside a file called snake.py then we can create our outline as follows in qsnake.py:

import snake
import pygame

class QGame(snake.Game):
    def __init__(self):
        super().__init__(60, 800, 600)
        self.snake = Snake(self, 40)

def main():
    game = QGame()
    game.play()

class Snake(snake.Snake):
    def __init__(self, game, size):
        super().__init__(game, size)

if __name__ == "__main__":
    main()

Now we are able to launch our program python3 qsnake.py and we will get the game that we have just finished. This allows us to both edit our new file qsnake but also edit the original snake file and see the effects in qsnake.

Now we can begin to set up the QLearning aspect of our snake. The core of the QLearning algorithm is the Q-table. The algorithm will constantly reference this table in order to make decisions. If we imagine a snake in the middle of a game, it can, at any time, take one of four possible actions. These actions are which direction to change to. Looking at our snake.py, the only action the user has control over is the direction of the snake. This can be represented along the lines of the following table.

UP DOWN LEFT RIGHT
State: 0 .5 -5 3

Here we can see this table has 4 possible directions and that entry can be looked up by the current state. Being at this state the reward for going down would be .5. Going right would yield the highest reward of 3. Therefore, if we looked up this table in order to choose our next action, we would choose to go right.

If we want to represent this in Python, a good choice is pandas. Using pandas requires us to install it so python can access it. $ pip3 install pandas This allows us to use a Dataframe to represent our table exactly as it’s displayed above. However, our QTable will be a very specific form of DataFrame, so it should be subclassed.

#In qsnake.py
import pandas as pd
class QTable(pd.DataFrame)
    def __init__(self, learning_rate=.1, discount_factor=.9):
        super().__init__(columns=["UP", "DOWN", "LEFT", "RIGHT"])
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor

Here we are creating a pandas Dataframe with the columns of UP, DOWN, LEFT, and RIGHT. Then, we set two instance variables learning\_rate and discount\_factor. These are important to the QLearning equation.

Q_new(state, action) = Q_old(state, action) + learning_rate * (reward + discount_factor * (maxQ(state + 1) - Q_old(state, action)))

The learning_rate represents to what degree new information overwrites old information.

The discount_factor represents how we balance current vs future rewards.

Note, once we have the QTable class, we need to create an instance inside of QGame

#In qsnake.py
#in Game.__init__()
self.table = QTable()

The first thing we need to be able to do is return a row from the QTable. Looking at the DataFrame.loc[] function, it takes in the index value of the row. In the above table, calling qtable.loc['State'] would return the row that contains [0, .5, -5, 3]. However, if we try to access a row by index that doesn’t exist, a KeyError will be thrown. That way, we know if we catch a KeyError, we don’t have that entry in the QTable. In order to add a row to the QTable, we can use the DataFrame.loc[]. This will access a row by the label, or throw a KeyError if it doesn’t exist. However, if it is being set as equal to something that doesn’t exist, it will create it. This way we can try to self.loc the index we want, if it throws a KeyError and we can add the row to the QTable. Finally, we can return the row whether it was added or not.

#In qsnake.py
#In class QTable
def getRow(self, index):
    try:
        self.loc[index]
    except KeyError:
        self.loc[index] = [0, 0, 0, 0]
    finally:
        return self.loc[index]

Here, getRow() will try to return the existing row, or, if it doesn’t exist, will add the row and then return the added row.

We can test this in our main.

#in main
table = QTable()
print(table.getRow("a"))

#OUTPUT:
UP      0
DOWN    0
LEFT    0
RIGHT   0
Name: a, dtype: object

Now that we can create and access rows, we need to be able to create the indices we will use to access rows. In our QTable, the indicies will be the state encoding. This is the most important part of our algorithm, because the encoding determines all the information that the AI will have. If there are too many possible states, the AI will rarely be in the same state and be able to learn from different actions. However, if the AI doesn’t have enough information, then it may treat two states the same way when they should be treated differently. The initial encoding (which isn’t perfect) will be a bitmap. $ pip3 install bitmap. The overall idea will be:

direction(2), quadrantOfFood(2), immediateSurrounding(8).

The four directions the snake could be travelling can be represented using 2 bits: 00 - UP 01 - LEFT 10 - DOWN 11 - RIGHT

If we imagine the board cut into quadrants with the origin at the head, then we can encode the positon of the food relative to the head by which quadrant it falls in. 00 - I 01 - II 10 - III 11 - IV

Note, we will have to decide what to do if the food falls on the same axis as the head. In this case, I will consider them to fall in the next quadrant. For example, if the food is in the middle of quadrant I and II, then I would encode it as being in qudrant II.

For the 8 surrounding bits, each of the 8 grid squares around the head of the snake will be analyzed. If there is an obstacle in that spot, then that bit will be a 1. If there is not an obstacle, the bit will be a 0. The order of this encoding will be: TopLeft, Left, BottomLeft, Top, Bottom, TopRight, Right, BottomRight.

If we put all of this together, a snake going up, with the food to the North East of the head, with no obstacles, would be encoded as [00][00][00000000].

A snake going DOWN with the food to the South East of the head, and an obstacle directly above and below the head would be encoded as [10][11][00011000].

Time to convert this into code:

#In qsnake.py
import bitmap
#in class QTable
def encodeState(self, snake_obj, food):
    encoded_map = bitmap.BitMap(12)
    bit_position = 0
    leftBoundry = 0
    rightBoundry = 800
    bottomBoundry = 600
    topBoundry = 0

    #Encode the surrounding
    #Go over the columns from left to right
    for x in range(snake_obj.x - snake_obj.width, snake_obj.x + snake_obj.width * 2, snake_obj.width):
        #Go over the squares from top to bottom
        for y in range(snake_obj.y - snake_obj.width, snake_obj.y + snake_obj.width * 2, snake_obj.width):
            if (x, y) == (snake_obj.x, snake_obj.y):
                continue
            #Loop over the tail
            for block in [snake_obj] + snake_obj.tail:
                #Check if that square has a tail block or hits a wall
                if blog_2.Block(snake_obj.width, x, y).colliderect(block) or y == bottomBoundry or \
                    y == topBoundry or x == leftBoundry or x == rightBoundry:
                    encoded_map.set(bit_position)
                    break
         bit_position += 1

    #Enocde Quadrants
    #food is on the right of the head and above or equal to the head
    if food.x > snake_obj.x and food.y <= snake_obj.y:
        bit_position += 2 #00
    #Food is on the left or equal to the head and above the head
    elif food.x <= snake_obj.x and food.y < snake_obj.y:
        encoded_map.set(bit_position)
        bit_position += 2
    #Food is on the left of the head and below or equal to the head
    elif food.x < snake_obj.x and food.y >= snake_obj.y:
        bit_position += 1
        encoded_map.set(bit_position)
        bit_position += 1
    #Food is on the right of or equal to the head and below the head
    else:
        encoded_map.set(bit_position)
        encoded_map.set(bit_position + 1)
        bit_position += 2

    #Encode Direction
    if snake_obj.direction.name == "UP":
        bit_position += 2
    elif snake_obj.direction.name == "LEFT":
        encoded_map.set(bit_position)
        bit_position += 2
    elif snake_obj.direction.name == "DOWN":
        bit_position += 1
        encoded_map.set(bit_position)
        bit_position += 1
    else:
        encoded_map.set(bit_position)
        encoded_map.set(bit_position + 1)
        bit_position += 2

    return encoded_map.tostring()[4:] #We only need 12 bits, but the bitmap defaults to hold bytes (16 bits).

In order to test this to make sure it functions as intended, we will play the game, but as the game is played, we will stop it at certain intervals to examine the output and see if it matches.

To help visualize this, I added pause functionality so that at any point while paused, the encoded state can be seen in the console.

#In snake.py
#In Game.__init__()
self.pause = False

#In Game.play()
while not self.done:
    for #...
    while self.pause:
        for event in pygame.event.get():
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_SPACE:
                    self.pause = False
            elif event.type == pygame.QUIT:
                self.done = True
                self.pause = False
    if input_buffer and #...

In order to print the encoded state, we must override the play function in QGame and print the encoded state when the snake moves.

#In qsnake.py
def play(self):
#Note, all of the code is exactly the same except...
    timer = 0
    speed = 10
    input_buffer = []
    moved = False
    while not self.done:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                self.done = True

            if event.type == pygame.KEYDOWN:
                if moved:
                    self.userInput(event.key)
                    moved = False
                else:
                    input_buffer.append(event.key)

        while self.pause:
            for event in pygame.event.get():
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_SPACE:
                        self.pause = False
                elif event.type == pygame.QUIT:
                    self.done = True
                    self.pause = False

        if input_buffer and moved:
            self.userInput(input_buffer.pop(0))
            moved = False

        if timer * speed > 1:
            timer = 0
            self.snake.move()
            if self.snake.checkDead():
                self.gameOver()
            moved = True
            ###################
            ####This Line######
            print(self.table.encodeState(self.snake, self.food))
            ###################

        self.snake.checkEat()
        self.screen.fill((0, 0, 0))
        self.score.draw()
        pygame.draw.rect(self.screen, snake.Snake.color, self.snake)
        pygame.draw.rect(self.screen, snake.Food.color, self.food)

        for block in self.snake.tail:
            pygame.draw.rect(self.screen, snake.Block.color, block)

        pygame.display.flip()
        timer += self.clock.tick(self.fps) / 1000

Now if we pause the game at random points we can check if our encoded state matches what we expect it to be.

What would the state of this one be?

000001101000

What about this one?

010100000010

The next step will be writing an updateQValue function that will take in the current state, the next state, the action taken, and the reward in order to update the Q-value inside the QTable. The function for updating Q-values in math is:

Q_new(state, action) = Q_old(state, action) + learning_rate * (reward + discount_factor * (maxQ(state + 1) - Q_old(state, action)))

In order to accomplish this, we need to be able to access individual cells in our QTable and update them in place. Luckily, we already have our getRow function that will return a pandas.series. When we have a series, we can use its series.at[] function to access values inside the series in the same way they were accessed in our getRow function.

converting this to code yields the following:

#In qsnake.py
#In class QTable
def updateQValue(self, current_state, next_state, action, reward):
    currentRow = self.getRow(current_state)
    nextRow = self.getRow(next_state)
    value = currentRow.at[action] #Q_old

    newValue = reward + self.discount_factor * nextRow.max() - value
    self.loc[current_state].at[action] = value + self.learning_rate * newValue

An important note, I recently realized that getting a row, adding a row, and then changing the first row would not update the Q-Table.

Ex:

row = self.getRow("I exist")
newRow = self.getRow("I Don't")
row.at["UP"] = 5 #DOES NOTHING

However, if you edit the Q-Table, but index it using .loc then it will produce the desired result:

row = self.getRow("I exist")
newRow = self.getRow("I Don't")
self.loc["I exist"].at["UP"] = 5 #WORKS

This can be tested with arbitrary inputs in the main function.

#In main
row = game.table.getRow("h")
row.at["LEFT"] = 1
newRow = game.table.getRow("p")
newRow.at["UP"] = 4
game.table.updateQValue("h", "p", "LEFT", 2)
print(game.table.getRow("h")

#OUTPUT:
UP         0
DOWN       0
LEFT       1
RIGHT   0.56
name: h, dtype: object

Everything is looking pretty good! So far we have a QTable that can return a row, encode a state, and update a Q-value using the correct formula. Now, we need to update the Q-table at every movement of the snake, and then use this value to choose an action. To choose an action, the algorithm will choose the action that has the highest Q-value in the table. If there are multiple actions with the same, maximum value, then the action is chosen randomly. Luckily, panda.series has helpful functions that do some of the functionality we are looking for. panda.series.max() will return the maximum value in the series. panda.series.items() will return an iterable list of index, value pairs. Using these functions, and random.randint() in order to get a random index in the list of actions.

It is important to remember that we have to keep in mind our current direction. If we are going Direction.RIGHT, we can’t go Direction.LEFT. This is because our changeDirection function will allow it, but will just return.

#In qsnake.py
import random
#in class QTable
def chooseAction(self, current_state):
    currentRow = self.getRow(current_state)
    #List comprehension gives every index (action) that equals the max of the row of the current state we are in.
    max_actions = [index for index, value in currentRow.items() if value == currentRow.max()]
    next_action = max_actions[random.randint(0, len(max_actions) - 1)]
    return next_action

In order to test this, we can add it to the same location where we printed our encoded state. This encoded state can be considered our current_state.

#In qsnake.py
#in QGame.play()
#...other code
moved = True
self.current_state = self.table.encodeState(self.snake, self.food)
print(self.table.chooseAction())

Note we are using self.current_state, but that will lead to a NameError if the variable is undefined at the time of access (whcih it is). To fix this, we will initialize it in the constructor of QGame.

#In qsnake.py
#In Qgame.__init__()
self.current_state = self.table.encodeState(self.snake, self.food)

While this does work, it is hard to tell because we aren’t updating the Q-value so the algorithm is always choosing a random direction because all actions are at 0. Now, to update the Q-table, we will have to call QTable.updateQValue() after the next_action is chosen. The issue is that it requires a reward to be supplied as a parameter, and we haven’t defined how the snake will be rewarded. A simple idea is to reward the snake when he eats, and do nothing else otherwise. To check if the snake eats, we will add an instance variable to the Snake class in snake.py.

#In snake.py
#In Snake.__init__():
self.ate = False

def checkEat(self):
    if self.colliderect(self.game.food):
        self.game.food.relocate()
        self.growTail()
        self.game.score.changeScore(1)
        self.ate = True
    else:
        self.ate = False

Then our code to update Q-values after every move:

#In qsnake.py
#in QGame.play()
#...other code
moved = True
old_state = self.current_state
self.current_state = self.table.encodeState(self.snake, self.food)
new_action = self.table.chooseAction(self.current_state)
reward = 1 if self.snake.ate else 0
if reward == 1: #Only print the row if it is changing.
    print(self.table.getRow(old_state))
    self.table.updateQValue(old_state, self.current_state, new_action, reward)
    print(self.table.getRow(old_state))

Finally, we can update Game.play() so it isn’t user controlled input that causes a change in direction, but the computed action from QTable.chooseAction(). Also, since we are only using the computed action, we can get rid of all actions that would detect user controlled input.

#In qsnake.py
#In QGame.play()
if timer *#...:
    old_state = self.current_state
    new_action = self.table.chooseAction(self.current_state)
    self.snake.changeDirection(snake.Direction[new_action])
    self.snake.move()
    if self.snake.checkDead():
        self.gameOver()
    self.current_state = self.table.encodeState(self.snake, self.food)
    reward = 1 if self.snake.ate else 0
    if reward == 1: #Only print the row if it is chaning.
        print(self.table.getRow(old_state))
        self.table.updateQValue(old_state, self.current_state, new_action, reward)
        print(self.table.getRow(old_state))

In the end, we will have the following:

#snake.py
import pygame
from enum import Enum
from random import randint
from math import floor

class Game():
    def __init__(self, fps, width, height):
        pygame.init()
        self.done = False
        self.fps = fps
        self.clock = pygame.time.Clock()
        self.screen = pygame.display.set_mode((width, height))
        self.snake = Snake(self, 40)
        self.food = Food(self, 40) 
        self.score = Score(self.screen, (10, 10))
        self.pause = False

        def checkPause(self, key):
            if key == pygame.K_SPACE:
                self.pause = ~self.pause

    def play(self):
        timer = 0
        speed = 10
        input_buffer = []
        moved = False
        while not self.done:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.done = True

                if event.type == pygame.KEYDOWN:
                    if moved:
                        self.userInput(event.key)
                        moved = False
                    else:
                        input_buffer.append(event.key)

            while self.pause:
                for event in pygame.event.get():
                    if event.type == pygame.KEYDOWN:
                        self.checkPause(event.key)

            if input_buffer and moved:
                self.userInput(input_buffer.pop(0))
                moved = False

            if timer * speed > 1:
                timer = 0
                self.snake.move()
                if self.snake.checkDead():
                    self.gameOver()
                moved = True

            self.snake.checkEat()
            self.screen.fill((0, 0, 0))
            self.score.draw()
            pygame.draw.rect(self.screen, Snake.color, self.snake)
            pygame.draw.rect(self.screen, Food.color, self.food)

            for block in self.snake.tail:
                pygame.draw.rect(self.screen, Block.color, block)

            pygame.display.flip()
            timer += self.clock.tick(self.fps) / 1000

    def userInput(self, key):
        if key == pygame.K_UP:
            self.snake.changeDirection(Direction.UP)
        elif key == pygame.K_DOWN:
            self.snake.changeDirection(Direction.DOWN)
        elif key == pygame.K_LEFT:
            self.snake.changeDirection(Direction.LEFT)
        elif key == pygame.K_RIGHT:
            self.snake.changeDirection(Direction.RIGHT)
        elif key == pygame.K_SPACE:
            self.pause = ~self.pause

    def gameOver(self):
        print(f"You have died! You ate {len(self.snake.tail)} pieces of food!")
        self.done = True


class Block(pygame.Rect):
    color = (0, 128, 255)  # blue

    def __init__(self, size, x, y):
        self.width = size
        self.height = size
        self.x = x
        self.y = y
        self.dim = size


class Direction(Enum):
    UP = [0, -40]
    DOWN = [0, 40]
    LEFT = [-40, 0]
    RIGHT = [40, 0]
    NONE = [0, 0]


class Snake(Block):
    color = (124, 252, 0)

    def __init__(self, game, size):
        Block.__init__(self, size,
                       floor((pygame.display.get_window_size()
                              [0]/size - 1)/2) * size,
                       floor((pygame.display.get_window_size()
                              [1]/size - 1)/2) * size)
        self.direction = Direction.NONE
        self.game = game
        self.tail = []

    def move(self):
        if self.tail:
            self.tail = self.tail[1:]
            self.tail.append(Block(self.dim, self.x, self.y))

        self.x += self.direction.value[0]
        self.y += self.direction.value[1]

        self.checkEat()
        self.hitWall()
        self.hitSelf()
        

    def changeDirection(self, newDirection):
        if newDirection.value == [-x for x in self.direction.value]:
            return
        self.direction = newDirection

    def checkEat(self):
        if self.colliderect(self.game.food):
            self.game.food.relocate()
            self.growTail()
            self.game.score.changeScore(1)
            self.ate = True
        else:
            self.ate = False

    def growTail(self):
        self.tail.append(Block(
            self.dim, self.x - self.direction.value[0], self.y - self.direction.value[1]))

    def hitWall(self):
        if (self.x < 0 or self.y < 0
                    or self.x > pygame.display.get_window_size()[0] - self.dim
                    or self.y > pygame.display.get_window_size()[1] - self.dim
                ):
            self.hit_wall = True
            # Push back
            self.x -= self.direction.value[0]
            self.y -= self.direction.value[1]
        else:
            self.hit_wall = False

    def hitSelf(self):
        for block in self.tail:
            if block.colliderect(self):
                self.hit_self = True
                return
        self.hit_self = False

    def checkDead(self):
        if self.hit_self or self.hit_wall:
            return True
        else:
            return False


class Food(Block):
    color = (255, 0, 100)

    def __init__(self, game, size):
        Block.__init__(self, size, 0, 0)
        self.game = game
        self.relocate()

    def relocate(self):
        cols = floor(pygame.display.get_window_size()[0]/self.dim - 1)
        rows = floor(pygame.display.get_window_size()[1]/self.dim - 1)

        self.x, self.y = randint(0, cols) * \
            self.dim, randint(0, rows) * self.dim

        if not self.isSafe():
            self.relocate()

    def isSafe(self):
        if self.colliderect(self.game.snake):
            return False
        else:
            return True

class Score():
    def __init__(self, screen, location):
        self.font = pygame.font.Font(None, 36)
        self.value = 0
        self.location = location
        self.screen = screen

    def draw(self):
        text = self.font.render(f"Score: {self.value}", 1, (255, 255, 255))
        self.screen.blit(text, self.location)

    def changeScore(self, change):
        self.value += change

    def reset(self):
        self.value = 0



def main():
    game = Game(60, 800, 600)
    game.play()


if __name__ == "__main__":
    main()
import snake
import pygame
import pandas as pd
import bitmap
import random

class QGame(snake.Game):
    def __init__(self):
        super().__init__(60, 800, 600)
        self.table = QTable()
        self.current_state = self.table.encodeState(self.snake, self.food)

    def play(self):
        timer = 0
        speed = 10
        input_buffer = []
        moved = False
        while not self.done:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.done = True

                if event.type == pygame.KEYDOWN:
                                    if event.key == pygame.K_SPACE:
                                        self.pause = True

            while self.pause:
                for event in pygame.event.get():
                    if event.type == pygame.KEYDOWN:
                        if event.key == pygame.K_SPACE:
                            self.pause = False
                    elif event.type == pygame.QUIT:
                        self.done = True
                        self.pause = False

            if timer * speed > 1:
                            timer = 0
                            old_state = self.current_state
                            new_action = self.table.chooseAction(self.current_state)
                            self.snake.changeDirection(snake.Direction[new_action])
                            self.snake.move()
                            if self.snake.checkDead():
                                self.gameOver()
                            self.current_state = self.table.encodeState(self.snake, self.food)
                            reward = 1 if self.snake.ate else 0
                            if reward == 1: #Only print the row if it is chaning.
                                print(self.table.getRow(old_state))
                                self.table.updateQValue(old_state, self.current_state, new_action, reward)
                                print(self.table.getRow(old_state))

            self.snake.checkEat()
            self.screen.fill((0, 0, 0))
            self.score.draw()
            pygame.draw.rect(self.screen, snake.Snake.color, self.snake)
            pygame.draw.rect(self.screen, snake.Food.color, self.food)

            for block in self.snake.tail:
                pygame.draw.rect(self.screen, snake.Block.color, block)

            pygame.display.flip()
            timer += self.clock.tick(self.fps) / 1000

class QTable(pd.DataFrame):
    def __init__(self, learning_rate=.1, discount_factor=.9):
            super().__init__(columns=["UP", "DOWN", "LEFT", "RIGHT"])
            self.learning_rate = learning_rate
            self.discount_factor = discount_factor

    def getRow(self, index):
            try:
                self.loc[index]
            except KeyError:
                self.loc[index] = [0, 0, 0, 0] 
            finally:
                return self.loc[index]


    def updateQValue(self, current_state, next_state, action, reward):
            currentRow = self.getRow(current_state)
            nextRow = self.getRow(next_state)
            value = currentRow.at[action] #Q_old

            newValue = reward + self.discount_factor * nextRow.max() - value
            self.loc[current_state].at[action] = value + self.learning_rate * newValue

    def chooseAction(self, current_state):
            currentRow = self.getRow(current_state)
            #List comprehension gives every index (action) that equals the max of the row of the current state we are in.
            max_actions = [index for index, value in currentRow.items() if value == currentRow.max()]
            next_action = max_actions[random.randint(0, len(max_actions) - 1)]
            return next_action

    def encodeState(self, snake_obj, food):
            encoded_map = bitmap.BitMap(12)
            bit_position = 0
            leftBoundry = 0
            topBoundry = 0
            rightBoundry, bottomBoundry = pygame.display.get_window_size()

            #Encode the surrounding
            #Go over the columns from left to right
            for x in range(snake_obj.x - snake_obj.width, snake_obj.x + snake_obj.width * 2, snake_obj.width):
                    #Go over the squares from top to bottom
                    for y in range(snake_obj.y - snake_obj.width, snake_obj.y + snake_obj.width * 2, snake_obj.width):
                            if (x, y) == (snake_obj.x, snake_obj.y): #Don't count the snake head.
                                            continue
                            #Loop over the tail
                            for block in snake_obj.tail:
                                    #Check if that square has a tail block or hits a wall
                                    if snake.Block(x, y).colliderect(block) or x < leftBoundry or y < topBoundry or x > rightBoundry or \
                                            y > bottomBoundry:
                                            encoded_map.set(bit_position)
                                            break
                            bit_position += 1

                #Enocde Quadrants
        #food is on the right of the head and above or equal to the head
        if food.x > snake_obj.x and food.y <= snake_obj.y:
            bit_position += 2 #00
        #Food is on the left or equal to the head and above teh head
        elif food.x <= snake_obj.x and food.y < snake_obj.y:
            encoded_map.set(bit_position)
            bit_position += 2
        #Food is on the left of the head and below or equal to the head
        elif food.x < snake_obj.x and food.y >= snake_obj.y:
            bit_position += 1
            encoded_map.set(bit_position)
            bit_position += 1
        #Food is on the right of or equal to the head and below the head
        else:
            encoded_map.set(bit_position)
            encoded_map.set(bit_position + 1)
            bit_position += 2

        #Encode Direction
        if snake_obj.direction.name == "UP":
            bit_position += 2
        elif snake_obj.direction.name == "LEFT":
            encoded_map.set(bit_position)
            bit_position += 2
        elif snake_obj.direction.name == "DOWN":
            bit_position += 1
            encoded_map.set(bit_position)
            bit_position += 1
        else:
            encoded_map.set(bit_position)
            encoded_map.set(bit_position + 1)
            bit_position += 2

        return encoded_map.tostring()[4:] #We only need 12 bits, but the bitmap defaults to hold bytes (16 bits).

def main():
    game = QGame()
    #table = QTable()
    game.play()

if __name__ == "__main__":
    main()

This does work, but it can be difficult to see. If we think about how the snake learns, the only time learning happens is when food is eaten. That means the snake will only learn when it randomly bumps into the food. An idea to fix this is to give a reward of -.1 for every move that the snake doesn’t eat the food, but then we have the same problem. The snake doesn’t know where the food is; it can’t be expected to move towards the food without this key information.

The idea I settled on was to give a reward of -.2 for every move that takes the snake farther away form the food. However, if the snake gets closer to the food, it will get a reward of .1. This way, not only is the snake punished for moving farther away from the food, but it is rewarderd for moving towards the food. It is also important that the negative reward outweight the positive reward.

Imagine if the snake constanlty went in a square. This is technically moving two squares closer to the food and two squares away from the food. If they were the same value, the net reward would be 0, but we want to punish that type of activity. If the negative reward is greater, going in circles will lead to a net negative reward.

In order to implement this, we need a function that can return the distance the snake is from the food.

#In qsnake.py
#In class Game
def getDistanceBetweenFoodAndSnake(self):
   x = abs(self.food.x - self.snake.x)
   y = abs(self.food.y - self.snake.y)

   return x + y

We take the absolute value of the distance between the head and the food because we don’t care if the food is to the left or the right, we only want the raw length. If we are imagining a coordinate plane with origin at the head, if the food is at point (-2, -2), the distance between (0, 0) and (-2, 0) is 2, not -2. We add the x and y value together to get the overall number of squares between the food and the snake.

Now that we can get the distance, we need to create a function to return the correct reward based on what move the snake just made.

#In qsnake.py
#In class Snake
def getReward(self):
    if self.ate:
        reward = 1
    else:
        last_distance = self.distance
        self.distance = self.game.getDistanceBetweenFoodAndSnake()
        if last_distance > self.distance:
            reward = -.2
        else:
            reward = .1

        return reward

Notice, we now have a new variable self.distance we can’t forget to initialize this in the constructor

#In qsnake.py
#In snake.__init__():
    self.distance = self.game.getDistanceBetweenFoodAndSnake()

Now, we need to replace the line reward = 1 if self.snake.ate else 0 with reward = self.snake.getReward(). This also means that because we are gettinga a reward every frame, we need to update the Q-Table every frame.

#In qsnake.py
#In QGame.play()
    old_state = self.current_state
    self.current_state = self.table.encodeState(self.snake, self.food)
    new_action = self.table.chooseAction(self.current_state)
    self.snake.changeDirection(snake.Direction[new_action])
    reward = self.snake.getReward()
    self.table.updateQValue(old_state, self.current_state, new_action, reward)

Now, if we run our code, we can start to see a moving snake. Let’s change the Snake.gameOver() function to print out the current Q-Table

#In qsnake.py
#In class Game
def gameOver(self):
    print(self.qTable)

This will lead to this type of output:

101000000000        0          0  0.01     0
011100000000        0       0.01     0     0
101100000000        0      0.109     0     0
100000001000        0      -0.02 -0.02     0
010001000000     0.01          0     0     0
000000010000        0          0     0  0.01
110000000010        0          0     0  0.01
110100000010     0.01          0     0     0
000100010000  0.11791          0     0     0
001100010000        0          0     0  0.01
111100000110        0       0.01     0     0
101100001001        0       0.01     0     0
101100001000        0  0.0454636     0     0
100010011100        0          0     0  0.01
110010010111        0          0     0     0

Here, we can see that the snake has been updating the Q-Table based on the actions it has been taking. For example, If the snake is going UP and the food is to the NW section, then it has received rewards so that it thinks the optimal move is to go UP. If we think about this, that should be the best move, because that will go closer to quadrant II. If we run over multiple iterations, we can see that the QTable will get larger and larger. First, we have to create a Game.reset() function. This way we can keep the QTable and just reset the food and the snake.

#In qsnake.py
#In class Game

def reset(self):
    self.food = snake.Food(self, 40)
    self.snake = Snake(self, 40)
    self.current_state = self.table.encodeState(self.snake, self.food)
    self.score.reset()
    self.done = False

Now, once we call our reset, we can play the game again.

#In qsnake.py
#In main
game = QGame()
for i in range(2):
    game.play()
    game.reset()

This can lead to a Q-Table like the following:

                    UP  DOWN       LEFT RIGHT
111000000000        0       0.01     0     0
111100000000     0          0      0  0.00940108
111000000000 -0.02      0.019      0       -0.02
001000000000     0          0  0.019           0
011100000000     0  0.0223058  -0.02           0
101100000000     0    0.15031      0           0
...            ...        ...    ...         ...
000010010000     0          0      0        0.01
110000010110 -0.02          0      0           0
001100010100     0          0  -0.02           0
011111010000     0       0.01      0           0
100001101000     0          0      0           0
(65 rows x 4 columns)

This means that the snake encountered 65 states. The maximum number of possible states in the qtable can be seen as 2^number of bits = 2^12 = 4,096 possible states.

One thing that we are missing is a negative reward upon death. If we give a -100 reward upon death (we really don’t want to die) then our reward function would change slightly

#In qsnake.py
#In Snake.getReward()
new_distance = self.game.getDistanceBetweenFoodAndSnake()
    if self.ate:
            reward = 1
            self.distance = new_distance
    elif self.checkDead():
            reward = -100
    else:
            if new_distance < self.distance:
                    reward = .1
            else:
                    reward = -.2 

    self.distance = new_distance

    return reward

Now, we may want to train our snake over many many trials in order to see how it can perform. The overall idea in this is to train the game over a number of trials. After so many trials, then we will have a “real” game and get the final score. This entire process is repeated over a certain number of replications.

For example, if replications = 10 and trial_set = [1] then the snake will train for 1 game and play a real game. Then it will repeat playing 1 game then a real game 10 times. Then all of those real game scores are recorded. We will want to keep track of this data for later, so we will offer the option to save this information to a file.

#In qsnake.py
from scipy.stats import describe #pip3 install scipy
import json
#Not in any class 
def __experiment(replications, trials):
    final_scores = []
    for replication in range(replications):
        print(f"replication = {replication}")
        game = QGame()
        for trial in range(trials):
            print(f"trial = {trial}")
            game.play()
            game.reset()
        game.play()
        final_scores += [game.score.value]

    return final_scores

def train(replications, trial_set, out_file_name=None):
    records = []
    for trial in trial_set:
        final_scores = __experiment(replications, trial)
        record = {
            'trials': trial,
            'replications': replications,
            'final_scores': final_scores
        }
        records += [record]

    if out_file_name:
        with open(out_file_name, 'a') as out_file:
            json.dump(records + '\n', out_file)
            out_file.write('\n')
    else:
        for record in records:
            print(f"Trials: {record{'trials'}}; Replications: {replications}")
            print(describe(record['final_scores']))

Here, there are two new functions that will help us train the snake.

__experiment() will help actually run the test while train() will help us record the data we get from the experiment.

Now we can call our train() function from the main method

#In qsnake.py
def main():
    train(2, [5], "training_data.txt")

This would produce the following output in “training_data.txt”

'[{"trials": 5, "replications": 2, "final_scores": [4, 2]}]

Now, it won’t take long to find out a major flaw with this system. It is incredibly slow. Try to use 100 replications and 2 games. It will take forever. In order to speed it up there are a couple things we could do. Instead of updating based on a timer, we could just update the game immediately when possible. However, if we don’t care about watching the snake, why not just avoid drawing the screen and allow even faster updating?

Results of time before updating

real    0m41.832s
user    0m3.184s
sys     0m1.024s

Let’s change our code to not show any screen if we specify that we are only training.

#In qsnake.py
#Lines are either commented out or indented differently
#In qsnake.play()
#In while not self.done
    #Other code untouched
    #if timer * speed > 1:
    #timer = 0
    old_state = self.current_state
    new_action = self.table.chooseAction(self.current_state, self.snake.direction)
    self.snake.changeDirection(blog_2.Direction[new_action])
    self.snake.move()
    self.snake.checkEat()
    if self.snake.checkDead():
            self.gameOver()
    self.current_state = self.table.encodeState(self.snake, self.food)
    reward = self.snake.getReward()
    self.table.updateQValue(old_state, self.current_state, new_action, reward)
    self.clock.tick()

    #self.screen.fill((0, 0, 0))
    #self.score.draw()
    #pygame.draw.rect(self.screen, blog_2.Snake.color, self.snake)
    #pygame.draw.rect(self.screen, blog_2.Food.color, self.food)

    #for block in self.snake.tail:
    #       pygame.draw.rect(self.screen, blog_2.Block.color, block)

    #pygame.display.flip()
    #timer += self.clock.tick(self.fps) / 1000

However, if you still wanted to play the game normally, the quick fix would be to pass an argument to the QGame.__init__() function.

#In qsnake.py
#In class Game
def __init__(self, training=False):
    #Other code...
    self.training = training

Don’t forget to change this in experiment!! game = QGame(training=True)

Then, in the play function, a simple check.

#In qsnake.py
#In game.play()
def play(self):
    timer = 0
    speed = 10
    moved = False
    while not self.done:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                self.done = True

            if event.type == pygame.KEYDOWN:
                if moved:
                    self.userInput(event.key)
                    moved = False
                else:
                    input_buffer.append(event.key)

        while self.pause:
            for event in pygame.event.get():
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_SPACE:
                        self.pause = False
                elif event.type == pygame.QUIT:
                    self.done = True
                    self.pause = False

        if self.training:
            timer = 0
            old_state = self.current_state
            new_action = self.table.chooseAction(self.current_state, self.snake.direction)
            self.snake.changeDirection(blog_2.Direction[new_action])
            self.snake.move()
            self.snake.checkEat()
            if self.snake.checkDead():
                    self.gameOver()
            self.current_state = self.table.encodeState(self.snake, self.food)
            reward = self.snake.getReward()
            self.table.updateQValue(old_state, self.current_state, new_action, reward)
            self.clock.tick()
        else:
            if timer * speed > 1:
                timer = 0
                old_state = self.current_state
                new_action = self.table.chooseAction(self.current_state, self.snake.direction)
                self.snake.changeDirection(blog_2.Direction[new_action])
                self.snake.move()
                self.snake.checkEat()
                if self.snake.checkDead():
                        self.gameOver()
                self.current_state = self.table.encodeState(self.snake, self.food)
                reward = self.snake.getReward()
                self.table.updateQValue(old_state, self.current_state, new_action, reward)

            self.screen.fill((0, 0, 0))
            self.score.draw()
            pygame.draw.rect(self.screen, blog_2.Snake.color, self.snake)
            pygame.draw.rect(self.screen, blog_2.Food.color, self.food)

            for block in self.snake.tail:
                pygame.draw.rect(self.screen, blog_2.Block.color, block)

            pygame.display.flip()
            timer += self.clock.tick(self.fps) / 1000

Results of new training:

real    0m1.933s
user    0m2.051s
sys     0m0.334s

This is much, much faster. From 41 seconds to almost 2 seconds. However, this can be very very slow if we change it to say, 100 replications and say [10, 20] trials

real    3m46.570s
user    3m39.304s
sys     0m7.374s

That took 3 minutes! We can make this even faster if we introduce multiprocessing into our program. Why does this make sense? If we think about our program, it is very CPU intensive (computing things over and over again) and not very I/O intensive. Also, each tiral in trialset is independent, it uses it’s own game each time. This allows us to do all of the work for the game in separate processes and then combine all of the work into the file later.

I originally thought about multithreading, but that didn’t work because for 1, it is best to only update a game from the main thread, but each thread would be acting as the main thread, which the OS wouldn’t understand. Also, Python utilizes something called the GIL. This basically causes multithreaded applications that are CPU intensive (our application) to be run as a single-threaded application. Basically, multithreading is only effective if another thread is waiting for I/O. If there is no waiting, then the threads are “basically” run sequentially.

Our multiprocess application becomes

#In qsnake.py
#In train()
def train(replications, trial_set, out_file_name=None):
    records = []
    formatted_input = []

    for trial in trial_set:
        formatted_input.append([replication, trial])

    with multiprocessing.Pool() as pool:
        results = pool.starmap(experimnet, formatted_input)

    for final_scores, trial in zip(results, trial_set):
        record = {
            'trials': trial,
            'replications': replications,
            'final_scores': final_scores
        }
        records += [record]

    if out_file_name:
        with open(out_file_name, 'a') as out_file:
            json.dump(records, out_file)
            out_file.write('\n')
    else:
        for record in records:
            print(f"Trials: {record['trials']}; Replications: {replications}")
            print(describe(record['final_scores']))

Results:

real    2m47.453s
user    4m21.260s
sys     0m11.165s

We were able to cut down a full minute! It is interesting to note, the time our code was actually on the CPU executing increased from 3m46s to 4m21s. This is the overhead of creating new processes on the OS.

Now we have a basic Q-Learning snake and the ability to test over many replications and trials. This will make our analysis of the snake easier In the next post, we will begin to go in depth in the results from our training. Then, we will go into work on adding other obstacles into the game.