Spark-RDD操作MySQL

2020-07-06  本文已影响0人  布莱安托

Spark支持通过Java JDBC访问关系型数据库,需要通过JdbcRDD进行访问,示例如下:

  1. 添加依赖
// 在build.sbt中添加依赖
libraryDependencies ++= Seq (
  "mysql" % "mysql-connector-java" % "5.1.47"
)
  1. MySQL读取
import java.sql.DriverManager

import org.apache.spark.rdd.JdbcRDD
import org.apache.spark.{SparkConf, SparkContext}

object MySQLDemo {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("MySQLDemo").setMaster("local[4]")
    val sc = new SparkContext(conf)

    val driver = "com.mysql.jdbc.Driver"
    val url = "jdbc:mysql://172.16.0.31:3306/db_canal_test"
    val username = "root"
    val password = "123456"

    val sql = "select name, age from tbl_person_info where id >= ? and id <= ?"

    val jdbcRdd = new JdbcRDD(sc,
      () => DriverManager.getConnection(url, username, password),
      sql, 1, 2, 2,
      (res) => {
        println(res.getString(1) + ", " + res.getInt(2))})

    jdbcRdd.collect()
    sc.stop()

  }
}

结果:

john, 18

lucy, 20

JdbcRDD的构造方法中有7个参数,他们分别是:

  1. sc: SparkContext - 当前应用的SparkContext对象
  2. getConnection: () => Connection - 获取Jdbc链接对象的方法
  3. sql: String - 请求的sql
  4. lowerBound: Long - 数据下边界
  5. upperBound: Long - 数据上边届
  6. numPartitions: Int - 分区数量
  7. mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _ - 用来处理结果的方法

其中,5、6两个参数值会通过sql中的?占位符传递进去,如果sql中没有占位符,就会抛出java.sql.SQLException: Parameter index out of range (1 > number of parameters, which is 0)异常

  1. MySQL写入
import java.sql.DriverManager

import org.apache.spark.{SparkConf, SparkContext}

object MySQLWriteDemo {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("MySQLWriteDemo").setMaster("local[4]")
    val sc = new SparkContext(conf)

    val driver = "com.mysql.jdbc.Driver"
    val url = "jdbc:mysql://172.16.0.31:3306/db_canal_test"
    val username = "root"
    val password = "123456"

    val dataRdd = sc.parallelize(List(("steve", 30), ("elly", 21), ("sam", 13)))

    dataRdd.foreach {
      case (name, age) => {
        val conn = DriverManager.getConnection(url, username, password)
        val sql = "insert into tbl_person_info (name, age) values (?, ?)"
        val statement = conn.prepareStatement(sql)
        try {
          statement.setString(1, name)
          statement.setInt(2, age)
          statement.executeUpdate()
        } finally {
          statement.close()
          conn.close()
        }
      }
    }

    sc.stop()

  }
}

通过sql查询结果:

1 john 18
2 lucy 20
3 elly 21
4 sam 13
5 steve 30

对于上述代码,写入MySQL的功能已经实现,但是由于与MySQL的链接是在foreach中创建的,那就意味着RDD中有多少元素就会进行多少次的链接创建,当数据量增大后,这种建立连接的开销是巨大的,于是我们将建立连接提前:

import java.sql.DriverManager

import org.apache.spark.{SparkConf, SparkContext}

object MySQLWriteDemo {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("MySQLWriteDemo").setMaster("local[4]")
    val sc = new SparkContext(conf)

    val driver = "com.mysql.jdbc.Driver"
    val url = "jdbc:mysql://172.16.0.31:3306/db_canal_test"
    val username = "root"
    val password = "123456"

    val dataRdd = sc.parallelize(List(("steve", 30), ("elly", 21), ("sam", 13)))

    val conn = DriverManager.getConnection(url, username, password)

    try {
      dataRdd.foreach {
        case (name, age) => {
          val sql = "insert into tbl_person_info (name, age) values (?, ?)"
          val statement = conn.prepareStatement(sql)
          try {
            statement.setString(1, name)
            statement.setInt(2, age)
            statement.executeUpdate()
          } finally {
            statement.close()
          }
        }
      }
    } finally {
      conn.close()
    }

    sc.stop()

  }
}

我们再次运行,此时抛出异常:org.apache.spark.SparkException: Task not serializable

由于建立连接相关对象不能序列化,导致序列化异常,于是我们改进如下:

import java.sql.DriverManager

import org.apache.spark.{SparkConf, SparkContext}

object MySQLWriteDemo {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("MySQLWriteDemo").setMaster("local[4]")
    val sc = new SparkContext(conf)

    val driver = "com.mysql.jdbc.Driver"
    val url = "jdbc:mysql://172.16.0.31:3306/db_canal_test"
    val username = "root"
    val password = "123456"

    val dataRdd = sc.parallelize(List(("steve", 30), ("elly", 21), ("sam", 13)))

    dataRdd.foreachPartition(iter => {
      val conn = DriverManager.getConnection(url, username, password)
      try {
        iter.foreach {
          case (name, age) => {
            val sql = "insert into tbl_person_info (name, age) values (?, ?)"
            val statement = conn.prepareStatement(sql)
            try {
              statement.setString(1, name)
              statement.setInt(2, age)
              statement.executeUpdate()
            } finally {
              statement.close()
            }
          }
        }
      } finally {
        conn.close()
      }
    })

    sc.stop()

  }
}

我们将之前插入的三条数据删除后,再次执行,之后查询结果:

1 john 18
2 lucy 20
6 sam 13
7 steve 30
8 elly 21

通过foreachPartition的方式遍历分区,只需每个分区建立一个连接即可,大大减少了连接的数量。

上一篇 下一篇

猜你喜欢

热点阅读