1、Q-Learning算法学习

2022-05-29  本文已影响0人  小黄不头秃

一、概述

Q Learning是强化学习算法中的一个经典算法。在一个决策过程中,我们不知道完整的计算模型,所以需要我们去不停的尝试。

Q Learning算法下,模型和模型所处环境不断交互,不停的尝试,学习走出迷宫的规则,找到最优策略,这就是强化学习的学习过程。

二、算法理解

马尔科夫决策问题 Markov decision problem

简单的理解为,在当前状态下 s,进行一个行为 a,去到下一个状态 s‘
所以我们定义几个变量(以走迷宫为例):

其中这个model T可以是一个概率问题,可理解为在s 和a 的条件下,转变为状态s‘ 的概率。

上述的定义可以理解为:

(1)状态 S

假如我们的迷宫是一个 10 * 10 的矩阵,那我们就可以使用 0-99 来表示当前的状态信息。

(2)行为 A

在迷宫中我们就会出现上下左右四种走法,我们可以使用 0-3 表示每一个行为。

(3)奖励 R

R(s),s表示进入的状态;
可以分为:
reward = -1,碰到障碍物或者走出地图;
reward = 1,走到了终点;

(4)模型 T

T(s,a,s’),模型(当前状态,行为,下一状态)
例如:从一个格子走下一步,进入四个方向的概率就会计算出来,比如说,计算得出0.9, 0.1, 0.1, 0.1。那么就会朝0.9的方向走。
如下图所示,这是Q Learning算法的核心运转过程。

9.jpg

Markov decision process 马尔可夫决策过程

  1. 初始化 Q table = 0,所有状态和行为的表。Q(s,a);
  2. 选择一个action, Π(s) => a; 更具一定规则选取一个a,也可以是随机选。
    Π(s) = max a (Q(s, a));
  3. 表现最好的action。
  4. 获取反馈,即根据奖励机制获取reward。
  5. 根据learning rate (α)进行 Q table的更新。(Q' = R(S) + γΣT(s,a,s') * max(Q(s',a')), 查看下一步的操作,找到最佳路径)
    Q' ---α-----> Q
    Q = αQ' + (1-α)Q';
    这里其实主要就是更新a的值,下面实例代码中使用梯度下降的方式。
  6. 返回第2步。

三、环境搭建(python)

依赖的包:

四、代码分析

感谢大佬提供的帮助:https://gitee.com/biangaoyang
代码下载地址:https://gitee.com/biangaoyang/rl_-maze.git

(1)虚拟环境的搭建(有详细注释)
""" 
Reinforcement learning
    Q-learning maze example
Red rectangle:          explorer( 探索者 )
Black rectangle:        hells(地狱, 陷阱)       [reward = -1].
Yellow bin circle:      paradise(天堂, 出口)    [reward = +1].
All other states:       ground(其他状态)      [reward = 0].
""" 
import numpy as np
import time
import tkinter as tk
UNIT = 50 # 单位长度 50px
MAZE_H = 5 # 5个单位长度
MAZE_W = 5 # 5个单位长度

