[源码分析]spark shuffle的读操作
上一篇解读了shuffle写操作的流程,相比较shuffle读操作而言是比较简单的;
shuffle读取过程比较耗内存,由于在最后会把所有的数据拉入到缓存中进行聚合;
shulle读取过程中比较复杂,涉及的内容过多;下面从入口开始解读:
一、shuffle读取数据由RDD的computeOrReadCheckpoint方法作为入口;
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
}
二、调用ShuffleRDD的compute方法
首先是调用ShuffleManager【哈希shuffleManager 和排序shuffleManager】的getReader方法获取ShuffleReader,BlockStoreShuffleReader继承了ShuffleReader,使用该类的read方法进行读取数据
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
三、read方法中,首先时创建ShuffleBlockFetcherIterator对象
该对象是一个迭代器,用于抓取多个block而形成的迭代器,初始化该对象时,会先取获取block所在的位置,以及对应的block信息以及block大小,由val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)完成;
每次抓取数据的上限由参数spark.reducer.maxSizeInFlight设定,默认48M
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
四、read方法一步一步进行分解解读
(a)获取各个block位置信息的过程,暂时不做解释,等有机会再做分析
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses: Array[MapStatus] = getStatuses(shuffleId)
// Synchronize on the returned array because, on the driver, it gets mutated in place
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}
(b)初始化ShuffleBlockFetcherIterator过程中,首先自身进行初始化,调用initialize方法
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
//用于task任务运行完后,清理占用的内存
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.划分数据的读取方式
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order,将生成的远程请求,随机排序
fetchRequests ++= Utils.randomize(remoteRequests)
// Send out initial requests for blocks, up to our maxBytesInFlight
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
初始化过程分解解读:
(b-1)splitLocalRemoteBlocks 进行划分读取数据的方式:本地读取 or 远程读取,会分别放入到不同的缓存中,最后返回远程请求的缓存对象。
【1】当数据和所在的BlockManager在一个节点时,把该信息加入到localBlocks列表中
【2】当数据和所在的BlockManager不在一个节点时,把该信息加入到remoteRequests 列表中;生成的FetchRequest【生成条件:一个FetchRequest请求的block块大小总和大于等于maxBytesInFlight/5】,会把block的信息放入进去,包括block所在的位置,blockId,block 大小(block中对应partition的数据大小)。
【注意:此处生成的FetchRequest的可能会发生内存泄漏,因为如果单个block过大,拉取过来占用堆外内存过大,造成OOM】
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
// Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
//当数据和所在的BlockManager在一个节点时,把该信息加入到localBlocks列表中
if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
} else {//不在一个节点的
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
//如果一个block过大,会引发堆外内存oom
if (curRequestSize >= targetRequestSize) {//当获取的数据超过阈值后,会生成一个FetchRequest对象
// 生成FetchRequest对象
remoteRequests += new FetchRequest(address, curBlocks)
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
remoteRequests
}
(b-2)fetchUpToMaxBytes
发送远程请求,进行拉取数据,需要满足条件才会发送fetch请求:首先必须有fetch请求即FetchRequest,然后队列中第一个请求的block大小+之前发送的请求的block快大小和要小于等于48M,或者为第一次发送fetch请求【如果是第一次发送fetch请求,所属的block块过大,那么就会可能发生OOM的风险,因为fetch的数据会放入到ManagerBuffer当中,堆外内存中】
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight
//如果不满足条件,怎么发送请求
//fetchRequests.front.size大小指的是block中所属redcue的大小
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
}
发送的过程如下:
.累计发送的请求大小:bytesInFlight += req.size,保证抓取的数据大小不能超过48M或者自己设定的值大小
.由NettyBlockTransferService实现抓取数据的功能
.抓取成功失败都会将结果放入到LinkedBlockingQueue[FetchResult]队列中,该队列将来在下面的聚合时候调用ShuffleBlockFetchIterator类的next方法时,会从该队列中取值,next方法见下面分析。
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
//由NettyBlockTransferService实现抓取数据的功能
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
//抓取成功失败都会将结果放入到LinkedBlockingQueue[FetchResult]队列中
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
}
(b-3)fetchLocalBlocks 读取本地的block,成功与否都会放入到LinkedBlockingQueue队列中;
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
}
五、ShuffleBlockFetcherIterator初始化完后,对其得到的数据进行压缩
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
六、获取序列化实例
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
七、对数据进行反序列化,,得到一个key/value 的iterator
val recordIter = wrappedStreams.flatMap { wrappedStream =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
八、这段不太清楚作用是干嘛的??????????????????
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
九、进行聚合,
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
十、聚合过程,根据是否在map端进行了聚合,选择不同的聚合方法
这里我们选择在map端聚合的进行分析,进入到combineCombinersByKey方法中,该方法中,首先会进行拉取数据,对其聚合排序,如果内存不够,spill到本地磁盘,可能会产生多个文件,最后会对所有的spill文件和内存中的数据进行聚合,并作为一个集合返回
def combineCombinersByKey(
iter: Iterator[_ <: Product2[K, C]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
//拉取数据到本地节点,会对其进行聚合排序,并会产生spill文件
combiners.insertAll(iter)
updateMetrics(context, combiners)
//返回的是所有spill到磁盘上的文件和内存中合并后的集合
combiners.iterator
}
十一、具体的聚合过程在方法insertAll(iter)中
此处的iter就是上面ShuffleBlockFetchIterator,下面的next就是调用了ShuffleBlockFetchIterator中的next方法。
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
if (currentMap == null) {
throw new IllegalStateException(
"Cannot insert new elements into a map after calling iterator")
}
// An update function for the map that we reuse across entries to avoid allocating
// a new closure each time
var curEntry: Product2[K, V] = null
val update: (Boolean, C) => C = (hadVal, oldVal) => {
if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
}
while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize()
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
//判断是否满足溢写的条件,具体可以看看shuffle writer流程
if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
//聚合
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
}
十二、上面所说的next方法:
如果fetch请求成功获取了block,那么累计fetch请求的block大小会不断的释放,然后再发送rpc取抓取剩下的数据,next返回时一个block的id 和一个inputstream流对象
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
case _ =>
}
// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()
result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
case SuccessFetchResult(blockId, address, _, buf) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}
}
}
十三、上面的iterator方法,如果有spill文件,会对其和内存中 的数据进行merge,如果只有内存中的数据,直接返回
override def iterator: Iterator[(K, C)] = {
if (currentMap == null) {
throw new IllegalStateException(
"ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
}
if (spilledMaps.isEmpty) {
CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
} else {
new ExternalIterator()
}
}
private def freeCurrentMap(): Unit = {
currentMap = null // So that the memory can be garbage-collected
releaseMemory()
}
十四、如果有spill文件,会在ExternalIterator中进行,将所有的数据读取到mergeHeap队列中,一个流的数据都会放到ArrayBuffer缓存中。
private class ExternalIterator extends Iterator[(K, C)] {
// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
inputStreams.foreach { it =>
val kcPairs = new ArrayBuffer[(K, C)]
readNextHashCode(it, kcPairs)
if (kcPairs.length > 0) {
mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
}
}
十五、当开始调用ExternalIterator 的next方法时,才会进行最终的merge操作,那么什么时候调用呢?看下面的keyOrdering过程,如果设定了排序规则,则在排序中进行聚合,如果没有设定,就在以后取数的过程中聚合了
//取出hash 最小的key,从所有的文件中取相同key的所有value进行聚合
override def next(): (K, C) = {
if (mergeHeap.length == 0) {
throw new NoSuchElementException
}
// Select a key from the StreamBuffer that holds the lowest key hash
val minBuffer = mergeHeap.dequeue() //从队列中取出优先级最高的,并将其从该队列中移除
val minPairs = minBuffer.pairs
val minHash = minBuffer.minKeyHash
val minPair = removeFromBuffer(minPairs, 0)
val minKey = minPair._1
var minCombiner = minPair._2
assert(hashKey(minPair) == minHash)
// For all other streams that may have this key (i.e. have the same minimum key hash),
// merge in the corresponding value (if any) from that stream
val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) {//从其他ArrayBuffer中取出相同key的
val newBuffer = mergeHeap.dequeue()
minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
mergedBuffers += newBuffer
}
// Repopulate each visited stream buffer and add it back to the queue if it is non-empty
mergedBuffers.foreach { buffer =>
if (buffer.isEmpty) {
readNextHashCode(buffer.iterator, buffer.pairs)
}
if (!buffer.isEmpty) {
mergeHeap.enqueue(buffer)
}
}
(minKey, minCombiner)
}
十六、如果我们设定了排序,则按照以下步骤进行排序
使用的也是外部排序类,会在spill过程中排序;跟shuffle write过程使用的一样,请翻看shuffle write操作;
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
总结
(1)可优化的参数:
(a)spark.reducer.maxSizeInFlight 默认48M 限定了一次fetch最大值。此处可以调节该参数,如果调大该参数,可以减少发生rpc次数,减轻网络的压力;
(2)可能存在的风险
(a)在生成的FetchRequest时,可能会发生内存泄漏,因为如果单个block过大,拉取过来占用堆外内存过大,造成OOM】
注意:见解纰漏,如有误解,请提示