教育

人工智能陪你玩游戏——的士司机接客和倒立摆

2018-09-07  本文已影响10人  圣_狒司机

的士司机接客游戏简介:

在这个游戏中,黄色方块代表出租车,(“|”)表示一堵墙,蓝色字母代表接乘客的位置,紫色字母是乘客下车的位置,出租车上有乘客时就会变绿。你作为的士司机要找到最快速接客和下客的路径。


接到客人了
放下客人,完成!

思路:

客人和下客的位置是随机出现的,不可能用机械的方法找出变动的路径,可以用人工智能强化学习方法学习出最优路径。

代码:

import gym
import numpy as np

env = gym.make('Taxi-v2')
Q = np.zeros((env.observation_space.n,env.action_space.n))

def trainQ():
    for _ in range(10000):
        observation = env.reset()
        while True:
            action = env.action_space.sample()
            observation_,reward, done,info = env.step(action)
            Q[observation,action] = reward + 0.75 * Q[observation_].max()
            observation = observation_
            if done:break
    return Q

def findway():
    observation = env.reset()
    rewards = 0
    while True:
        action = Q[observation].argmax()
        observation_,reward, done,info = env.step(action)
        print(observation_,reward, done,info)
        rewards += reward
        observation = observation_
        env.render()
        if done:
            print(rewards)
            break

Q = trainQ()
findway()

测试:

客人在这
接到客人了 开车
开车
开车 下客,完工!

倒立摆游戏简介:

思路:

倒立摆游戏比的士司机游戏复杂,原因在于倒立摆的连续状态是无穷多个,人工智能 Q-learning 方法需要有限个状态形成知识。
解决方案:只需要将连续状态打散为离散状态即可。

代码:

import gym
import numpy as np
env = gym.make('CartPole-v0')
eplision = 0.01
q_table = np.zeros((256,2))

def bins(clip_min, clip_max, num):
    return np.linspace(clip_min, clip_max, num + 1)[1:-1]

def digitize_state(observation):
    cart_pos, cart_v, pole_angle, pole_v = observation
    digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
                 np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
                 np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
                 np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
    return sum([x * (4 ** i) for i, x in enumerate(digitized)])

#------------观察期--------------#
for _ in range(1000):
    observation = env.reset()
    s = digitize_state(observation)
    while True:
        action = env.action_space.sample()
        observation_, reward, done, info = env.step(action)
        if done:reward = -200
        s_ = digitize_state(observation_)
        action_ = env.action_space.sample()
        q_table[s,action] = reward + 0.85*q_table[s_,action_]
        s,action = s_,action_
        if done:break
print('观察期结束')

#------------贪心策略期--------------#
for epicode in range(1000):
    observation = env.reset()
    s = digitize_state(observation)
    while True:
        eplision = epicode / 1000
        action = q_table[s].argmax() if np.random.random() < eplision else env.action_space.sample()
        observation_, reward, done, info = env.step(action)
        if done:reward = -200
        s_ = digitize_state(observation_)
        action_ = q_table[s_].argmax() if np.random.random() < eplision else env.action_space.sample()
        q_table[s,action] = reward + 0.85*q_table[s_,action_]
        s,action = s_ ,action_
        if done:break
print('贪心策略期结束')

#------------验证期--------------#
scores = []
for _ in range(100):
    score = 0
    observation = env.reset()
    s = digitize_state(observation)
    while True:
        action = q_table[s].argmax()
        observation_, reward, done, info = env.step(action)
        score += reward
        s = digitize_state(observation_)
        #env.render()
        if done:
            scores.append(score)
            break
print('验证期结束\n验证成绩:%s'%np.max(scores))

测试:

观察期结束
贪心策略期结束
验证期结束
验证成绩:200.0

倒立摆

纪要:

离散化函数原先为

def digitize_state(observation,bin=5):
    high_low = np.vstack([env.observation_space.high,env.observation_space.low]).T[:,::-1]
    bins = np.vstack([np.linspace(*i,bin) for i in high_low])
    state = [np.digitize(state,bin).tolist()  for state,bin in zip(observation,bins)]
    state = sum([value*2**index for index,value in enumerate(state)])
    return state

效果并不好,在网上找了个好的替换:

def bins(clip_min, clip_max, num):
    return np.linspace(clip_min, clip_max, num + 1)[1:-1]

def digitize_state(observation):
    cart_pos, cart_v, pole_angle, pole_v = observation
    digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
                 np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
                 np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
                 np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
    return sum([x * (4 ** i) for i, x in enumerate(digitized)])

可见要做好人工智能还是需要了解状态参数大致意义,不管含义什么都喂进智能体的话,智能体表现不会太好。

上一篇下一篇

猜你喜欢

热点阅读