class Maze(tk.Tk, object):
    def __init__(self):
        # 初始化maze
        super(Maze, self).__init__()
        self.actionSpace = ['u', 'd', 'l', 'r']
        self.actionNum = len(self.actionSpace)
        self.title('MAZE') # 设置窗口标题
        self.geometry('{}x{}'.format(MAZE_H * UNIT, MAZE_W * UNIT)) # 窗口大小 5*50px,5*50px,参数是一个字符串“250x250”
        self._buildMaze()

    #构建迷宫地图,探索者,陷阱,出口
    def _buildMaze(self):
        self.canvas = tk.Canvas(self, bg = 'white', height=MAZE_H * UNIT, width=MAZE_W * UNIT) # 设置背景和窗口大小
        
        for c in range(0, MAZE_W * UNIT, UNIT):
            # 每隔50px画一条竖线 从(x0+=50,0)到(x0+=50,250)。
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            # 每隔50px,画一条横线。
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)
    
        origin = np.array([25, 25]) #方块的中心点
        #agent 这就是红色的探索者 (5,5,45,45)左上角坐标,右下角坐标
        self.rect = self.canvas.create_rectangle(
            origin[0] - 20, origin[1] - 20,
            origin[0] + 20, origin[1] + 20,
            fill= 'red'
        )

        #第一个陷阱的位置坐标(4,3),陷阱的中心位置(155,125)
        hell1_center = origin + np.array([UNIT * 3, UNIT * 2])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - 20, hell1_center[1] - 20, 
            hell1_center[0] + 20, hell1_center[1] + 20, 
            fill='black')

        #第二个陷阱的位置坐标(3,4),陷阱的中位值(125,155)
        hell2_center = origin + np.array([UNIT * 2, UNIT * 3])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - 20, hell2_center[1] - 20, 
            hell2_center[0] + 20, hell2_center[1] + 20, 
            fill='black')
        # exit 出口 位置(4,4),出口的中心位置(155,155)
        oval_center = origin + UNIT * 3
        self.oval = self.canvas.create_oval(
            oval_center[0] - 20, oval_center[1] - 20,
            oval_center[0] + 20, oval_center[1] + 20,
            fill= 'yellow'
        )
        # 将控件放置在主窗口中
        self.canvas.pack()

    def reset(self):
        # 游戏结束,重新设置起始位置
        self.update()   # Enter event loop until all pending events have been processed by Tcl.进入循环直到tcl处理完所有挂起事件
        time.sleep(0.5)
        self.canvas.delete(self.rect) #通过id将探索者清除
        # 重置位置
        origin = np.array([25, 25])
        self.rect = self.canvas.create_rectangle(
            origin[0] - 20, origin[1] - 20,
            origin[0] + 20, origin[1] + 20,
            fill= 'red'
        )
        # 刷新窗口 [5.0, 5.0, 45.0, 45.0]
        return self.canvas.coords(self.rect)    # Return a list of coordinates for the item given in ARGS.返回ARGS中给出的项目的坐标列表  

    def step(self, action):
        # 这里是每走一步的行为self,action:0-3,分别为上下左右
        s = self.canvas.coords(self.rect)
        baseAction = np.array([0, 0])
        if action == 0:         # UP 上
            if s[1] > UNIT:    
                baseAction[1] -= UNIT
        elif action == 1:       # DOWN 下
            if s[1] < (MAZE_H - 1) * UNIT:
                baseAction[1] += UNIT
        elif action == 2:       # LEFT 左
            if s[0] > UNIT:
                baseAction[0] -= UNIT
        elif action == 3:       # RIGHT 右
            if s[0] <(MAZE_W - 1) * UNIT:
                baseAction[0] += UNIT
        # 平移,单位为 UNIT 50px
        self.canvas.move(self.rect, baseAction[0], baseAction[1])

        s_ = self.canvas.coords(self.rect) # 获取探索者的位置

        # 判断探索者现在所处的位置设置reward 当状态是terminal的时候就是需要重开
        if s_ == self.canvas.coords(self.oval):
            reward = 1
            done = True
            s_ = 'terminal'
        elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
            reward = -1
            done = True
            s_ = 'terminal'
        else:
            reward = 0
            done  = False

        return s_, reward, done

    def render(self):       # Refresh the current environment 刷新当前窗口
        time.sleep(0.1) # 0.1s
        self.update()

#env = Maze()
def update():
    for t in range(10):
        s = env.reset()
        while True:
            # 刷新窗口
            env.render()
            a = 1 # 往下走
            s, r, done = env.step(a)
            if done:
                break

#  import 到其他的 python 脚本中被调用(模块重用)执行 具体解释可以看后面这位大佬的博客:"https://blog.csdn.net/heqiang525/article/details/89879056?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522163264662216780269835793%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=163264662216780269835793&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-89879056.pc_search_insert_js_new&utm_term=if+__name__+%3D%3D+%22__main__%22%3A&spm=1018.2226.3001.4187"
# 这个mian就是你运行本页面可以运行出来,在别的页面引用就不会执行,因为一个程序永远只有一个main
if __name__ == '__main__':
    # 用于测试环境
    env = Maze()
    # 设定每隔100ms,调用一次update,向下走10格,中途碰到陷阱或者出口就break down
    env.after(100, update)  # Call function once after given time.在给定时间后调用函数一次。
    env.mainloop()

