Alink中机器学习模型的训练,保存及调用

2022-02-15  本文已影响0人  老羊_肖恩

  机器学习训练算法比较复杂去数据集规模较大,通常需要在分布式环境中进行,但是使用训练出来的模型进行预测往往简单很多,一般可以单个或者多个节点对模型进行装载,从而进行多路预测。Alink提供了由参数或模型数据直接构建一个本地的java实例,我们称之为LocalPredictor,可以对单条数据进行预测。这样的话,预测任务不再必须由Flink完成,可以嵌入到提供RestAPI的预测服务系统,或者嵌入到用户的业务系统里。
  以酒店评论情感分析为例,数据集为:ChnSentiCorp_htl_all.csv,通过构建完整的中文情感分析pipeline,并将训练好的模型保存在指定位置:

        //数据文件地址
        String url = "D:\\Workspace\\data\\ChnSentiCorp_htl_small.csv";
        //数据schema
        String schemaStr = "label bigint, review string";
        //定义数据源
        BatchOperator data = new CsvSourceBatchOp()
                .setFilePath(url)
                .setSchemaStr(schemaStr)
                .setIgnoreFirstLine(true);
        //Shuffle数据集
        data = new ShuffleBatchOp().linkFrom(data);

        //按照7:3分割数据,生成训练集和测试集
        SplitBatchOp splitter = new SplitBatchOp().setFraction(0.7);
        BatchOperator trainData = splitter.linkFrom(data);
        BatchOperator testData = splitter.getSideOutput(0);

        //构建文本分类pipeline
        Pipeline pipeline = new Pipeline(
                //缺失值填充
                new Imputer()
                        .setSelectedCols("review")
                        .setOutputCols("featureText")
                        .setStrategy("value")
                        .setFillValue("null"),
                //分词
                new Segment()
                        .setSelectedCol("featureText"),
                //停用词过滤
                new StopWordsRemover()
                        .setSelectedCol("featureText"),
                //文本特征生成
                new DocCountVectorizer()
                        .setFeatureType("TF")
                        .setSelectedCol("featureText")
                        .setOutputCol("featureVector"),
                //逻辑回归二分类
                new LogisticRegression()
                        .setVectorCol("featureVector")
                        .setLabelCol("label")
                        .setPredictionCol("pred")
        );

        //模型训练
        PipelineModel model = pipeline.fit(trainData);

        //模型效果评估
        BatchOperator<?> predict = model.transform(testData);
        MultiClassMetrics metrics = new EvalMultiClassBatchOp()
                .setLabelCol("label")
                .setPredictionCol("pred")
                .linkFrom(predict)
                .collectMetrics();
        System.out.println("accuracy:" + metrics.getAccuracy("1"));
        System.out.println("recall:" + metrics.getRecall("1"));
        System.out.println("Macro Precision:" + metrics.getMacroPrecision());
        System.out.println("Micro Recall:" + metrics.getMicroRecall());
        System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());

        //模型保存
        model.save("D:\\Workspace\\models\\SentimentHotel_model_0001", true);
        //这一行要有,不然不保存
        BatchOperator.execute();

  模型效果的评估结果如下图所示:

模型效果

  如果模型的效果能达到预期,那么将模型保存到指定的位置,方便后续的业务系统进行调用。这里我们可以发现,训练完成的模型保存到本地,生成了一个非常小的模型文件。后期业务系统可以直接使用这个模型对外提供模型预测服务。

模型保存

  业务系统可以使用LocalPredictor对指定位置的模型进行加载和调用,代码如下:

        //数据schema
        String SCHEMA_STR = "review string";

        //读取
        LocalPredictor localPredictor = new LocalPredictor("D:\\Workspace\\models\\SentimentHotel_model_0001", SCHEMA_STR);

        Row[] rows = new Row[]{
                Row.of("这个酒店给人留下了永远难忘的印象--垃圾!!" +
                        "奉劝各位千万不要再去那里了,保你后悔没及!总体印象是:无论酒店,商店," +
                        "交通还是旅游景点,除了收费已和国际接轨了以外," +
                        "其他的都是老、少、边区的水平!!!!!!!!悲哀呀"),
                Row.of("酒店的设施和环境都不错的,就是周围没有什么集市和超市," +
                        "在房间的阳台上就能看到一望无际的大海,真的心情非常的不错." +
                        "唯一的就是每天的早餐都是一样的东西.离机场和市区也不是太远."),
                Row.of("房间很宽敞,干净卫生,就是不知道为啥隔音很差,整体还行"),
                Row.of("价格还不错,出行很方便,临近地铁站,去机场也很方便," +
                        "楼上就四磁悬浮啦,但是房间是真的旧,地板缩水严重。")
        };
        for (Row row : rows) {
            Row predict = localPredictor.map(row);
            System.out.println(predict.getField(0) + "  prediction:" + predict.getField(3));
        }

  该中文情感分析的模型,调用结果如下:

模型调用结果

  以上就是使用Alink进行机器学习建模的全过程,建模的过程中,由于训练数据往往很庞大因此数据处理和模型训练的过程需要放在flink集群中去完成,最终生成满足业务需求的模型。由于生成的数据模型很小,且通常需要内嵌到业务系统中对外提供模型预测服务,因此可以将模型预测的功能于flink集群进行脱离,直接在业务系统中载入模型后对外提供预测服务。

参考:
https://www.yuque.com/pinshu/alink_guide/zo1y6q
https://www.yuque.com/pinshu/alink_guide/pz7rcl

上一篇 下一篇

猜你喜欢

热点阅读