模型排序中的逆序对计算

2019-01-08  本文已影响0人  张矩阵

一、问题背景

在使用机器学习模型预测的各类场景中,对象间的序关系是否预测准确,是考量模型效果的指标之一。

正序数、逆序数可作为模型效果的衡量指标。当样本的label之间存在序关系时,由样本间两两组成的pair,若模型预测结果的序关系与label之间的序关系相同,称为正序;若模型预测结果的序关系与label之间的序关系相反,称为逆序。当正序数量越多、逆序数量越少时,表明模型对序关系的刻画越准确,模型效果越好。正逆序即为正序数量与逆序数量的比值。

计算正序数量、逆序数量时,一种直观的方法是暴力构造所有的pair对并一一验证,在N个样本下时间复杂度为O(N^2),当N的量级在5万时普通的服务器上已需要耗费近10分钟(Python),在更大量级时计算时间无法忍受。

进一步,如果在模型训练过程中,希望在训练的每一轮迭代时都查看正逆序的值(据此做终止条件);或是在多组参数间使用正逆序做验证调参的标准时,正逆序的计算速度问题则更为突出。我们需要复杂度更低的算法来快速计算出正逆序。

二、数学表述

给定N个样本,现有人工打好的每个样本的label,记为label_1,\ label_2,\ ...,\ label_N以及模型预测出的每个样本的分值,记为predict_1,\ predict_2,\ ...,\ predict_N

| A | 表示集合A的大小, \land表示逻辑与、\lor表示逻辑或
以下i, j 均在 1,2,...N中取值、且均满足 i < j ,不再特别写出

构造出的所有pair集合为(注意我们只对label不同的样本构造pair
Pair = \{ (i, j) \ | \ label_i \ne label_j \}

正序集合为
Right = \{ (i, j) \ | \ (label_i > label_j \land predict_i \ge predict_j) \lor (label_i < label_j \land predict_i \le predict_j)\}

逆序集合为
Wrong = \{(i, j) \ | \ (label_i > label_j \land predict_i \le predict_j) \lor (label_i < label_j \land predict_i \ge predict_j) \}

严格正序集合为
StrictRight = \{(i,j) \ | \ (label_i > label_j \land predict_i > predict_j) \lor (label_i < label_j \land predict_i < predict_j) \}

严格逆序集合为
StrictWrong = \{(i, j) \ | \ (label_i > label_j \land predict_i < predict_j) \lor (label_i < label_j \land predict_i > predict_j) \}

由以上的定义容易看出,
| Right | + |StrictWrong | = | Wrong | + |StrictRight | = | Pair |

三、算法思路

注:以下排序均指升序排列。

MergeSort

熟悉排序算法的同学应已看出:这个问题像极了经典的找逆序对问题。

给定数组a,找出满足 i < j \ \land \ a[i] > a[j](i,j) pair对数量。使用MergeSort可以在 O(N \log N)时间计算出结果。

MergeSort计算逆序的思路较为直接,使用的是Divide-and-Conquer的思想:

严格逆序 StrictWrong

我们从计算严格逆序集合StrictWrong入手。

根据上一点关于MergeSort的讨论,一个自然的想法出现了:

先按label排序得到数组a,接着计算数组a中的关于predict的逆序对数量。

但这与我们的原始需求仍有细微的差别:在按label排序后,
我们对于下标i, j只能得到i < j \Rightarrow label_i \le label_j
然而 i < j \nRightarrow label_i < label_j
因此直接计算逆序对的话,会把label相等的情形也算进来,得不到正确答案。

为了消除这一影响,我们可以使用一个小trick:

在按label排序时,排序的key不仅仅使用label,而是按照二元组 (label,\ predict)排序。
即:先按label排序,当label值相等时,按predict排序

于是
i < j \Rightarrow (label_i,\ predict_i) \le (label_j,\ predict_j) \Rightarrow (label_i < label_j) \lor (label_i = label_j \land predict_i \le predict_j)
因此在这种情形下我们不会统计到任何label相等时的逆序对。可在O(N \log N)时间内计算得到 |StrictWrong |

严格正序 StrictRight

思路与严格逆序完全一致,只是不等号方向变反。为了程序复用,可将predict正负号变反后、直接调用严格逆序计算的程序得到结果

正序Right, 逆序Wrong, Pair

因为 | Right | + |StrictWrong | = | Wrong | + |StrictRight | = | Pair |
结合前面的结果, 只需计算出 | Pair |, 即可获得 | Right ||Wrong |
| Pair | 的计算较为简单,将label排好序后,依次遍历处理即可,总复杂度为 O(N \log N)

结论及实验

根据以上讨论,各个集合的大小均可在O(N \log N) 时间计算得出。
随机构造的样本在普通服务器上的实际运行时间统计如下(Python),可以看出优化后的算法执行时间大幅提升。

样本数量N 基于 MergeSort 计算时间(秒) 基于 暴力法 计算时间(秒)
500 0.017 0.051
5000 0.164 5.048
50000 2.017 512.371

四、总结与展望

总结

展望

附、代码片段

(代码中的变量true即为上文中的label,取"groundtruth"之意;pred即为上文中的predict

from itertools import groupby


class InversionCounter(object):
    @classmethod
    def merge_sort_count_sub(cls, vals):
        if len(vals) <= 1:
            return vals, 0

        n = len(vals)
        left_vals, left_cnt = cls.merge_sort_count_sub(vals[:n/2])
        right_vals, right_cnt = cls.merge_sort_count_sub(vals[n/2:])

        left_i = 0
        right_i = 0
        
        mid_cnt = 0
        new_vals = []
        while True:
            if left_vals[left_i][1] <= right_vals[right_i][1]:
                new_vals.append(left_vals[left_i])
                left_i += 1
            elif left_vals[left_i][1] > right_vals[right_i][1]:
                mid_cnt += (len(left_vals) - left_i)
                new_vals.append(right_vals[right_i])
                right_i += 1

            if left_i == len(left_vals):
                new_vals.extend(right_vals[right_i:])
                break
            if right_i == len(right_vals):
                new_vals.extend(left_vals[left_i:])
                break

        return new_vals, left_cnt + mid_cnt + right_cnt


    @classmethod
    def merge_sort_count_strict_right(cls, trues, preds):
        neg_preds = (-p for p in preds)
        vals = zip(trues, neg_preds)
        vals.sort()
        return cls.merge_sort_count_sub(vals)[1] 


    @classmethod
    def merge_sort_count_strict_wrong(cls, trues, preds):
        vals = zip(trues, preds)
        vals.sort()
        return cls.merge_sort_count_sub(vals)[1]


    @classmethod
    def merge_sort_count_right(cls, trues, preds):
        return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_wrong(trues, preds)


    @classmethod
    def merge_sort_count_wrong(cls, trues, preds):
        return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_right(trues, preds)


    @classmethod
    def merge_sort_count_pair(cls, trues, preds=None):
        '''
            preds: dummpy variable, no need inside function
        '''
        trues = sorted(trues)
        acc_num = 0
        pair = 0
        for k, ks in groupby(trues):
            current_num = sum(1 for _ in ks)
            acc_num += current_num
            pair += (len(trues) - acc_num) * current_num

        return pair
上一篇 下一篇

猜你喜欢

热点阅读