Prioritized DQN

2018-02-06  本文已影响752人  海街diary

1.简介

Prioritized DQN 是为了解决当在memory中均匀采样时候学习效率低下的问题。原因主要有两个:

1.我们想让new transition立马用于更新,因为这样的new experience对于explore很重要。

2.我们想让large td-error的transition立马用于更新(比如有99次失败的经历和1次成功的经历,我们希望立马学习这个成功的经历)

显然uniform sampling无法做到这两点。
于是便有了伟大的Prioritized Experience Replay.

论文在这里。
代码在这里。
简单介绍在这里。

下面我将分享自己学习这篇论文的时候一些经验。请读完论文和简单介绍后,如有困惑,再阅读以下部分。

2.关键点

Prioritized DQN能够成功的主要原因有两个:sum tree这种数据结构带来的采样的O(log n)的高效率,和Weighted Importance sampling的正确估计。后者,我现在还没有完全搞明白原理。

我简单由谈下自己对于sum tree数据结构的理解。 sum tree存储的元素是样本的优先级,其思想是根据累积概率密度(因此叫sum)来抽取样本。从最左方开始,优先级累积逐渐增大,如果我们的段>左子孩子,(递归地)就在右子孩子中寻找(这时候要做减法,以便又是新的累积优先级)。

如果把累积优先级(离散地)画出来,我们就会发现,高优先级对应的直线段斜率最大,被抽取到的概率最大。(可以以下图为例,自己在每个段中取数字进行验证)。

sum tree.png

3.代码解读

原代码注释较少,我这里列出几个点,方便大家阅读代码。

  • 代码实现的是DQN, 而不是Double DQN。
  • 在插入new transition更新sum tree的时候, 是根据新样本与原来位置的样本的优先级差来更新。(详见SumTree.add)
  • 在memory中插入new transition的时候,给予new transition最大的优先级,因为我们想让new experience立马用于学习。(详见Memory.store)
  • 在memroy中抽取n个samples后,我们会根据nn计算出来的TD-error来更新那些抽取到的样本的优先级,这样的话new transition就不会一直被学习。(详见Memory.batch_update)。

大家最好照着源码自己敲一编(时间大概2~3小时),我这里给出自己在搬砖过程中写的一点注释(也可以自己下载,照着看)。

import numpy as np
import tensorflow as tf

np.random.seed(1)
tf.set_random_seed(1)


class SumTree(object):
    data_pointer = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        # [------------ Parent nodes -------------][------ leaves to recode priority ----------]
        #            size: capacity - 1                  size: capacity
        self.data = np.zeros(capacity, dtype=object)
        # [------------ data frame ---------------]
        #            size: capacity

    # memory store_transition的时候使用
    def add(self, p, data):                     # p is the new priority, data is transition
 
    # memory batch_update的时候使用
    def update(self, tree_idx, p):

    # memory 分段采样的时候使用
    def get_leaf(self, v):
        parent_idx = 0
        while True:
            cl_idx = 2 * parent_idx + 1
            cr_idx = cl_idx + 1
            if cl_idx >= len(self.tree):   # 此时parent就是叶子结点
                leaf_idx = parent_idx
                break
            else:
                if v <= self.tree[cl_idx]:  # <= 左子孩子,就向左前进
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]  # > 右子孩子,需要重新当作一颗累积树,因此要减去左子孩子的值
                    parent_idx = cr_idx

        data_idx = leaf_idx - (self.capacity + 1)
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_p(self):
        return self.tree[0]   # the root


class Memory(object):
    epsilon = 0.01     # small amount to avoid zero priority
    alpha = 0.6         # [0, 1] convet the importance of TD error to priority
    beta = 0.4          # importance sampling, from intial value increasing to 1
    beta_increment_per_sampling = 0.001
    abs_err_upper = 1   # clipped abs error

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    # 向sum tree的transitions 中加入 new transition
    def store(self, transition):

    # 从sum tree中采取n个样本
    def sample(self, n):

    # 更新采样过的样本的priority(基于abs_error)
    def batch_update(self, tree_idx, abs_error):
上一篇 下一篇

猜你喜欢

热点阅读