蓄水池采样算法-Lua版本

2020-03-19  本文已影响0人  慢慢来比较快_

由于业务需要,所以搜索了一些相关的随机算法
代码是参考维基百科进行编写的:https://en.wikipedia.org/wiki/Reservoir_sampling

注意点:
Chao算法会有缺陷,因为它一开始就把所需要的数据全部扔池子里了。 如果权重存在0的的数据,且数据量较少,可能会出现在最终结果里。(实验中,Chao算法结果不正确
Res算法,我使用了遍历求最小值,所以在处理大量数据时,可能会存在性能瓶颈。

代码:
-- 蓄水池采样算法
function reserviorSampling(tbSequence, dNeed, dSequenceSize, funcRandom)
    dSequenceSize = dSequenceSize or #tbSequence
    funcRandom = funcRandom or math.random
    
    if dNeed > dSequenceSize then
        return tbSequence
    end
    
    local tbSample = {}
    
    for i=1, dNeed do
        table.insert(tbSample, tbSequence[i])
    end
    
    for i=dNeed+1, dSequenceSize do
        local j = funcRandom(1,i)
        if dNeed >= j then
            tbSample[j] = tbSequence[i]
        end
    end
    
    return tbSample
end

-- 加权蓄水池采样算法: Algorithm A-Chao
function weightedReserviorSampling_Chao(tbSequence, dNeed, dSequenceSize, funcRandom)
    dSequenceSize = dSequenceSize or #tbSequence
    funcRandom = funcRandom or math.random
    
    if dNeed > dSequenceSize then
        return tbSequence
    end

    local dWeightSum = 0
    local tbSample = {}
    
    for i=1, dNeed do
        table.insert(tbSample, tbSequence[i])
        dWeightSum = dWeightSum + tbSequence[i].weight
    end
    
    for i=dNeed+1, dSequenceSize do
        dWeightSum = dWeightSum + tbSequence[i].weight
        local p = tbSequence[i].weight / dWeightSum
        local j = funcRandom()
        if j <= p then
            tbSample[funcRandom(1, dNeed)] = tbSequence[i]
        end
    end
    
    return tbSample
end

-- 加权蓄水池采样算法: Algorithm A-Res
function weightedReserviorSampling_Res(tbSequence, dNeed, dSequenceSize, funcRandom)
    dSequenceSize = dSequenceSize or #tbSequence
    funcRandom = funcRandom or math.random
    
    if dNeed > dSequenceSize then
        return tbSequence
    end

    local dWeightSum = 0
    local tbSample = {}
    
    local dMinIndex = nil
    
    for i=1, dSequenceSize do
        local r = funcRandom()^(1/tbSequence[i].weight)
        tbSequence[i].__reservior_r = r
        if i <= dNeed then
            table.insert(tbSample, tbSequence[i])
        else
            if not dMinIndex then
                local dMin = 9999
                for i,v in ipairs(tbSample) do
                    if v.__reservior_r < dMin then
                        dMin = v.__reservior_r
                        dMinIndex = i
                    end
                end
            end
            assert(dMinIndex)
            if r > tbSample[dMinIndex].__reservior_r then
                table.remove(tbSample,dMinIndex)
                dMinIndex = nil
                table.insert(tbSample, tbSequence[i])
            end
        end
    end
    
    return tbSample
end
上一篇 下一篇

猜你喜欢

热点阅读