两个多元高斯分布之间的wasserstein 距离
2019-04-16 本文已影响0人
世界上的一道风
Wasserstein distance
定义
- Wasserstein distance 是在度量空间 ,定义概率分布之间距离的距离函数。
- 定义:让 作为一个度量空间,每一个在 上的概率测度是一个 测度,(具体这是什么我也不清楚)。对于 的情况, 表示对于在 度量空间上的 ,在 上有有限个 距(moment)的概率测度集合 ,等于:
那么,对于两个分布 和 在 中的(p阶) Wasserstein距离为:
集合 表示在 空间上所有测度的集合,并且这些测度的边缘分布分别是 和 。(又称为 和 的耦合(coupling))。
变换上式得到:
- 注:什么是距(moments):在数学中,力矩是对函数形状的一种特定的定量度量。定义:关于值 的实值连续函数 的第 阶矩为:
如果 是一个概率密度函数,则上述积分的值称为概率分布的 阶矩;如果 是任何概率分布的累积概率分布函数,其中可能没有密度函数,随机变量为,则概率分布的第n个矩为:
由此我们得到公式 到 的表示。
- 什么是测度:测度(Measure)是一个函数,它对一个给定集合的某些子集指定一个数,这个数可以比作大小、体积、概率等等。 传统的积分是在区间上进行的,后来人们希望把积分推广到任意的集合上,就发展出测度的概念,它在数学分析和概率论有重要的地位。
Intuition and connection to optimal transport
在Wiki百科中有
Wasserstein distance between two gaussian
- 两个多元高斯分布之间的2阶Wasserstein距离是什么,公式 中的距离函数 如果是欧几里得距离的话,那么两个分布之间的2阶Wasserstein距离是:
两个多元高斯分布之间的2阶Wasserstein距离 是:
当协方差矩阵可以互换 ,公式 退化为:
- 注:
当 与 都是对称矩阵:,有:
代码:
def Wasserstein(mu, sigma, idx1, idx2):
p1 = torch.sum(torch.pow((mu[idx1] - mu[idx2]),2),1)
p2 = torch.sum(torch.pow(torch.pow(sigma[idx1],1/2) - torch.pow(sigma[idx2], 1/2),2) , 1)
return p1+p2
- 矩阵 范数表示为 ,:
上划线表示矩阵的中每一个数共轭复数。
- 协方差矩阵:
- 多元高斯部分的情况:是正定矩阵。
- 多元高斯部分的情况:考虑一个一般的对称协方差矩阵 ,其有个独立的参数。对于,又有个独立的参数。多以一共有个参数。当变大时,其独立的参数个数以的二次方增长。只考虑 是对角阵,即 ,我们就只用关心2D个独立的参数。
- 参考:
https://zlearning.netlify.com/computer/prml/PRMLch2dot3-gaussian-again.pdf
http://www.robots.ox.ac.uk/~davidc/pubs/tt2015_dac1.pdf
https://en.wikipedia.org/wiki/Wasserstein_metric#cite_note-1
http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/
https://en.wikipedia.org/wiki/Wasserstein_metric#cite_note-1