[源码分析]spark shuffle的读操作

2018-08-02  本文已影响149人  糖哗啦

上一篇解读了shuffle写操作的流程,相比较shuffle读操作而言是比较简单的;
shuffle读取过程比较耗内存,由于在最后会把所有的数据拉入到缓存中进行聚合;

shulle读取过程中比较复杂,涉及的内容过多;下面从入口开始解读:

一、shuffle读取数据由RDD的computeOrReadCheckpoint方法作为入口;
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }
二、调用ShuffleRDD的compute方法

首先是调用ShuffleManager【哈希shuffleManager 和排序shuffleManager】的getReader方法获取ShuffleReader,BlockStoreShuffleReader继承了ShuffleReader,使用该类的read方法进行读取数据

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }
三、read方法中,首先时创建ShuffleBlockFetcherIterator对象

该对象是一个迭代器,用于抓取多个block而形成的迭代器,初始化该对象时,会先取获取block所在的位置,以及对应的block信息以及block大小,由val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)完成;
每次抓取数据的上限由参数spark.reducer.maxSizeInFlight设定,默认48M

override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

    // Wrap the streams for compression based on configuration
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)
    }

    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }
四、read方法一步一步进行分解解读
(a)获取各个block位置信息的过程,暂时不做解释,等有机会再做分析
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
    val statuses: Array[MapStatus] = getStatuses(shuffleId)
    // Synchronize on the returned array because, on the driver, it gets mutated in place
    statuses.synchronized {
      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
    }
  }
(b)初始化ShuffleBlockFetcherIterator过程中,首先自身进行初始化,调用initialize方法
private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    //用于task任务运行完后,清理占用的内存
    context.addTaskCompletionListener(_ => cleanup())

    // Split local and remote blocks.划分数据的读取方式
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order,将生成的远程请求,随机排序
    fetchRequests ++= Utils.randomize(remoteRequests)

    // Send out initial requests for blocks, up to our maxBytesInFlight
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }
初始化过程分解解读:
(b-1)splitLocalRemoteBlocks 进行划分读取数据的方式:本地读取 or 远程读取,会分别放入到不同的缓存中,最后返回远程请求的缓存对象。

【1】当数据和所在的BlockManager在一个节点时,把该信息加入到localBlocks列表中
【2】当数据和所在的BlockManager不在一个节点时,把该信息加入到remoteRequests 列表中;生成的FetchRequest【生成条件:一个FetchRequest请求的block块大小总和大于等于maxBytesInFlight/5】,会把block的信息放入进去,包括block所在的位置,blockId,block 大小(block中对应partition的数据大小)。
【注意:此处生成的FetchRequest的可能会发生内存泄漏,因为如果单个block过大,拉取过来占用堆外内存过大,造成OOM】

 private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
    // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
    // nodes, rather than blocking on reading output from one node.
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
    // at most maxBytesInFlight in order to limit the amount of data in flight.
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      //当数据和所在的BlockManager在一个节点时,把该信息加入到localBlocks列表中
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // Filter out zero-sized blocks
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {//不在一个节点的
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          //如果一个block过大,会引发堆外内存oom
          if (curRequestSize >= targetRequestSize) {//当获取的数据超过阈值后,会生成一个FetchRequest对象
            // 生成FetchRequest对象
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }
        // Add in the final request
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }
(b-2)fetchUpToMaxBytes

发送远程请求,进行拉取数据,需要满足条件才会发送fetch请求:首先必须有fetch请求即FetchRequest,然后队列中第一个请求的block大小+之前发送的请求的block快大小和要小于等于48M,或者为第一次发送fetch请求【如果是第一次发送fetch请求,所属的block块过大,那么就会可能发生OOM的风险,因为fetch的数据会放入到ManagerBuffer当中,堆外内存中】

private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight
    //如果不满足条件,怎么发送请求
    //fetchRequests.front.size大小指的是block中所属redcue的大小
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
      sendRequest(fetchRequests.dequeue())
    }
  }

