优化文件加载从20s到3s

2017-02-25  本文已影响285人  董泽润
背景

现在大多业务都使用机器学习,程序启动时加载训练好的模型文件,运行期也会触发模型的 reload。 在程序启动时如果加载耗时比较长,那么程序自然有段时间不可服务(模型没有准备好),但是运行期由于是双 buffer 切换,耗时长些也无所谓。

优化前

加载 14 个模型文件,并行加载,文件最小的几k,最大有 300m,加载时间取短板最长耗时 20s

优化后

对最大的四个文件,采用并行加载,耗时最大减少到 3s,优化完成

代码串行处理逻辑

原有单个文件也是串行处理逻辑

这个逻辑非常简单,符合人的直觉思维,但同时也非常低效。

并行优化1

思路很简单:模型文件没有顺序,可以一次性全读到内存中,然后按行去并行解析,最后合并到字典,非常类似 MapReduce

第一次优化后耗时降为 10s,初步成效,但是仍然不理想。纺计每一步耗时后发现,对于最大 300m 的文件,bytes.Split 打散耗时 4s, 模型 Map 合并耗时 5s

并行优化2

和同事探讨下如何继续优化,对于 Map 无法并行。当前模型实现方式用单一 Map,如果加锁就和串行合并行为是一致的。当初始化 Map 指定大小时,合并时间从 5s 降到 2s,避免了 rehash copy 的开销,效果很明显。

另外 bytes.Split 打散耗时超长是没有想到的,看了下源码,内部两次遍历,耗时自然和数量成正比。同事提义将打散移到并行阶段,由每个 goroutine 去完成,预估并行数量,然后按 batch 打散。有几点需要注意:

  1. 无所提前知道总数据量大小,模型 Map 初始化要预估大小,按 30 byte 一行猜测即可
  2. 每个 gorouinte 划分数据也是不均等的,但一定要以 '\n' 分隔符打散,不能打数据截断

最后共耗时 3s,一次性加载内存维 150ms,并行解析 1s,合并 Map 2s

代码示例:

当前性价比最高的优化,如果大家有更好的方式可以共同交流一下,第一个是抽象的执行函数,第二个是示例使用方示

// ParallelLoadModelFile 并行加载模型文件
// @params  data         文件二进制数据
// @params  sep          分隔符
// @params  name         识别标记
// @params  parallel     并发数目, 一般不超过20, 过大没用
// @params  parse        用户处理函数
// @params  merge        用户合并函数
// 原类类似 MapReduce, 先将文件并行处理, 最后 reduce 合并。使用请参考 loadPassengerFeatures2
// 原则:尽量将耗时操作并行化
// 注意:
// 1. map 初始化时一定要指定大小,否则 rehash copy 成本非常高 测试 800W 条记录合并消耗 2s
// 2. 数据在 parse 和 merge 函数流动要用 channel, 具体类型及解析合并由调用方决定
// 3. 需要特殊处理行不能使用这个函数, 要单独处理
//
// 流程优化:
// 读文件 |  解析每行数据并写到map
// +------------------------+
//
// load内存并打散  分片     聚合
//             +-----+
// +--------+  |-----| +------+
//             +-----+
// load   打散分片 聚合
//
//        +----+
// +----+ |----| +---+
//        +----+
//
// 1. ioutil.ReadFile一次性读入内存 2. bytes.Split 按\n打散 3. 分片计算  4. 合并merge
// 在大文件时 bytes.Split 非常耗时, 将第2步移到并行阶段, 和3一起算。合并 map 非常耗时
// Map 操作只能串行, 并发也需要加锁来互斥, 等同于串行, 暂时没想到好的合并方法
func ParallelLoadModelFile(data []byte, sep []byte, name string, parallel int, parse func([]byte), merge func()) {
    if parallel <= 0 || parallel > 30 || parse == nil || merge == nil || len(sep) == 0 {
        panic("ParallelLoadModelFile params illegal")
    }

    var (
        wait  = sync.WaitGroup{} // sync
        size  = len(data)        // file size
        batch = size / parallel  // batch size
        num   = size/batch + 1   // parallel goroutine
        start = 0
        end   = batch
    )

    for i := 0; i < num; i++ {
        wait.Add(1)
        // 获取第一个 sep 所在的 index
        idx := bytes.Index(data[end:], sep)
        if idx == -1 {
            end = len(data) - 1
        } else {
            end += idx
        }

        go parse(data[start:end])

        start = end
        if (end + batch) < len(data) {
            end += batch
        } else {
            end = len(data) - 1
        }
    }

    go func() {
        for i := 0; i < num; i++ {
            merge()
            wait.Done()
        }
    }()

    // 同步阻塞,等待所有 MapReduce
    wait.Wait()
}

