编程语言
首页 > 编程语言> > Spark2.x精通:ShuffleReader过程源码深度剖析

Spark2.x精通:ShuffleReader过程源码深度剖析

作者:互联网

一、概述


    之前我们写了几篇文章详细讲解了Spark Shuffle的Writer原理、技术演进历程及Spark2.x中三种Writer机制的具体实现,这里我们对Shuffler Read的源码进行深度剖析。

    对于每个stage来说,它的上边界,要么从外部存储读取数据,要么读取parent stage的输出。而下边界要么是写入到本地文件系统(需要有shuffle),提供给child stage进行读取,要么就是最后一个stage,需要输出结果。这里的stage在运行时就可以以流水线的方式进行运行一组Task,除了最后一个stage对应的ResultTask,其余的stage全部对应的ShuffleMapTask。

  除了需要从外部存储读取数据和RDD已经做过cache或者checkPoint的Task。一般的Task都是从Shuffle RDD的ShuffleRead开始的。


二、源码剖析


 1.我们先从ResultTask的runtask()函数开始讲解,代码如下:

override def runTask(context: TaskContext): U = {
   val threadMXBean = ManagementFactory.getThreadMXBean    val deserializeStartTime = System.currentTimeMillis()    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {      threadMXBean.getCurrentThreadCpuTime    } else 0L    val ser = SparkEnv.get.closureSerializer.newInstance()    //从广播变量获取rdd和func并进行反序列化ResultTask,结果为rdd,和func函数    //taskBinary的值是在DAGScheduler.submitMissingTasks()方法中进行序列化的    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime    } else 0L    // 对rdd中每个parition迭代执行,这里的RDD为ShuffleRDD,会调用ShuffleRDD中    // 的compute()函数,然后从各个ShuffleMapTask的输出结果中拉取数据处理    func(context, rdd.iterator(partition, context))  }


2.这里ShuffleRDD.compute()函数从sparkEnv中获取对应的shuffleManager,这里对应的是SortShuffleManager,代码如下:

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)      .read()      .asInstanceOf[Iterator[(K, C)]]  }

   

    调用getReader函数返回的应该是BlockStoreShuffleReader实例,代码如下:

  override def getReader[K, C](      handle: ShuffleHandle,      startPartition: Int,      endPartition: Int,      context: TaskContext): ShuffleReader[K, C] = {    new BlockStoreShuffleReader(      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)  }


3.最后调用BlockStoreShuffleReader的read()函数对数据进行读取,代码如下:

 override def read(): Iterator[Product2[K, C]] = {   //实例化手,后面用于拉取多种block数据。对于local block,它将从本地block manager获取block数据;   // 对于remote block,它通过使用BlockTransferService拉取remote block      // 里面有几个重要的参数      //1.spark.reducer.maxSizeInFlight 默认48MB,该参数用于设置shuffle read task的buffer缓冲大小,决定每次拉取多少数据  非常重要 重要 重要!!!     //2.spark.reducer.maxReqsInFlight 远程机器拉取本机器文件块的请求数,随着集群增大,需要对此做出限制。否则可能会使本机负载过大而挂掉,一般不需要修改      //3.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS 每个reduce任务从给定主机端口获取的远程块的数量。当在一次获取或同时从给定地址请求大量块时,这可能会导致服务执行程序或节点管理器崩溃。   // 这对于在启用外部shuffle时减少节点管理器上的负载特别有用。您可以通过将其设置为较低的值来减轻这个问题      //4.spark.maxRemoteBlockSizeFetchToMem ,默认值Int.MaxValue - 512 当块的大小超过此阈值时,远程块将被提取到磁盘.一般都降低spill阈值,增加spill频率减少内存压力    //Reduce 获取数据时,由于数据倾斜,有可能造成单个 Block 的数据非常的大,默认情况下是需要有足够的内存来保存单个 Block 的数据。因此,此时极有可能因为数据倾斜造成 OOM    //spark.shuffle.detectCorrupt  获取数据后是否对数据进行校验,默认为true 一般不需要调整    val wrappedStreams = new ShuffleBlockFetcherIterator(      context,      blockManager.shuffleClient,      blockManager,      //根据shuffleId和开始结束分区 通过mapOutputTracker获取的需要拉取数据块列表      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),      serializerManager.wrapStream,      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),      SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
   val serializerInstance = dep.serializer.newInstance()
   // Create a key/value iterator for each stream    //对每个流建立数据迭代器    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a      // NextIterator. The NextIterator makes sure that close() is called on the      // underlying InputStream when all records have been read.      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator    }
   // Update the context task metrics for each record read.    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](      recordIter.map { record =>        readMetrics.incRecordsRead(1)        record      },      context.taskMetrics().mergeShuffleReadMetrics())
   // An interruptible iterator must be used here in order to support task cancellation    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)     //下面就是做了一些聚合操作              //判断会否指定了聚合操作    //内部使用ExternalAppendOnlyMap进行聚合操作,类似ExternalSort的实现,    //这里没有排序,只是对key,value进行聚合操作    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {      if (dep.mapSideCombine) {        // We are reading values that are already combined        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)      } else {        //合并聚合值        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)      }    } else {      //如果没有聚合  这里进行完整聚合操作      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]    }
    //这里做了一些排序    dep.keyOrdering match {     //判断是否需要排序      case Some(keyOrd: Ordering[K]) =>        //         // 对于需要排序的情况,创建一个ExtenrnalSorter实例,使用ExtenrnalSorter进行排序,        // 这里需要注意,如果spark.shuffle.spill是false的话,数据是不会写入硬盘的。        val sorter =          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)        sorter.insertAll(aggregatedIter)        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())      case None =>        aggregatedIter    }  }

   

4.回到上面第1中的ShuffleBlockFetcherIterator初始化,由于这个涉及到数据的拉取,比较重要,再去跟踪下他的实例化代码,里面主要是调用了initialize()函数进行了初始化,代码如下:

private[this] def initialize(): Unit = {    // Add a task completion callback (called in both success case and failure case) to cleanup.    context.addTaskCompletionListener(_ => cleanup())      //区分local block和remote block,本地数据块和远程数据块拉取肯定是不一样的    // 这里获取了哪些需要从远程拉取的    val remoteRequests = splitLocalRemoteBlocks()    // Add the remote requests into our queue in a random order    fetchRequests ++= Utils.randomize(remoteRequests)    assert ((0 == reqsInFlight) == (0 == bytesInFlight),      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +      ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)     //发送fetch request获取remote block    fetchUpToMaxBytes()
   val numFetches = remoteRequests.size - fetchRequests.size    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    // 获取本地数据块    fetchLocalBlocks()    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))  }

 

    总结一下,上面的read()函数里面的实现代码比较多,这里只看了重要的流程代码,它主要干了三件事:


    1.首先实例化ShuffleBlockFetcherIterator并进行数据的拉取;

    2.其次就是对数据进行了聚合,生成聚合迭代器;

    3.最后对数据进行了排序,生成排序迭代器。


    ShuffleBlockFetcherIterator上面也说了几个参数,但是有一个参数特别重要参数,经常会用来优化shuffle reader:


    spark.reducer.maxSizeInFlight


         默认值48MB,设置ShuffleReadTask拉取数据的缓冲区大小,决定每次能够拉取多少数据。如果你内存充足,可适当调大成64MB、96MB减少拉取次数和数据传输次数,如果内存不太多,可适当调小为24MB,防止OOM,减少每次拉取的数据。


标签:Iterator,val,get,dep,Spark2,ShuffleReader,拉取,源码,context
来源: https://blog.51cto.com/15080019/2653904