Temporal Fusion Transformers for
摘要
多步(尺度)预测通常包含一个复杂的输入组合——包括静态(即时不变)协变量、已知的未来输入,以及其他仅在过去观察到的外生时间序列——没有任何关于它们如何与目标相互作用的先验信息。几种深度学习方法已经被提出,但它们通常是“黑盒”模型,并不能阐明它们如何使用实际场景中出现的全部输入。在本文中,我们介绍了时间融合变压器(TFT)——一种新的基于注意的架构,它结合了高性能的多步预测和对时间动态的可解释的洞察力。为了学习不同尺度上的时间关系,TFT使用循环层进行局部处理,并使用可解释的自我注意层进行长期依赖。TFT利用专门的组件来选择相关的特性和一系列的门控层来抑制不必要的组件,从而在广泛的场景中实现高性能。
历史瓶颈
在时序多步预测任务中,DNN面临以下两个挑战:
1. 如何利用多个数据源? 2. 如何解释模型的预测结果?
1. 如何利用多个数据源?
如上图,在时间序列预测中,所有的变量都划分为两个大类:
1、静态变量;
2、动态变量
(1)静态变量(Static Covariates):不会随时间变化的变量,例如商店位置;
(2)时变变量(Time-dependent Inputs):随时间变化的变量;
●过去观测的时变变量(Past-observed Inputs):过去可知,但未来不可知,例如历史客流量
●先验已知未来的时变变量(Apriori-known Future Inputs):过去和未来都可知,例如节假日;
静态变量可以细分为:
再细分可以划分为静态连续变量和静态离散变量,静态离散变量很好理解,例如商品所在的城市,商品的大类,商品所在销售区域等等,这些变量都是不会随着时间发生变化的,而静态连续变量,例如,商品A在2020年的年总销售金额,商品B在2021年双十一期间的销量等,这些不会随着时间变化的变量都是静态变量,或者说,静态特征(这里需要注意定义,如果说的是前一年销量,则是一个动态变量,因为时间点和取值是会发生变化的,真特么绕);
动态变量分为两种(论文里的图没有画的特别详细,所以看后面的地方会比较疑惑)
1、动态时变变量;
2、动态时不变变量
(当然,动态时变和时不变的变量中也都包含了连续和离散型的features,不过这个很好区分,下面的描述主要还是从时变和时不变展开来看的)
所谓动态时变量指的就是随着时间变化的特征,例如我们要预测的销售金额,明天的客流量,网页的pv,uv等等;动态时不变变量也是随着时间变化的特征,例如月份,星期几,那么二者的区别是什么?
二者的核心区别在于是否可以推断出来,动态时变变量是无法推断的,比如在经典的温度预测的例子里,温度,湿度,气压这些都是无法推断的,随着时间变化的,我们无法事先知道的。而动态时不变变量,最典型的就是月份,星期几了,这些变量他们虽然也是随着时间变化的,但是我们是可以轻而易举的进行推断从而应用到模型训练的过程中。
常规的向量输出的深度学习模型在训练的过程中是不会用到未来的特征的,即使是静态特征,例如我们用历史的30天的销量数据预测未来10天的销量数据,我们在模型训练的过程中不会用到未来十天的任何特征。
而很多RNN结构的变体模型,还有Transformer的变体模型,很少在多步预测任务上,认真考虑怎么去利用不同数据源的输入,只是简单把静态变量和时变变量合并在一起,但其实针对不同数据源去设计网络,会给模型带来提升。
因为无论是tcn,wavenet,nbeats,deepar,LSTM,seq2seq based model 或者是attention-based model 或是transformer,我们在构建模型的时候,都是将所有的特征按照time step 直接concat在一起,也就是说,目前现有的处理方式基本上就是“万物皆时序”,将所有的变量全部都扩展到所有的时间步,无论是静态,动态的变量都合并在一起送入模型,文中提到如果模型能够对静态和动态变量加以区分,能够进一步提升模型的泛化性能(存疑,不一定吧?)。其实如果有关注时序比赛的话,可以看看top solutions里关于nn的代码,基本上大家都是直接把所有变量一起放到模型里traing的,而不区分静态,动态,或者动态时变 or 动态时不变。
2. 如何解释模型的预测结果?
除了不考虑常见的多步预测输入的异质性之外,大多数当前架构都是" 黑盒" 模型,预测结果是由许多参数之间的复杂非线性相互作用控制而得到的。这使得很难解释模型如何得出预测,进而让使用者难以信任模型的输出,并且模型构建者也难对症下药去Debug模型。不幸的是,DNN常用的可解释性方法不适合应用于时间序列。在它们的传统方法中,事后方法(Post-hoc Methods),例如LIME和SHAP不考虑输入特征的时间顺序。另一方面,像Transformer架构,它的自相关模块更多是能回答“哪些时间点比较重要?”,而很难回答“该时间点下,哪些特征更重要?”。
论文贡献
本文提出的TFT模型有如下贡献:
1. 静态协变量编码器:可以编码上下文向量,提供给网络其它部分;
2. 门控机制和样本维度的特征选择:最小化无关输入的贡献;
3. sequence-to-sequence层:局部处理时变变量(包括过去和未来已知的时变变量);
4. 时间自注意解码器:用于学习数据集中存在的长期依赖性。这也有助于模型的可解释性,TFT支持三种有价值的可解释性用例,帮助使用者识别:
● 全局重要特征;
● 时间模式;
● 重要事件。
模型
我们设计了时间融合转换器(TFT),使用规范组件为每个输入类型(即静态、已知输入、观察输入)高效地构建特征表示,使其能够在广泛的问题上获得高预测性能。TFT的主要组成部分是:
(1)控制机制,跳过架构中任何未使用的组件,提供自适应深度和网络复杂性,以适应大范围的数据集和场景。门控线性单元在整个体系结构中得到了广泛的应用,门控剩余网络被提出作为主要的构建模块。
(2)变量选择网络,在每个时间步选择相关的输入变量。
(3)静态协变量编码器,通过对上下文向量进行编码,以条件时间动态,将静态特征融入到网络中。
(4)时间处理,学习长期和短期的时间关系,同时自然地处理观察到的和先验知道的时变输入。一个序列-序列层被用于局部特征处理,而长期依赖关系被捕获使用一个新的可解释的多头注意块。
(5)多水平预测区间预测,在每个预测水平产生分位数预测
下图显示了TFT的高级体系结构,后面的部分将详细描述各个组件。模型的开源实现也可以在GitHub1上找到,完全可复制。
google-research/tft at master · google-research/google-research · GitHub
总结一下,GRN用了skip-connection和GLU,主要是控制线性和非线性特征的特征信息的贡献(Gate+Add&Norm),特别是加入静态协变量c,去引导模型的学习。VSN是配合GRN和softmax,进行特征选择。TFD中的多头自关注模块提供了可解释性和时序长依赖关系的捕捉能力。
新认识:
2. VSN软性的特征(变量)选择;
3. GRN作线性特征与非线性特征的融合,对非线性变量作筛选。其中这里用的残差,类似教师机制,对模型的训练很有效果、以及使用GLU来作门限组件,起到了特征选择的作用,与TabNet相似、以及ELU激活函数,比relu效果更好,小于0的部分不会出现梯度消失的情况,呈指数形式。
4. TSL中,多头注意力机制的使用:可解释性多头自关注层比较好理解(与Transformer不同之处),它其实就是针对V是多头共享参数,对Q和K是多头独立参数,然后计算多头attention score加权后的V,
5. 预测分位数:
除了像DeepAR预测均值和标准差,然后对预测目标做高斯采样后,做分位数统计。
分位数损失函数:
当q取越大,使得loss更小,相较y越大;当q取越小,欲使得loss更小,相较y越小。
由于q取90%,训练时,模型会越来越趋向于预测出大的数字,这样Loss下降的更快,则模型的整个拟合的超平面会向上移动,这样便能很好的拟合出目标变量的90分位数值。
我们对于未来的预测就能够产生 prediction intervals了,可以比较好的反应预测结果的不确定性,比如某个点的不同分位数线性回归的预测结果很接近,则其prediction intervals很窄,预测的确定性高,反之亦然。
创新点
对时序模型的输入进行了分类。文章提出的分类方式是从模型视角出发,对不同预测任务输入的特征作了抽象的归类。
指出了探索不同数据源交互信息的重要性。这一点也可以说是用深度学习作预测的重要性,因为深度学习模型的优点在于捕捉非线性关系,这种关系在真实的系统中大量存在。
强调了时序预测可解释性的重要性。时序预测无法像推荐算法、图像算法那样直接产生收益,更多起到辅助决策的作用,在利益方牵涉较广的电力、交通、经济、金融系统中,黑盒模型始终缺乏说服力。
消融实验
另外作者对网络模块做了消融实验,如图12。从下图右侧,我们能看到Self-Attention和Local Processing(LSTM层)贡献最大,但不同数据集上,两者的贡献大小并不绝对,比如对于Traffic数据集,Local Processing更重要,作者认为是Traffic数据集得目标历史观测值更重要,所以Local Processing发挥了更大的作用。而对于Eelectricity数据集,Self-Attention更重要,作者认为是电力的周期性明显,hour-of-day特征甚至比预测目标Power Usage的历史观测值更重要,所以自关注发挥作用更大。
对TFT的解释性,作者从3方面进行展示:(1)检查每个输入变量在预测中的重要性,(2)可视化长期的时间模式,以及(3)识别导致时间动态发生重大变化的任何状态或事件。