发送的过程如下:

.累计发送的请求大小:bytesInFlight += req.size,保证抓取的数据大小不能超过48M或者自己设定的值大小
.由NettyBlockTransferService实现抓取数据的功能
.抓取成功失败都会将结果放入到LinkedBlockingQueue[FetchResult]队列中,该队列将来在下面的聚合时候调用ShuffleBlockFetchIterator类的next方法时,会从该队列中取值,next方法见下面分析。
private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size

    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val blockIds = req.blocks.map(_._1.toString)

    val address = req.address
    //由NettyBlockTransferService实现抓取数据的功能
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        //抓取成功失败都会将结果放入到LinkedBlockingQueue[FetchResult]队列中
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            buf.retain()
            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
            shuffleMetrics.incRemoteBytesRead(buf.size)
            shuffleMetrics.incRemoteBlocksFetched(1)
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }

        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }
(b-3)fetchLocalBlocks 读取本地的block,成功与否都会放入到LinkedBlockingQueue队列中;
private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
          return
      }
    }
  }
五、ShuffleBlockFetcherIterator初始化完后,对其得到的数据进行压缩
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)
    }
六、获取序列化实例
 val ser = Serializer.getSerializer(dep.serializer)
 val serializerInstance = ser.newInstance()
七、对数据进行反序列化,,得到一个key/value 的iterator
val recordIter = wrappedStreams.flatMap { wrappedStream =>
     serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
 }
八、这段不太清楚作用是干嘛的??????????????????
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
九、进行聚合,
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }
十、聚合过程,根据是否在map端进行了聚合,选择不同的聚合方法

这里我们选择在map端聚合的进行分析,进入到combineCombinersByKey方法中,该方法中,首先会进行拉取数据,对其聚合排序,如果内存不够,spill到本地磁盘,可能会产生多个文件,最后会对所有的spill文件和内存中的数据进行聚合,并作为一个集合返回

def combineCombinersByKey(
      iter: Iterator[_ <: Product2[K, C]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
    //拉取数据到本地节点,会对其进行聚合排序,并会产生spill文件
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    //返回的是所有spill到磁盘上的文件和内存中合并后的集合
    combiners.iterator
  }
十一、具体的聚合过程在方法insertAll(iter)中

此处的iter就是上面ShuffleBlockFetchIterator,下面的next就是调用了ShuffleBlockFetchIterator中的next方法。

def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
    if (currentMap == null) {
      throw new IllegalStateException(
        "Cannot insert new elements into a map after calling iterator")
    }
    // An update function for the map that we reuse across entries to avoid allocating
    // a new closure each time
    var curEntry: Product2[K, V] = null
    val update: (Boolean, C) => C = (hadVal, oldVal) => {
      if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
    }

    while (entries.hasNext) {
      curEntry = entries.next()
      val estimatedSize = currentMap.estimateSize()
      if (estimatedSize > _peakMemoryUsedBytes) {
        _peakMemoryUsedBytes = estimatedSize
      }
      //判断是否满足溢写的条件,具体可以看看shuffle writer流程
      if (maybeSpill(currentMap, estimatedSize)) {
        currentMap = new SizeTrackingAppendOnlyMap[K, C]
      }
      //聚合
      currentMap.changeValue(curEntry._1, update)
      addElementsRead()
    }
  }
十二、上面所说的next方法:

如果fetch请求成功获取了block,那么累计fetch请求的block大小会不断的释放,然后再发送rpc取抓取剩下的数据,next返回时一个block的id 和一个inputstream流对象

