SparkML 实现 LR 算法

2020-06-16  本文已影响0人  乌鲁木齐001号程序员

离散特征

举例
离散特征 | 处理

连续特征

举例
连续特征 | 标准化 | 处理
连续特征 | 离散化 | 处理

特征处理

featurevalue.csv
"用户id","年龄","性别","门店id","评分","人均价格","是否点击"
"1","22","M","315","4","193","0"
"1","16","F","431","3","193","1"
"1","62","F","489","3","72","1"
"1","12","M","398","0","216","1"
"1","76","M","307","3","131","0"
"1","54","M","490","1","205","0"
"1","38","M","308","2","227","1"
"1","56","M","400","3","82","1"
"1","65","F","426","0","136","0"
"2","48","F","328","3","64","1"
feature.csv
"1","0","0","0","1","0","0.8","0","0","1","0","0"
"1","0","0","0","0","1","0.6","0","0","1","0","1"
"0","0","0","1","0","1","0.6","0","1","0","0","1"
"1","0","0","0","1","0","0.0","0","0","0","1","1"
"0","0","0","1","1","0","0.6","0","0","1","0","0"
"0","0","1","0","1","0","0.2","0","0","0","1","0"
"0","1","0","0","1","0","0.4","0","0","0","1","1"
"0","0","1","0","1","0","0.6","0","1","0","0","1"
"0","0","0","1","0","1","0.0","0","0","1","0","0"
"0","0","1","0","0","1","0.6","0","1","0","0","1"

LR 模型生成

LR 模型生成 | 步骤
LR 模型生成 | 代码
package tech.lixinlei.dianping.recommand;

import java.io.IOException;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class LRTrain {

    public static void main(String[] args) throws IOException {

        // 初始化spark运行环境
        SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();

        // 加载特征及 label 训练文件
        JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/feature.csv").toJavaRDD();

        // 做转化
        JavaRDD<Row> rowJavaRDD = csvFile.map(new Function<String, Row>() {
            /**
             *
             * @param v1 feature.csv 中的一行数据;
             * @return
             * @throws Exception
             */
            @Override
            public Row call(String v1) throws Exception {
                v1 = v1.replace("\"", "");
                String[] strArr = v1.split(",");
                return RowFactory.create(new Double(strArr[11]),
                                         Vectors.dense(
                                              Double.valueOf(strArr[0]),
                                              Double.valueOf(strArr[1]),
                                              Double.valueOf(strArr[2]),
                                              Double.valueOf(strArr[3]),
                                              Double.valueOf(strArr[4]),
                                              Double.valueOf(strArr[5]),
                                              Double.valueOf(strArr[6]),
                                              Double.valueOf(strArr[7]),
                                              Double.valueOf(strArr[8]),
                                              Double.valueOf(strArr[9]),
                                              Double.valueOf(10)));
            }
        });

        // 定义列
        StructType schema = new StructType(
                new StructField[]{
                        new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                        new StructField("features",new VectorUDT(),false, Metadata.empty())
                }
        );

        // data 只有两列,第一列 label,第二列是个 11 维的向量;
        Dataset<Row> data = spark.createDataFrame(rowJavaRDD, schema);

        // 训练集和测试集
        Dataset<Row>[] dataArr = data.randomSplit(new double[]{0.8, 0.2});
        Dataset<Row> trainData = dataArr[0];
        Dataset<Row> testData = dataArr[1];

        // 模型训练 | 逻辑回归
        LogisticRegression lr = new LogisticRegression()
                .setMaxIter(10) // 迭代次数
                .setRegParam(0.3)
                .setElasticNetParam(0.8)
                .setFamily("multinomial");
        LogisticRegressionModel lrModel = lr.fit(trainData);
        lrModel.save("file:///home/lixinlei/project/gitee/dianping/src/main/resources/lrmode");

        // 测试评估
        Dataset<Row> predictions = lrModel.transform(testData);
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator();
        double accuracy = evaluator.setMetricName("accuracy").evaluate(predictions);

        System.out.println("auc = " + accuracy);

    }

}
上一篇 下一篇

猜你喜欢

热点阅读