go-xgboost内存占用极高的原因分析

2020-03-04  本文已影响0人  AI_Finance

在用go-xgboost做模型预测的时候,大家肯定非常熟悉下面的一条语句,但是这条语句非常的占用内存,因为模型在执行这样的加载时,顺带进行了以下循环:
predictor, err := xgboost.NewPredictor(modelDir+fileName, runtime.NumCPU(), 0, 5000, -1)

如果workCount=4,意味着任何一个模型均要加载四次

func NewPredictor(xboostSavedModelPath string, workerCount int, optionMask int, nTreeLimit uint, missingValue float32) (Predictor, error) {
    if workerCount <= 0 {
        return nil, errors.New("worker count needs to be larger than zero")
    }

    requestChan := make(chan multiBoosterRequest)
    initErrors := make(chan error)
    defer close(initErrors)

    for i := 0; i < workerCount; i++ {
        go func() {
            runtime.LockOSThread()
            defer runtime.UnlockOSThread()

            booster, err := core.XGBoosterCreate(nil)
            if err != nil {
                initErrors <- err
                return
            }

            err = booster.LoadModel(xboostSavedModelPath)
            if err != nil {
                initErrors <- err
                return
            }

            // No errors occured during init
            initErrors <- nil

            for req := range requestChan {
                data, rowCount, columnCount := req.matrix.Data()
                matrix, err := core.XGDMatrixCreateFromMat(data, rowCount, columnCount, missingValue)
                if err != nil {
                    req.resultChan <- multiBoosterResponse{
                        err: err,
                    }
                    continue
                }

                res, err := booster.Predict(matrix, optionMask, nTreeLimit)
                req.resultChan <- multiBoosterResponse{
                    err:    err,
                    result: res,
                }
            }
        }()

        err := <-initErrors
        if err != nil {
            return nil, err
        }
    }

    return &multiBooster{reqChan: requestChan}, nil
}
上一篇 下一篇

猜你喜欢

热点阅读