override def next(): (BlockId, InputStream) = {
    numBlocksProcessed += 1
    val startFetchWait = System.currentTimeMillis()
    currentResult = results.take()
    val result = currentResult
    val stopFetchWait = System.currentTimeMillis()
    shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

    result match {
      case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
      case _ =>
    }
    // Send fetch requests up to maxBytesInFlight
    fetchUpToMaxBytes()

    result match {
      case FailureFetchResult(blockId, address, e) =>
        throwFetchFailedException(blockId, address, e)

      case SuccessFetchResult(blockId, address, _, buf) =>
        try {
          (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
        } catch {
          case NonFatal(t) =>
            throwFetchFailedException(blockId, address, t)
        }
    }
  }
十三、上面的iterator方法,如果有spill文件,会对其和内存中 的数据进行merge,如果只有内存中的数据,直接返回
override def iterator: Iterator[(K, C)] = {
    if (currentMap == null) {
      throw new IllegalStateException(
        "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
    }
    if (spilledMaps.isEmpty) {
      CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
    } else {
      new ExternalIterator()
    }
  }

  private def freeCurrentMap(): Unit = {
    currentMap = null // So that the memory can be garbage-collected
    releaseMemory()
  }
十四、如果有spill文件,会在ExternalIterator中进行,将所有的数据读取到mergeHeap队列中,一个流的数据都会放到ArrayBuffer缓存中。
private class ExternalIterator extends Iterator[(K, C)] {

    // A queue that maintains a buffer for each stream we are currently merging
    // This queue maintains the invariant that it only contains non-empty buffers
    private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]

    // Input streams are derived both from the in-memory map and spilled maps on disk
    // The in-memory map is sorted in place, while the spilled maps are already in sorted order
    private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
      currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
    private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

    inputStreams.foreach { it =>
      val kcPairs = new ArrayBuffer[(K, C)]
      readNextHashCode(it, kcPairs)
      if (kcPairs.length > 0) {
        mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
      }
    }
十五、当开始调用ExternalIterator 的next方法时,才会进行最终的merge操作,那么什么时候调用呢?看下面的keyOrdering过程,如果设定了排序规则,则在排序中进行聚合,如果没有设定,就在以后取数的过程中聚合了
//取出hash 最小的key,从所有的文件中取相同key的所有value进行聚合
override def next(): (K, C) = {
      if (mergeHeap.length == 0) {
        throw new NoSuchElementException
      }
      // Select a key from the StreamBuffer that holds the lowest key hash
      val minBuffer = mergeHeap.dequeue()  //从队列中取出优先级最高的,并将其从该队列中移除
      val minPairs = minBuffer.pairs
      val minHash = minBuffer.minKeyHash
      val minPair = removeFromBuffer(minPairs, 0)
      val minKey = minPair._1
      var minCombiner = minPair._2
      assert(hashKey(minPair) == minHash)

      // For all other streams that may have this key (i.e. have the same minimum key hash),
      // merge in the corresponding value (if any) from that stream
      val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
      while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) {//从其他ArrayBuffer中取出相同key的
        val newBuffer = mergeHeap.dequeue()
        minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
        mergedBuffers += newBuffer
      }

      // Repopulate each visited stream buffer and add it back to the queue if it is non-empty
      mergedBuffers.foreach { buffer =>
        if (buffer.isEmpty) {
          readNextHashCode(buffer.iterator, buffer.pairs)
        }
        if (!buffer.isEmpty) {
          mergeHeap.enqueue(buffer)
        }
      }

      (minKey, minCombiner)
    }
十六、如果我们设定了排序,则按照以下步骤进行排序

使用的也是外部排序类,会在spill过程中排序;跟shuffle write过程使用的一样,请翻看shuffle write操作;

dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }

总结

(1)可优化的参数:
(a)spark.reducer.maxSizeInFlight 默认48M 限定了一次fetch最大值。此处可以调节该参数,如果调大该参数,可以减少发生rpc次数,减轻网络的压力;
(2)可能存在的风险
(a)在生成的FetchRequest时,可能会发生内存泄漏,因为如果单个block过大,拉取过来占用堆外内存过大,造成OOM】

注意:见解纰漏,如有误解,请提示

上一篇下一篇

猜你喜欢

热点阅读