//加载小时特征 并行版本
func LoadHourGEOInfo2(model_data_center *ModelDataCenter, file_name string) error {
    now := time.Now().UnixNano()
    defer func() {
        logger.Info("load[%s] time=%dms", file_name, (time.Now().UnixNano()-now)/1e6)
    }()

    content, err := ioutil.ReadFile(file_name)
    if err != nil {
        logger.Error("ioutil readfile error, file_name=%s", file_name)
        return err
    }

    // 预估map大小
    model_data_center.HourGEOInfoData = make(map[string]DynamicDiscountGEOInfo, len(content)/30)

    // model 消息
    modelChan := make(chan map[string]DynamicDiscountGEOInfo, 10)

    // map 并行处理函数
    mapParse := func(content []byte) {
        var (
            data = bytes.Split(content, SepLine)
            m    = make(map[string]DynamicDiscountGEOInfo, len(data))
        )

        defer func() {
            // 将数据扔到 chan 待合并
            // 用 defer 防止遗望
            modelChan <- m
        }()

        for _, l := range data {

            line := string(l)
            // 兼容\r\n换行的情况
            line = strings.Replace(line, "\r", "", -1)
            list := strings.Split(line, ",")

            var hour_geo_info DynamicDiscountGEOInfo
            if len(list) != 8 && len(list) != 10 {
                logger.Warn("wrong fomat file=%s line=%s cols.Size=%d", file_name, line, len(list))
                continue
            }

            lng_lat, err := strconv.Atoi(list[0])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 0)
                continue
            }
            hour, err := strconv.Atoi(list[1])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 1)
                continue
            }
            hour_geo_key := GetGEOKey(hour, lng_lat, "HOUR", 0)
            hour_geo_info.StartGEOInfo.CarpoolNum, err = strconv.Atoi(list[2])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 2)
                continue
            }
            hour_geo_info.StartGEOInfo.SucCarpoolNum, err = strconv.Atoi(list[3])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 3)
                continue
            }
            hour_geo_info.StartGEOInfo.SucCarpoolRate, err = strconv.Atoi(list[4])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 4)
                continue
            }
            hour_geo_info.DestGEOInfo.CarpoolNum, err = strconv.Atoi(list[5])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 5)
                continue
            }
            hour_geo_info.DestGEOInfo.SucCarpoolNum, err = strconv.Atoi(list[6])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 6)
                continue
            }
            hour_geo_info.DestGEOInfo.SucCarpoolRate, err = strconv.Atoi(list[7])
            if err != nil {
                logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 7)
                continue
            }

            if len(list) == 10 {
                hour_geo_info.StartGEOInfo.InComeRate, err = strconv.ParseFloat(list[8], 64)
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 8)
                    continue
                }
                hour_geo_info.DestGEOInfo.InComeRate, err = strconv.ParseFloat(list[9], 64)
                if err != nil {
                    logger.Warn("wrong format file=%s line=%s item=%d", file_name, line, 9)
                    continue
                }
            } else {
                hour_geo_info.StartGEOInfo.InComeRate = -1.0
                hour_geo_info.DestGEOInfo.InComeRate = -1.0
            }
            // 更新 map
            m[hour_geo_key] = hour_geo_info
        }
    }

    // reduce 最终合并函数
    mergeReduce := func() {
        select {
        // merge model msg
        case m := <-modelChan:
            logger.Info("parallel load[%s]||line_num=%d", file_name, len(m))
            for k := range m {
                model_data_center.HourGEOInfoData[k] = m[k]
            }
        }
    }

    ParallelLoadModelFile(content, SepLine, file_name, 3, mapParse, mergeReduce)
    return nil
}
上一篇下一篇

猜你喜欢

热点阅读