(2)Q_table的设置
'''
Brain of the agent 探索者的大脑!
agent will make desicion here 用于做决策
Q(s,a) <- Q(s,a) + Alpha * [r + gamma * max(Q(s', a')) - Q(s,a)]

下面是Q——table表: (状态:行,行为:列)
        up    down    left    right   
state1  
state2
  .
  .
  .     
'''

import numpy as np
import pandas as pd
# import random

class QLearningTable:
    def __init__(self, actions, learningRate = 0.01, reward_decay = 0.9, e_greedy = 0.9):
        self.actions = actions   # a list 行为列表 [0,1,2,3]
        self.alpha = learningRate # 学习率 0.01
        self.gamma = reward_decay # γ 0.9
        self.epsilon = e_greedy # ε 0.9
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
    
    # Π()随机选择一个action, observation是一个坐标字符串'[5.0, 5.0, 45.0, 45.0]'
    def chooseAction(self, observation):
        self.checkStateExist(observation) #从来没来过这个格子(state),就加上一列

        if np.random.uniform() < self.epsilon: # 0-1随机采样 均匀分布 f(x) = 1/(b-a)
            stateAction = self.q_table.loc[observation, :] # 选出当前状态下的所有列
            # 意思是选出最大的那个action的index, 
            print("stateAction:",stateAction) #stateAction: 0    0.0 四行一列的表,分别对应四个action
            print("stateAction1:",np.max(stateAction)) # 0.0
            print("stateAction2:",stateAction == np.max(stateAction)) # 0    True 四行一列的表 是否等于最大值的bool值
            print("stateAction3:",stateAction[stateAction == np.max(stateAction)].index) #Int64Index([0, 1, 2, 3], dtype='int64')

            # 这里如果没有最大值就[1,2,3,4]随机选,有最大值就选择最大值。
            action = np.random.choice(stateAction[stateAction == np.max(stateAction)].index) 
            print("action:",action)

        else:
            action = np.random.choice(self.actions)
        return action

    # 学习更新q_table
    def learn(self, s, a, r, s_):
        self.checkStateExist(s_)
        q_predict = self.q_table.loc[s, a] # 预测值,一个值
        if s_ != 'terminal':
            # 游戏重开 如果走到陷阱q_target = -1 + 0.9*max (s'a)
            # 这里理解为loss function, 较为巧妙的设计
            q_target = r + self.gamma * self.q_table.loc[s_].max()
        else:
            # 游戏不重开 q_target = 0
            q_target = r
        # 这个公式就是梯度下降 Y = Y + α * loss()
        # Q(s,a) + Alpha * [r + gamma * max(Q(s', a')) - Q(s,a)]
        self.q_table.loc[s, a] += self.alpha * (q_target - q_predict) # 一次更新一个状态下的 a
    
    # check whether the state exist in q_table 检查是否该状态在q_table中存在
    def checkStateExist(self, state):
        if state not in self.q_table.index:
            self.q_table = self.q_table.append(
                # 追加序列
                pd.Series(
                    [0] * len(self.actions), #[0, 0, 0, 0]
                    index = self.q_table.columns, #Int64Index([0, 1, 2, 3], dtype='int64')
                    name = state, # 状态名
                )
            )
(3)主函数
from maze_env import Maze
from RL_brain import QLearningTable

def update():
    for episode in range(100):
        # print(RL.q_table)
        # 循环100次
        observation = env.reset() # s 当前状态

        while True:
            env.render()    # Refresh the current environment 刷新当前的窗口
            action = RL.chooseAction(str(observation)) # a' 下一步动作,随机选出一个action
            observation_, reward, done = env.step(action) # 根据Π("state")计算出的action
            RL.learn(str(observation), action, reward, str(observation_)) # 选择最优的action,根据学习率更新Q_table
            observation = observation_ # s' 下一状态
            if done:
                # 游戏重开跳出循环
                break

    print("Game Over")
    env.destroy()

if __name__ == '__main__':
    # 实例化一个窗口对象,封装好的迷宫
    env = Maze()
    RL = QLearningTable(actions=list(range(env.actionNum))) # actions = [0, 1, 2, 3]
    env.after(100, update)
    # 显示窗口
    env.mainloop()

上一篇 下一篇

猜你喜欢

热点阅读