Spark2.x精通:ShuffleReader过程源码深度剖析
作者:互联网
一、概述
之前我们写了几篇文章详细讲解了Spark Shuffle的Writer原理、技术演进历程及Spark2.x中三种Writer机制的具体实现,这里我们对Shuffler Read的源码进行深度剖析。
对于每个stage来说,它的上边界,要么从外部存储读取数据,要么读取parent stage的输出。而下边界要么是写入到本地文件系统(需要有shuffle),提供给child stage进行读取,要么就是最后一个stage,需要输出结果。这里的stage在运行时就可以以流水线的方式进行运行一组Task,除了最后一个stage对应的ResultTask,其余的stage全部对应的ShuffleMapTask。
除了需要从外部存储读取数据和RDD已经做过cache或者checkPoint的Task。一般的Task都是从Shuffle RDD的ShuffleRead开始的。
二、源码剖析
1.我们先从ResultTask的runtask()函数开始讲解,代码如下:
override def runTask(context: TaskContext): U = {
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
//从广播变量获取rdd和func并进行反序列化ResultTask,结果为rdd,和func函数
//taskBinary的值是在DAGScheduler.submitMissingTasks()方法中进行序列化的
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
// 对rdd中每个parition迭代执行,这里的RDD为ShuffleRDD,会调用ShuffleRDD中
// 的compute()函数,然后从各个ShuffleMapTask的输出结果中拉取数据处理
func(context, rdd.iterator(partition, context))
}
2.这里ShuffleRDD.compute()函数从sparkEnv中获取对应的shuffleManager,这里对应的是SortShuffleManager,代码如下:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() .asInstanceOf[Iterator[(K, C)]] }
调用getReader函数返回的应该是BlockStoreShuffleReader实例,代码如下:
override def getReader[K, C]( handle: ShuffleHandle, startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) }
3.最后调用BlockStoreShuffleReader的read()函数对数据进行读取,代码如下:
override def read(): Iterator[Product2[K, C]] = {
//实例化手,后面用于拉取多种block数据。对于local block,它将从本地block manager获取block数据;
// 对于remote block,它通过使用BlockTransferService拉取remote block
// 里面有几个重要的参数
//1.spark.reducer.maxSizeInFlight 默认48MB,该参数用于设置shuffle read task的buffer缓冲大小,决定每次拉取多少数据 非常重要 重要 重要!!!
//2.spark.reducer.maxReqsInFlight 远程机器拉取本机器文件块的请求数,随着集群增大,需要对此做出限制。否则可能会使本机负载过大而挂掉,一般不需要修改
//3.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS 每个reduce任务从给定主机端口获取的远程块的数量。当在一次获取或同时从给定地址请求大量块时,这可能会导致服务执行程序或节点管理器崩溃。
// 这对于在启用外部shuffle时减少节点管理器上的负载特别有用。您可以通过将其设置为较低的值来减轻这个问题
//4.spark.maxRemoteBlockSizeFetchToMem ,默认值Int.MaxValue - 512 当块的大小超过此阈值时,远程块将被提取到磁盘.一般都降低spill阈值,增加spill频率减少内存压力
//Reduce 获取数据时,由于数据倾斜,有可能造成单个 Block 的数据非常的大,默认情况下是需要有足够的内存来保存单个 Block 的数据。因此,此时极有可能因为数据倾斜造成 OOM
//spark.shuffle.detectCorrupt 获取数据后是否对数据进行校验,默认为true 一般不需要调整
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
//根据shuffleId和开始结束分区 通过mapOutputTracker获取的需要拉取数据块列表
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
//对每个流建立数据迭代器
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
//下面就是做了一些聚合操作
//判断会否指定了聚合操作
//内部使用ExternalAppendOnlyMap进行聚合操作,类似ExternalSort的实现,
//这里没有排序,只是对key,value进行聚合操作
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
//合并聚合值
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
//如果没有聚合 这里进行完整聚合操作
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
//这里做了一些排序
dep.keyOrdering match {
//判断是否需要排序
case Some(keyOrd: Ordering[K]) =>
//
// 对于需要排序的情况,创建一个ExtenrnalSorter实例,使用ExtenrnalSorter进行排序,
// 这里需要注意,如果spark.shuffle.spill是false的话,数据是不会写入硬盘的。
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
4.回到上面第1中的ShuffleBlockFetcherIterator初始化,由于这个涉及到数据的拉取,比较重要,再去跟踪下他的实例化代码,里面主要是调用了initialize()函数进行了初始化,代码如下:
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup())
//区分local block和remote block,本地数据块和远程数据块拉取肯定是不一样的
// 这里获取了哪些需要从远程拉取的
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
//发送fetch request获取remote block
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// 获取本地数据块
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
总结一下,上面的read()函数里面的实现代码比较多,这里只看了重要的流程代码,它主要干了三件事:
1.首先实例化ShuffleBlockFetcherIterator并进行数据的拉取;
2.其次就是对数据进行了聚合,生成聚合迭代器;
3.最后对数据进行了排序,生成排序迭代器。
ShuffleBlockFetcherIterator上面也说了几个参数,但是有一个参数特别重要参数,经常会用来优化shuffle reader:
spark.reducer.maxSizeInFlight
默认值48MB,设置ShuffleReadTask拉取数据的缓冲区大小,决定每次能够拉取多少数据。如果你内存充足,可适当调大成64MB、96MB减少拉取次数和数据传输次数,如果内存不太多,可适当调小为24MB,防止OOM,减少每次拉取的数据。
标签:Iterator,val,get,dep,Spark2,ShuffleReader,拉取,源码,context 来源: https://blog.51cto.com/15080019/2653904