Rasa Core源码之Policy训练
上下文的联系与理解是对话系统中重要的一块,直接影响与机器人对话的体验。最近接触了RASA系列,包括自然语言理解的rasa nlu和对话管理的rasa core。简单方便的实现一个任务对话系统的同时,也好奇其内部实现使用的技术。花时间读了Rasa Core关于上下文理解部分的源码,后面有机会再把rasa对话系统的其他模块的实现也做一个源码的分析。
文章分为以下几部分:
- Rasa Core的主要模块概念
- 训练数据准备
- 对话Policy模型训练和实现方法
主要概念
arch.png与对话系统的主要模块对应,如图 Rasa Core的实现也有相应的几个模块。从接受用户消息到机器人做出决策的流程大致如下:
- 接受用户消息,送入Interpreter模块,识别并生成包含消息文本(text)、用户意图(intent)、和实体(entities)的字典。这里Interpreter对意图和实体的识别由上面提到的Rasa NLU实现,不是文章的主题,只要知道其功能即可。
- Tracker 是对对话状态进行追踪(state tracker)的对象,它接受并记录Interpreter识别的新消息。
- Policy接受当前的对话状态,选择响应哪一个Action。
- 被选择的Action别记录在Tracker中,并返回响应给用户。
上述流程是在Interpreter和Policy模型训练好的基础上对话系统的运行流程,下面主要针对Policy选择Action的模型训练部分的源码进行分析。该部分模型需要考虑历史对话对下一步响应进行决策,是整个对话系统的核心。
训练数据
训练Policy之前需要准备两个数据文件:
- domain.yml : 包括对话系统所适用的领域,其中包括intents(意图集合)、slots(实体槽集合)、actions (机器人相应方式的集合)。
- story.md:训练数据集合,这里的训练数据比不是原始的对话数据,而是原始的对话在domain中的映射。
以官方的订餐馆的数据集为例:
restaurant_domain.yml:
slots:
cuisine:
type: text
people:
type: text
location:
type: text
price:
type: text
info:
type: text
matches:
type: list
intents:
- greet
- affirm
- deny
- inform
- thankyou
- request_info
entities:
- location
- info
- people
- price
- cuisine
templates:
utter_greet:
- "hey there!"
utter_goodbye:
- "goodbye :("
- "Bye-bye"
utter_default:
- "default message"
utter_ack_dosearch:
- "ok let me see what I can find"
...
...
actions:
- utter_greet
- utter_goodbye
- utter_default
- utter_ack_dosearch
- utter_ack_findalternatives
...
...
babi_stories.md:
## story_03812903
* greet # 用户打招呼
- utter_ask_howcanhelp # 机器人响应需要什么帮助
* inform{"location": "paris", "people": "six", "price": "cheap"} # 用户回复想订一下Paris便宜的六人桌
- utter_on_it # 机器人回复好的
- utter_ask_cuisine # 机器人继续询问要什么菜系
* inform{"cuisine": "indian"} # 用户说印度菜
- utter_ack_dosearch # 机器人回复稍等帮您查找
- action_search_restaurants # 机器人查库返回结果
...
... # 省略
* affirm # 用户确认
- utter_ack_makereservation # 机器人回复完成订单,询问手机号
* request_info{"info": "phone"} # 用户告知手机号
- action_suggest # 机器人其他推荐
* thankyou # 用户感谢
- utter_ask_helpmore # 机器人询问其他帮助
story中对样例对话进行了简单的注释。
模型训练
准备好训练数据,下面是模型训练。拿官方的一个经典的KerasPolicy模型为例,该模型用Keras实现了一个简单的LSTM作为Policy模型:
def model_architecture(self, num_features, num_actions, max_history_len):
"""Build a keras model and return a compiled model.
:param max_history_len: The maximum number of historical
turns used to decide on next action
"""
from keras.layers import LSTM, Activation, Masking, Dense
from keras.models import Sequential
n_hidden = 32 # Neural Net and training params
batch_shape = (None, max_history_len, num_features)
# Build Model
model = Sequential()
model.add(Masking(-1, batch_input_shape=batch_shape))
model.add(LSTM(n_hidden, batch_input_shape=batch_shape, dropout=0.2))
model.add(Dense(input_dim=n_hidden, units=num_actions))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
logger.debug(model.summary())
return model
模型通过历史对话记录作为输入训练数据,下一个决策Action作为label,进行模型训练。三个参数:
- max_history_len: 记录的最大历史长度。
- num_features: 每个记录的特征维度(intent、slot、action等的数目),包括了该记录的状态。
- num_actions:候选响应数。
模型本质上是num_actions个类别的多分类。下面详细分析对story.md的编码,生成可以直接输入到模型的训练数据(X,y)。
状态追踪(state track)
在搞清楚模型输入的训练数据是什么之前需要了解Rasa Core是如何实现状态追踪。训练阶段,rasa core读入Story,用track记录:在设置最大长度上下文为2时,一条训练数据的会有如下字典的表示:
[
{'entity_location': 1.0, 'entity_people': 1.0, 'entity_price': 1.0, 'slot_cuisine_0': 0.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_utter_on_it': 1},
{'entity_location': 1.0, 'entity_people': 1.0, 'entity_price': 1.0, 'slot_cuisine_0': 0.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_utter_ask_cuisine': 1},
{'entity_cuisine': 1.0, 'slot_cuisine_0': 1.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_action_listen': 1}
]
该部分状态对应上面训练数据的
- utter_on_it # 机器人回复好的
- utter_ask_cuisine # 机器人继续询问要什么菜系
* inform{"cuisine": "indian"} # 用户说印度菜
状态编码
track列表中第一个字典表示utter_on_it后的状态,此时slot_location、slot_people、slot_price等的均已收集到在之前的对话中,对应value为1。第二个字典表示在utter_ask_cuisine后的状态,此时并没有获取到新的信息,而只是记录上一个机器响应prev_utter_ask_cuisine的value为1,表示该阶段状态;第三个字典表示当前状态,在获取新的cuisine信息后对应key的value置为1,同时上一个action为prev_action_listen表示监听。
相应的,根据训练数据下一个机器应该采取的action为:
- utter_ack_dosearch # 机器人回复稍等帮您查找
如此得到一条训练数据(x,y), x经过编码,单条记录为一个二值向量,如果特征出现为1,否则为0,对应上面的第三个字典:
{'entity_cuisine': 1.0, 'slot_cuisine_0': 1.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_action_listen': 1}
[0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
而对于最大历史信息记录为2的对应单条训练数据:
[array([[0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]
对应的y为5(utter_ack_dosearch的编号)。相同方法从Story中读取所有可能的数据对,去重和数据增强(打乱拼接),最终生成训练数据X,y。
- X的维度为:(num_states, max_history, num_features)
- y的维度为:num_states
模型训练
在准备好训练数据之后就可以对LSTM进行训练:
def train(self, X, y, domain, **kwargs):
self.model = self.model_architecture(domain.num_features,
domain.num_actions,
X.shape[1])
y_one_hot = np.zeros((len(y), domain.num_actions))
y_one_hot[np.arange(len(y)), y] = 1
number_of_samples = X.shape[0]
idx = np.arange(number_of_samples)
np.random.shuffle(idx)
shuffled_X = X[idx, :, :]
shuffled_y = y_one_hot[idx, :]
validation_split = kwargs.get("validation_split", 0.0)
logger.info("Fitting model with {} total samples and a validation "
"split of {}".format(number_of_samples, validation_split))
self.model.fit(shuffled_X, shuffled_y, **kwargs)
self.current_epoch = kwargs.get("epochs", 10)
logger.info("Done fitting keras policy model")
和一般LSTM网络的训练方法一样,这里先对y进行one hot编码,shuffle训练集,之后进行训练。对于单个训练数据,对比文本的训练,一个状态相当于一个词,而最大上下文长度为2的单条训练数据可类比为2个词的句子。
而在模型实用的预测阶段,一开始流程也有涉及,显然只要Tracker记录之前的聊天记录,每次拿当前决策的前两个消息作为模型输入,输出即为每个action的概率值,选择最大的响应即可。
小结
到此分析了Rasa Core的Policy训练方式,虽然Rasa Core的代码量并不算大,但这里并没有根据源码细节来看,而只是理清其训练方法。通过一个不错的对话系统的源码阅读,可以对对话管理的几个关键技术有进一步的理解,比如状态追踪、上下文理解以及没有讲的意图识别和实体识别。
相比于高大上的论文的解决方案(如端到端、Memory Network进行上下文理解),Rasa Core显得更加简单可用,同样Rasa Core支持online learning还有点增强学习的意思,感兴趣的可以关注其github。