Scala深海奇遇记-当case class遇到了Spark的聚
自从知道有case class这个东西以后,一直都比较常用这个东西。但是,最近在测试的时候,突然发现,其实这个东西并不简单,它导致了一个看起来很无厘头的错误,并且花了我两天的时间来调试。
在这篇文章里,我会详细记录调试的过程,以及结论。
致谢
在调试的过程中,得到了我们Hadoop组老大,项目组老大,以及其他同事的深度支持与帮助,非常感谢他们。
结论
先说结论。如果有朋友不感兴趣,不想深究原理,只是想知道怎么用,可以跳过后面的部分。
不要把case class放在class里面。例如:
class Test {
case class A(a: String)
}
可以放到object里面,或者放到package里面。例如:
object Test {
case class A(a: String)
}
或者
package Test {
case class A(a: String)
}
如果不遵循这条原则,那么有极大的概率,Spark在执行的时候,相同的key并不会聚集到一起。
分析
首先,我们有如下代码:
package org.apache.spark.test
import org.apache.spark.{Partitioner, SparkConf, SparkContext}
import scala.util.Random
object VerifySparkBug {
def main(args: Array[String]): Unit = {
new VerifySparkBug().run()
}
}
class VerifySparkBug extends Serializable {
def run() = {
val sparkConf = new SparkConf()
sparkConf.setMaster("local")
sparkConf.setAppName("VerifySparkBug")
val sparkContext = new SparkContext(sparkConf)
val inputRDD = sparkContext.parallelize(Seq(
(UidAndPartition("1", "2"), 1L),
(UidAndPartition("1", "1"), 1L),
(UidAndPartition("1", "3"), 1L),
(UidAndPartition("2", "1"), 1L),
(UidAndPartition("2", "2"), 1L),
(UidAndPartition("3", "3"), 1L)
)).partitionBy(new RandomPartitioner(10))
inputRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ inputRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val uidRDD = inputRDD.map{
case pair => {
(Uid(pair._1.uid), pair._2)
}
}
uidRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ uidRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val result = uidRDD.aggregateByKey(0L)(
(counter: Long, number: Long) => {
counter + number
},
(counter1: Long, counter2: Long) => {
counter1 + counter2
}
).collect()
for (pair <- result) {
System.out.println("------ key: " + pair._1 + ", uv: " + pair._2)
}
}
case class UidAndPartition(uid: String, partition: String)
case class Uid(idType: String)
}
class RandomPartitioner(partitions: Int) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = {
new Random().nextInt(partitions)
}
}
它的输出如下:
------ key: Uid(1), uv: 1
------ key: Uid(1), uv: 1
------ key: Uid(1), uv: 1
------ key: Uid(3), uv: 1
------ key: Uid(2), uv: 1
------ key: Uid(2), uv: 1
很诡异,对吧?明明key相同,却并没有被聚合到一起。而且,更诡异的是,key相同的hash值也相同。
猜想1
开始,我们怀疑是case class虽然默认实现了Serializable接口,但是并没有生成serialVersionUID,所以在序列化以及反序列化时,生成的类其实并不是同一个。导致反序列化出来的对象不是同一个。
反编译后的Uid的部分代码如下:
于是,我就用Netty写了一个服务器端和客户端的程序,来测试这个问题:
服务器端:
package org.apache.spark.test
import java.io._
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.ByteBuf
import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter, ChannelInitializer, SimpleChannelInboundHandler}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.{NioServerSocketChannel, NioSocketChannel}
import org.apache.spark.test.VerifySparkBug.Uid
object ServerForCaseClassEquals {
def main(args: Array[String]): Unit = {
val serverBootstrap = new ServerBootstrap()
val bossEventLoopGroup = new NioEventLoopGroup()
val workerEventLoopGroup = new NioEventLoopGroup()
serverBootstrap.group(bossEventLoopGroup, workerEventLoopGroup)
.channel(classOf[NioServerSocketChannel])
.childHandler(new ChannelInitializer[NioSocketChannel] {
override def initChannel(c: NioSocketChannel) = {
c.pipeline().addLast(new ChannelInboundHandlerAdapter {
override def channelRead(ctx: ChannelHandlerContext, msg: scala.Any): Unit = {
val byteBuf = msg.asInstanceOf[ByteBuf]
val length = byteBuf.readInt()
val bytes = new Array[Byte](length)
byteBuf.readBytes(bytes)
val byteInputStream = new ByteArrayInputStream(bytes)
val objectInputStream = new ObjectInputStream(byteInputStream)
println(objectInputStream.readObject() == Uid("a"))
}
})
}
})
.bind(8000)
}
}
客户端:
package org.apache.spark.test
import java.io.{ByteArrayOutputStream, ObjectOutputStream}
import io.netty.bootstrap.Bootstrap
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter, ChannelInitializer}
import org.apache.spark.test.VerifySparkBug.Uid
object ClientForCaseClassEquals {
def main(args: Array[String]): Unit = {
val bootstrap = new Bootstrap()
val eventLoopGroup = new NioEventLoopGroup()
bootstrap.group(eventLoopGroup)
.channel(classOf[NioSocketChannel])
.handler(new ChannelInitializer[NioSocketChannel] {
override def initChannel(c: NioSocketChannel) = {
c.pipeline().addLast(new ChannelInboundHandlerAdapter {
override def channelActive(ctx: ChannelHandlerContext): Unit = {
val uid = Uid("a")
val byteArrayOutputStream = new ByteArrayOutputStream()
val objectOutputStream = new ObjectOutputStream(byteArrayOutputStream)
objectOutputStream.writeObject(uid)
objectOutputStream.close()
val bytes = byteArrayOutputStream.toByteArray
val byteBuf = ctx.alloc().buffer()
byteBuf.writeInt(bytes.length)
byteBuf.writeBytes(bytes)
ctx.channel().writeAndFlush(byteBuf)
}
})
}
})
.connect("localhost", 8000)
}
}
而这个测试自然失败了。测试的结果是,反序列化解析出来的对象,是==反序列化之前的对象的。
为什么用==?因为在scala中,==其实就是equals()方法的语法糖。而且,spark的代码中,聚合函数中,对key进行比较时,也是用的==。
其实这种假设根本就经不起推敲。
如果序列化和反序列的时候,类不是一个,即serialVersionUID不同,那么,其实反序列化是不会成功的,它是会直接报错的。而我们的程序中并没有报错。说明其实是同一个。
那么既然case class在生成的时候,并没有生成serialVersionUID。那我们怎么能够确定,在运行的时候,即使是在不同的节点,不同的JVM上,它们就是同一个呢?这个问题,在Java Specification中的序列化和反序列化部分有详细描述。总之,就是根据类的相关信息,生成一个serialVersionUID。由于这个信息跟内存地址无关,所以即使是运行时生成,也是唯一且相同的。
猜想2
这个假设给推掉了。为了复现这个问题,我写了如下代码,打算来进行复现。上面的代码是最终版的复现版本。在下面的代码中,虽然问题看似得到了复现,实际上是完全不同的问题。
package org.apache.spark.test
import org.apache.spark.{Partitioner, SparkConf, SparkContext}
import scala.util.Random
object VerifySparkBug {
def main(args: Array[String]): Unit = {
new VerifySparkBug().run()
}
}
class VerifySparkBug extends Serializable {
def run() = {
val sparkConf = new SparkConf()
sparkConf.setMaster("local")
sparkConf.setAppName("VerifySparkBug")
val sparkContext = new SparkContext(sparkConf)
val inputRDD = sparkContext.parallelize(Seq(
(UidAndPartition("1", "2"), 1L),
(UidAndPartition("1", "1"), 1L),
(UidAndPartition("1", "3"), 1L),
(UidAndPartition("2", "1"), 1L),
(UidAndPartition("2", "2"), 1L),
(UidAndPartition("3", "3"), 1L)
))
inputRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ inputRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val uidRDD = inputRDD.map{
case pair => {
(Uid(pair._1.uid), pair._2)
}
}.partitionBy(new RandomPartitioner(10))
uidRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ uidRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val result = uidRDD.aggregateByKey(0L)(
(counter: Long, number: Long) => {
counter + number
},
(counter1: Long, counter2: Long) => {
counter1 + counter2
}
).collect()
for (pair <- result) {
System.out.println("------ key: " + pair._1 + ", uv: " + pair._2)
}
}
case class UidAndPartition(uid: String, partition: String)
case class Uid(idType: String)
}
class RandomPartitioner(partitions: Int) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = {
new Random().nextInt(partitions)
}
}
代码跟上面的很相似,只是partitionBy()的位置变了一下。
初衷很简单,因为我们并不能保证相同的Uid会在同一个partition,于是就通过RandomPartitioner来模拟随机分区的效果。
这段代码的问题是,如果在进行聚集(如aggregateByKey, reduceByKey等)之前,你有进行过partitionBy操作,那么,Spark会认为,你已经将相同的Key放到一个分区了,所以它在进行聚集的时候,是不会考虑聚集不同partition上,相同的key的情况的。
我们通过调试聚集函数通用的PairRDDFunctions.combineByKeyWithClassTag方法,就能看到这个结论:
我们看到,如果聚集函数前面刚好调用了partitionBy方法,聚集函数内部是不会调用partitionBy来将相同的Key分到一个分区的。
这是因为partitionBy以及聚集函数,其实都是ShuffleDependency。Spark内部做了优化,如果ShuffleDependency前面已经是Narrow Dependency了,那么就会把这个ShuffleDependency转换成Narrow Dependency。
所以上面看似正确的复现,实际上刚好走进了这个陷阱里面。
那么,上面的代码,如果考虑到这个问题,该怎么解决?一种方案是,通过HashPartitioner,将相同的key确实分到同一个分区里面去。另一种方法就是,在上面的代码中的partitionBy(new RandomPartitioner(10)) 后面,加一个map(r => r)。 为什么加个这个就会生效?
我们先看map() 函数,以及MapPartitionRDD的代码。
我们可以看到,map()函数,实际上是生成了一个没有partitioner的MapPartitionRDD。就是这个没有partitioner,帮了我们的大忙。没有partitioner,聚集函数就不知道前面是Narrow Dependency,所以就会考虑相同的key在不同分区的问题。
下面这张图片,是我采用第二种方式,调试的过程。可以看到,现在走的是ShuffleRDD。
然后,我们可以看到,丫的,相同的key还是没有被聚集到一起。
其实当时在写这个复现的Demo的时候,是聚到一起了,一度手舞足蹈。但是聚到一起的原因,是case class是放在Object里面了,没有放到class里,所以并没有正确的复现问题。
猜想3
采用猜想2中的第1种方法,我们把不同分区中,相同的Key放到同一个分区里,有如下代码:
package org.apache.spark.test
import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}
import scala.util.Random
object VerifySparkBug {
def main(args: Array[String]): Unit = {
new VerifySparkBug().run()
}
}
class VerifySparkBug extends Serializable {
def run() = {
val sparkConf = new SparkConf()
sparkConf.setMaster("local")
sparkConf.setAppName("VerifySparkBug")
val sparkContext = new SparkContext(sparkConf)
val inputRDD = sparkContext.parallelize(Seq(
(UidAndPartition("1", "2"), 1L),
(UidAndPartition("1", "1"), 1L),
(UidAndPartition("1", "3"), 1L),
(UidAndPartition("2", "1"), 1L),
(UidAndPartition("2", "2"), 1L),
(UidAndPartition("3", "3"), 1L)
))
inputRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ inputRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val uidRDD = inputRDD.map{
case pair => {
(Uid(pair._1.uid), pair._2)
}
}.partitionBy(new RandomPartitioner(10)).partitionBy(new HashPartitioner(10))
uidRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ uidRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val result = uidRDD.aggregateByKey(0L)(
(counter: Long, number: Long) => {
counter + number
},
(counter1: Long, counter2: Long) => {
counter1 + counter2
}
).collect()
for (pair <- result) {
System.out.println("------ key: " + pair._1 + ", uv: " + pair._2)
}
}
case class UidAndPartition(uid: String, partition: String)
case class Uid(idType: String)
}
class RandomPartitioner(partitions: Int) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = {
new Random().nextInt(partitions)
}
}
其实这些代码差距很小,用文本比较工具,很容易看出差距。
在运行这个过程中,我们发现,丫的,把它们放到同一个分区,还是不能聚到一起?太他么过分了。
于是,我们查看聚集函数中,是如何比较key的。
我们可以看到,是调用equals方法。
于是,我们就想,case class的equals方法,到底长什么样子?
这里有一点很坑,就是你看反编译出来的代码,在这里是会反编译错误的。你必须找其他的方式来看。我们采用的方式,是看Scala编译的中间结果。
运行scalac -Xprint:typer VerifySparkBug.scala,我们能看到,上面的代码中,Uid的equals()方法,scala会被编译成这种形式:
override <synthetic> def equals(x$1: Any): Boolean = Uid.this.eq(x$1.asInstanceOf[Object]).||(x$1 match {
case (_: VerifySparkBug.this.Uid) => true
case _ => false
}.&&({
<synthetic> val Uid$1: VerifySparkBug.this.Uid = x$1.asInstanceOf[VerifySparkBug.this.Uid];
Uid.this.idType.==(Uid$1.idType).&&(Uid$1.canEqual(Uid.this))
}))
};
我们再来看看,把Uid放到Object里面,生成的equals方法是什么样子。
代码如下:
package org.apache.spark.test
import org.apache.spark.test.VerifySparkBug.{Uid, UidAndPartition}
import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}
import scala.util.Random
object VerifySparkBug {
def main(args: Array[String]): Unit = {
new VerifySparkBug().run()
}
case class UidAndPartition(uid: String, partition: String)
case class Uid(idType: String)
}
class VerifySparkBug extends Serializable {
def run() = {
val sparkConf = new SparkConf()
sparkConf.setMaster("local")
sparkConf.setAppName("VerifySparkBug")
val sparkContext = new SparkContext(sparkConf)
val inputRDD = sparkContext.parallelize(Seq(
(UidAndPartition("1", "2"), 1L),
(UidAndPartition("1", "1"), 1L),
(UidAndPartition("1", "3"), 1L),
(UidAndPartition("2", "1"), 1L),
(UidAndPartition("2", "2"), 1L),
(UidAndPartition("3", "3"), 1L)
))
inputRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ inputRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val uidRDD = inputRDD.map{
case pair => {
(Uid(pair._1.uid), pair._2)
}
}.partitionBy(new RandomPartitioner(10)).partitionBy(new HashPartitioner(10))
uidRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ uidRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val result = uidRDD.aggregateByKey(0L)(
(counter: Long, number: Long) => {
counter + number
},
(counter1: Long, counter2: Long) => {
counter1 + counter2
}
).collect()
for (pair <- result) {
System.out.println("------ key: " + pair._1 + ", uv: " + pair._2)
}
}
}
class RandomPartitioner(partitions: Int) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = {
new Random().nextInt(partitions)
}
}
生成的equals()方法如下:
override <synthetic> def equals(x$1: Any): Boolean = Uid.this.eq(x$1.asInstanceOf[Object]).||(x$1 match {
case (_: org.apache.spark.test.VerifySparkBug.Uid) => true
case _ => false
}.&&({
<synthetic> val Uid$1: org.apache.spark.test.VerifySparkBug.Uid = x$1.asInstanceOf[org.apache.spark.test.VerifySparkBug.Uid];
Uid.this.idType.==(Uid$1.idType).&&(Uid$1.canEqual(Uid.this))
}))
};
我们可以看到,它们的区别,就在于放在class中的Uid,调用的是VerifySparkBug.this.Uid。而放在object中的Uid,调用的是com.hyper.cdp.label.spark.test.VerifySparkBug.Uid。后者很容易理解,它肯定是唯一且相同的。但是前者呢?它包含了一个VerifySparkBug.this啊。
那么什么是VerifySparkBug.this呢?我也不清楚。但是我写了一个Demo,验证了一个事实,对于每一个VerifySparkBug实例,它们的值都是不同的。
代码如下:
package org.apache.spark.test
class VerifyCaseClassInClass {
case class Test(a: String)
def test() = {
println(VerifyCaseClassInClass.this)
}
}
object VerifyCaseClassInClass {
def main(args: Array[String]): Unit = {
println(new VerifyCaseClassInClass().Test("a").equals(new VerifyCaseClassInClass().Test("a")))
new VerifyCaseClassInClass().test()
new VerifyCaseClassInClass().test()
}
}
这就很容易解释为什么case class在Class中时,即使Key相同,仍然不会不会聚到一起了。
在Shuffle的时候,对于收到的其他分区的Key,肯定是为每个分区都new一个VerifySparkBug。这样就导致看起来相同的key,即使它们的hash值也一样,但是就是聚不到一起。
case class的hashCode()方法倒是很中立,它跟内存地址等都无关。只要字段的值一样,生成的hash就相同。感兴趣的读者可以自行查看scala.runtime.ScalaRunTime._hashCode这个方法。
这次调试,也深刻验证了,两个对象的hashcode相同,它们却不一定equals这一个简单却并没有重视过的问题。
最终,能够正常工作的代码如下:
package org.apache.spark.test
import org.apache.spark.test.VerifySparkBug.{Uid, UidAndPartition}
import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}
import scala.util.Random
object VerifySparkBug {
def main(args: Array[String]): Unit = {
new VerifySparkBug().run()
}
case class UidAndPartition(uid: String, partition: String)
case class Uid(idType: String)
}
class VerifySparkBug extends Serializable {
def run() = {
val sparkConf = new SparkConf()
sparkConf.setMaster("local")
sparkConf.setAppName("VerifySparkBug")
val sparkContext = new SparkContext(sparkConf)
val inputRDD = sparkContext.parallelize(Seq(
(UidAndPartition("1", "2"), 1L),
(UidAndPartition("1", "1"), 1L),
(UidAndPartition("1", "3"), 1L),
(UidAndPartition("2", "1"), 1L),
(UidAndPartition("2", "2"), 1L),
(UidAndPartition("3", "3"), 1L)
))
inputRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ inputRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val uidRDD = inputRDD.map{
case pair => {
(Uid(pair._1.uid), pair._2)
}
}.partitionBy(new RandomPartitioner(10)).partitionBy(new HashPartitioner(10))
uidRDD.foreachPartition {
case iterator => {
val partitionId = new Random().nextInt()
while (iterator.hasNext) {
val item = iterator.next()
System.out.println("------ uidRDD partitionId: " + partitionId
+ ", key: " + item._1
+ ", hash:" + item._1.hashCode()
+ ",value: " + item._2)
}
}
}
val result = uidRDD.aggregateByKey(0L)(
(counter: Long, number: Long) => {
counter + number
},
(counter1: Long, counter2: Long) => {
counter1 + counter2
}
).collect()
for (pair <- result) {
System.out.println("------ key: " + pair._1 + ", uv: " + pair._2)
}
}
}
class RandomPartitioner(partitions: Int) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = {
new Random().nextInt(partitions)
}
}