编程语言
首页 > 编程语言> > [源码解析] PyTorch 流水线并行实现 (5)--计算依赖

[源码解析] PyTorch 流水线并行实现 (5)--计算依赖

作者:互联网

[源码解析] PyTorch 流水线并行实现 (5)--计算依赖

目录

0x00 摘要

前几篇文章我们介绍了 PyTorch 流水线并行的基本知识,自动平衡机制和切分数据等,本文我们结合论文内容来看看如何实现流水线依赖,核心就是如何建立这些小批次之间的跨设备依赖关系

流水线并行其他文章链接如下:

[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

[源码解析] 深度学习流水线并行 GPipe(3) ----重计算

[源码解析] 深度学习流水线并行之PipeDream(1)--- Profile阶段

[源码解析] 深度学习流水线并行 PipeDream(2)--- 计算分区

[源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型

[源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎

[源码解析] 深度学习流水线并行 PipeDream(5)--- 通信模块

[源码解析] 深度学习流水线并行 PipeDream(6)--- 1F1B策略

[源码解析] PyTorch 流水线并行实现 (1)--基础知识

[源码解析] PyTorch 流水线并行实现 (2)--如何划分模型

[源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统

[源码解析] PyTorch 流水线并行实现 (4)--前向计算

本文图片来自论文和github源码。

0x01 前文回顾

为了更好的理解本文,我们首先看看前文之中的关键部分。

img

img

因为前文已经介绍了执行顺序方案,所以本文介绍如何计算依赖。

0x02 计算依赖

+-----------------------------------------------------------------------------------------+
|                                                                                         |
| Layer 1 +--->  Layer 2 +-----> Layer 3 +----->  Layer 4 +-----> Layer 5  +---> Layer 6  |
|                                                                                         |
+--------------------------+---------------------------+----------------------------------+
                                          +
                                          |
                                          |
                                          v
 +------------------------------------------------------------------------------------+
 | +--------------------+         +---------------------+      +--------------------+ |
 | |Partition 1         |         |Partition 2          |      |Partition 3         | |
 | |                    |         |                     |      |                    | |
 | |      Layer 1       |    +---------> Layer 4        |      |                    | |
 | |         +          |    |    |         +           |  +------->   Layer 6      | |
 | |         |          |    |    |         |           |  |   |                    | |
 | |         v          |    |    |         |           |  |   |                    | |
 | |      Layer 2       |    |    |         |           |  |   |                    | |
 | |         +          |    |    |         v           |  |   |                    | |
 | |         |          |    |    |      Layer 5 +---------+   |                    | |
 | |         v          |    |    |                     |      |                    | |
 | |      Layer 3  +---------+    |                     |      |                    | |
 | |                    |         |                     |      |                    | |
 | +---------+----------+         +---------+-----------+      +-----------+--------+ |
 |                                                                                    |
 +------------------------------------------------------------------------------------+

为什么需要计算依赖?

所以针对流水线并行,torchgpipe需要自己补充一个本机跨设备伪分布式依赖关系。torchgpipe 通过在前向计算图和后向计算图做各种调整来达到目的。计算图就意味着各种依赖逻辑,依赖逻辑的补足就是依靠本节介绍的 Fork 和 Join 两个函数完成的。

这里最初有一个疑问,就是Torchgpipe怎么在不使用 PyTorch RPC 和 p2p的情况下,构建出来一个异地反向计算图。后来发现,原来是我想多了,因为Torchgpipe没有考虑到这种情况,它针对都是在同一个主机之上的GPU,不涉及异地多机器计算。

Torchgpipe 本质上还是一个进程内运行多个线程进行计算,是 DP 的替代。比如源码中就有对比如下:

### ResNet-101 Accuracy Benchmark

Batch size | torchgpipe | nn.DataParallel | Goyal et al.
---------- | ---------: | --------------: | -----------:
256        | 21.99±0.13 |      22.02±0.11 |   22.08±0.06
1K         | 22.24±0.19 |      22.04±0.24 |          N/A
4K         | 22.13±0.09 |             N/A |          N/A

再比如代码中明确提到:

If you decide not to use checkpointing at all, :class:`nn.DataParallel
<torch.nn.DataParallel>` might be more efficient than GPipe.

0x03 反向传播依赖

我们首先看看反向传播依赖,这个是论文的重点。

2.1 解析

我们还是要回忆一下前面两个图例。

图1

img

图2

img

这里需要完成两种依赖:

假定我们依据确定性时钟周期(deterministic clock-cycle)算法来运行一个前向传播。即使前向传播是按照在第j个设备上应该执行的顺序来执行任务 \(F_{1,j},...,F_{m,j}\) ,得到的后向传播结果计算图看起来也更像图1而非图2,

从图1上看,PyTorch 的 autograd 引擎不知道 \(B_{i+1,j}\) 必须在 \(B_{i,j}\) 之前运行,因此会打乱后向传播的时间流。因此,虚拟依赖(图2的虚线箭头)必须在前向传播中被显式绘制出来。

我们再仔细分析一下图2。图2之中,每一行都表示一个 micro-batch 在训练中的运行流,这个流的前向是由clock算法确定的。后向关系是由前向传播中自动确定完成的

现在的问题是:一个 mini-batch 被分成了4个 micro-batch,分别在不同时钟周期进入训练。就是每一列。这一列由上到下的传播也是由clock算法确定,但是反向传播(由下自上)目前是不确定的。比如最后一列中,反向传播的顺序应是:\(B_{4,1},B_{3,1},B_{2,1},B_{1,1}\)。但是这个目前从前向传播的结果来看,无法确定这个顺序。

所以需要依靠本节介绍的 Fork 和 Join 两个函数完成这个依赖关系。图中斜线表示checkpoint之中需要先有一个重计算,然后才能由下往上走

因此,torchpipe定义两个基础函数,Fork 和 Join 来表达这种依赖关系:

现在,\(F_{i+1,j}\) 对于 \(F_{i,j}\) 的依赖(其在后向传播计算图中被转换为 \(B_{i,j}\) 到 $B_{i+1,j} $ 的依赖关系)可以被如下表示

所以,图中这里实线都是前向传播时候构建的,虚线是由 fork & join 构建的。

原则上,表示虚拟依赖关系的张量可以是任意的。然而,torchgpipe选择使用空张量,以消除由张量引起的任何不必要的计算,例如PyTorch中的梯度累积。

具体如下图。就是使用 Fork 和 Join 的后向计算图。图中,不同颜色对应不同的设备。箭头依据后向传播图的方向来绘制,这些联系是在前向传播中被构建的。因此,\(F^{'}_{i,j}\) 对于 \(B_{i+1,j}\) 的虚拟依赖通过 Fork 和 Join 被构建出来,用虚线表示。

2.2 基础功能

2.2.1 Function

首先,我们要看看 torch.autograd.Function 的作用。

torch.autograd.Function类实际上是一个操作函数的基础父类,这样的操作函数必须具备两个基本的过程,即前向的运算过程和反向的求导过程,

如果某些操作无法通过 PyTorch 已有的层或者是已有的方法实现不了,就需要实现一个新的方法对 PyTorch 进行拓展。当不使用自动求导机制,需要自定义求导规则的时候,就应该拓展torch.autograd.Function类。 由于pytorch不再提供自动求导机制,就要用户自己定义实现前向传播和反向传播的计算过程,这就是 "Extending torch.autograd"。

我们接下来介绍Backward Dependency 的关键算法:Fork and Join。

2.2.2 Fork

Fork 是auto grad 函数,其把一个张量 x 映射到 pair(x, \(\phi\)),这里 \(\phi\) 是一个空张量。Fork 方法就是拓展了torch.autograd.Function

def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony


class Fork(torch.autograd.Function):
    @staticmethod
    def forward(ctx: 'Fork', input: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore
        phony = get_phony(input.device, requires_grad=False)
        return input.detach(), phony.detach()

    @staticmethod
    def backward(ctx: 'Fork', grad_input: Tensor, grad_grad: Tensor) -> Tensor:  # type: ignore
        return grad_input

2.2.3 Join

Join 是auto grad 函数,其把 pair(x, \(\phi\)) 映射到一个张量 x ,这里 \(\phi\) 是一个空张量。Join 方法也是拓展了torch.autograd.Function

def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input


class Join(torch.autograd.Function):
    @staticmethod
    def forward(ctx: 'Join', input: Tensor, phony: Tensor) -> Tensor:  # type: ignore
        return input.detach()

    @staticmethod
    def backward(ctx: 'Join', grad_input: Tensor) -> Tuple[Tensor, None]:  # type: ignore
        return grad_input, None

2.2.4 Phony

Phony是没有空间的张量,因为它不需要任何梯度累积,所以可以在 autograd 图中构建任意的依赖。

def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
    """Gets a phony. Phony is tensor without space. It is useful to make
    arbitrary dependency in a autograd graph because it doesn't require any
    gradient accumulation.

    .. note::

        Phonies for each device are cached. If an autograd function gets a phony
        internally, the phony must be detached to be returned. Otherwise, the
        autograd engine will mutate the cached phony in-place::

            class Phonify(torch.autograd.Function):
                @staticmethod
                def forward(ctx, input):
                    phony = get_phony(input.device, requires_grad=False)
                    return phony.detach()  # detach() is necessary.

    """
    key = (device, requires_grad)

    try:
        phony = _phonies[key]
    except KeyError:
        with use_stream(default_stream(device)):
            phony = torch.empty(0, device=device, requires_grad=requires_grad)

        _phonies[key] = phony

    return phony

2.2.5 detach

在代码中,经常可以见到 detach 的使用,这个从注释可以看出来,是为了解决 PyTorch 的一个bug。

    # A Python autograd function might fail with this error:
    #
    #   RuntimeError: Returning Variables sharing storage with other Variables
    #   that require grad is not supported in Python functions. Please submit a
    #   feature request if you hit this error.
    #
    # It doesn't look like an essential restriction. But it happens on the
    # current PyTorch version. To avoid it, we should detach the tensor before
    # returning by identity autograd functions, such as Wait, Fork, and Join.
    #

2.3 使用

在 Pipeline 之中我们可以看到具体的使用方法,fence 方法(省略部分代码)利用 depend 来构建后向传播的依赖关系,确保 batches[i-1] 在 batches[i] 之后完成。

    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 在这里建立了后向传播依赖关系
                
            next_stream = copy_streams[j][i]

            for prev_j, ns, name in skip_layout.copy_policy(j):
                prev_stream = copy_streams[prev_j][i]
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0:
                prev_stream = copy_streams[j-1][i]
                copy(batches[i], prev_stream, next_stream)                

具体 depend 代码如下:

def depend(fork_from: Batch, join_to: Batch) -> None:
    fork_from[0], phony = fork(fork_from[0])
    join_to[0] = join(join_to[0], phony)

我们结合示例代码把传入的参数赋值一下,重新把方法解释如下,这样大家就可以更好的理解。

def depend(batches[i-1]: Batch, batches[i]: Batch) -> None:
    batches[i-1][0], phony = fork(batches[i-1][0])
    batches[i][0] = join(batches[i][0], phony)

具体逻辑如下,通过 phony 完成了一个桥接,即在正向传播之中,batches[i] 依赖 batches[i-1] 的执行结果

      +----------------+          +--------------+
      |                |          |              |
      |  batches[i-1]  |          |  batches[i]  |
      |                |          |              |
      +----------+-----+          +-----+--------+
                 |                      |
                 |                      |
                 |                      |
                 v                      v
+--------------------------------------------------------+
| depend         |                      |                |
|                |                      |                |
|                |                      |                |
|                v                      |                |
|        +-----------------------+      |                |
|        | fork  |               |      |                |
|        |       |    get_phony  |      |                |
|        |       |        +      |      |                |
|        |       |        |      |      |                |
|        |       |        |      |      |                |
|        +-----------------------+      |                |
|                |        |             |                |
|                |        |             |                |
|                |        |             |                |
|                v        v             |                |
|    +-----------+--+  +--+-----+       |                |
|    |              |  |        |       |                |
|    | batches[i-1] |  | phony  |       |                |
|    |              |  |        |       |                |
|    +--------------+  +--+-----+       |                |
|                         |             |                |
|                         |             |                |
|                         v             v                |
|                      +--+------------------+           |
|                      |Join            |    |           |
|                      |                |    |           |
|                      |                |    |           |
|                      |                v    |           |
|                      +---------------------+           |
|                                       |                |
|                                       |                |
|                                       |                |
|                                       v                |
|                                 +-----+------+         |
|                                 |            |         |
|                                 | batches[i] |         |
|                                 |            |         |
|                                 +------------+         |
|                                                        |
+--------------------------------------------------------+

我们把多个 batches 联合起来看看,这样就能看出来一个依赖链条。

                  +----------------------------------------------------------+
                  | depend                                                   |
                  |                                                          |
                  | +------------+                                           |
 +-------------   | |fork        |     +-----------+                         |
 |            |   | |            |     |           |                         |
 |batches[i]  +----------------------> | batches[i]|                         |
 |            |   | |            |     |           |                         |
 +-------------   | |            |     +-----------+                         |
                  | |            |             +-------+                     |
                  | |            +-----------> | Join  |                     |
                  | |            |             |       |                     |
                  | +------------+             |       |                     |
 +-------------   |                            |       |    +--------------+ |
 |            |   |                            |       |    |              | |
 |batches[i+1]+-------------------------------------------->+ batches[i+1] | |
 |            |   |                            |       |    |              | |
 +---------+---   |                            |       |    +--------------+ |
           |      |                            +-------+                     |
           |      |                                                          |
           |      +----------------------------------------------------------+
           |      +----------------------------------------------------------+
           |      | depend                                                   |
           |      |                                                          |
           |      | +-------------+                                          |
           |      | |fork         |     +------------+                       |
           |      | |             |     |            |                       |
           +--------------------------> |batches[i+1]|                       |
                  | |             |     |            |                       |
                  | |             |     +------------+                       |
                  | |             |           +-------+                      |
                  | |             +---------> |Join   |                      |
                  | +-------------+           |       |                      |
+------------+    |                           |       |     +-------------+  |
|            |    |                           |       |     |             |  |
|batches[i+2]+--------------------------------------------> | batches[i+2]|  |
|            |    |                           |       |     |             |  |
+----------+-+    |                           |       |     +-------------+  |
           |      |                           +-------+                      |
           |      |                                                          |
           |      +----------------------------------------------------------+
           |
           |      +-----------------------------------------------------------+
           |      | depend                                                    |
           |      |                                                           |
           +----------------------------->    ......                          |
                  |                                                           |
                  |                                                           |
                  +-----------------------------------------------------------+

这样,上图就是前向计算图,于是在后向传播之中,batches[i] 就 必须在 batches[i-1] 之前完成了

我们再结合论文的图来看看。

本来示例代码中是:

depend(batches[i-1], batches[i])

为了和论文中的图对应,我们修改为:

depend(batches[i], batches[i+1])

depend 代码也变化为:

def depend(batches[i]: Batch, batches[i+1]: Batch) -> None:
    batches[i][0], phony = fork(batches[i][0])
    batches[i+1][0] = join(batches[i+1][0], phony)

对应下图,就是在后向传播计算图之中 batches[i+1] 通过一个join, 一个fork,排在了 batches[i] 前面,就是下面大箭头所示,具体细化一下:

0x03 正向传播依赖

我们回头再来看正向依赖。因为正向传播的部分目的就是完成反向传播依赖,而目前反向传播只完成了行之间的依赖,列之间的依赖没有完成,我们现在补全

列之间的依赖就是设备之间的依赖,即前一个设备的输出是后一个设备的输入

3.1 分割模型

首先还是需要回顾下如何切分模型,从 split_module 可以看到,

GPipe 的 partitions 成员变量是 nn.ModuleList 类型。nn.ModuleList是一个容器,其储存不同 module,并自动将每个 module 的 parameters 添加到网络中。但是nn.ModuleList 并没有定义一个网络,而只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序,网络的执行顺序是根据 forward 函数来决定的。

def split_module(module: nn.Sequential,
                 balance: Iterable[int],
                 devices: List[torch.device],
                 ) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:

    balance = list(balance)

    j = 0
    partitions = []
    layers: NamedModules = OrderedDict()

    for name, layer in module.named_children(): # 遍历模型包含的层
        layers[name] = layer # 把新的层加入到数组中

        if len(layers) == balance[j]: # 如果数组大小等于balance[j],就是达到了device j应该包含的层数
            # Group buffered layers as a partition.
            partition = nn.Sequential(layers) # 把层数组组合成一个sequential module

            device = devices[j]
            partition.to(device) # 把层放置到相关设备之上

            partitions.append(partition) # 这个新module加入到分区数组中

            # Prepare for the next partition.
            layers.clear()
            j += 1 # 去下一个device看看

    partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
    del devices[j:]

    return partitions, balance, devices

随之而来问题就是:partition内部可以用Sequential来进行一系列的前向操作,但是如何配置partitions 之间的执行顺序?

+-----------------------------------------------------------------------------------------+
|                                                                                         |
| Layer 1 +--->  Layer 2 +-----> Layer 3 +----->  Layer 4 +-----> Layer 5  +---> Layer 6  |
|                                                                                         |
+-----------------------------------------+-----------------------------------------------+
                                          |
                                          |
                                          |
                                          v
+-----------------------------------------------------------------------------------------+
| +--------------------+           +---------------------+         +--------------------+ |
| |Partition 1         |           |Partition 2          |         |Partition 3         | |
| |                    |   ???     |                     |         |                    | |
| |      Layer 1       |     +----------> Layer 4        |   ???   |                    | |
| |         +          |     |     |         +           |     +------->   Layer 6      | |
| |         |          |     |     |         |           |     |   |                    | |
| |         v          |     |     |         |           |     |   |                    | |
| |      Layer 2       |     |     |         |           |     |   |                    | |
| |         +          |     |     |         v           |     |   |                    | |
| |         |          |     |     |      Layer 5 +------------+   |                    | |
| |         v          |     |     |                     |         |                    | |
| |      Layer 3  +----------+     |                     |         |                    | |
| |                    |           |                     |         |                    | |
| +--------------------+           +---------------------+         +--------------------+ |
|                                                                                         |
+-----------------------------------------------------------------------------------------+

3.2 建立依赖

我们还是从论文中入手。假定我们有一个神经网络,其由一系列子网络构成。我们假定这些子网络是 \(f^1,...,f^n\),其参数分别是 \(\theta^1,...,\theta^n\),则整个网络是:

参数是 \(\theta = (\theta^1,...,\theta^n)\),为了清楚起见,我们称 \(f^j\) 表示 f 的第 j 个分区,并假设分区的参数是相互不相交的。

在训练网络时,基于梯度的方法(如随机梯度下降法)需要在给定小批量训练数据 x 和相应损失之后,计算网络的输出结果f(x)。以及损失相对于网络参数 \(\theta\) 的梯度g。这两个阶段分别称为向前传播和向后传播。

既然 f 由其 L 层 子模块 (\(f^L, f^{L-1},...f^1\)) 顺序组成,那么前向传播\(f(x)\) 可以通过如下方式计算:让 \(x^0=x\)(就是输入x),然后顺序应用每一个 partition,即 \(x^j = f^j (x^{j-1})\),这里 $ j = 1, ..., L$。就是 \(f(x)\) 可以表示为 :

\[f(x) = f^L(f^{L-1}(f^{L-2}(... f^1(x)))) \]

于是我们知道了,前向传播的顺序是由 \(f(x) = f^L(f^{L-1}(f^{L-2}(... f^1(x))))\) 来确定的

我们可以针对代码,进一步解析,看看如何实施partitions之间的顺序依赖。

    def run(self) -> None:
        """Runs pipeline parallelism.

        It modifies the given batches in place.

        """
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        skip_layout = self.skip_layout

        m = len(batches)
        n = len(partitions)

        skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]

        with spawn_workers(devices) as (in_queues, out_queues):
            for schedule in clock_cycles(m, n): # 这里使用,给出了执行序列计划,后续按照这个来执行
                self.fence(schedule, skip_trackers)
                self.compute(schedule, skip_trackers, in_queues, out_queues)

解析的目标是 for schedule in clock_cycles(m, n) 这个 for 循环,其:

现在我们完成了两步:

  1. 确定性时钟周期算法给定了前向传播的执行顺序,我们只要按照 clock_cycles 方法提供的计划一一运行即可
  2. fence 方法通过调用 join 和 fork,我们做到了在后向传播之中,batches[i] 就 必须在 batches[i-1] 之前完成了,即 \(B_{i+1,j}\) 必须在 \(B_{i,j}\) 之前运行。

对于我们的图来说,第二步就是完成了下图的列依赖。

我们的问题是:怎么通过这个 for 循环,做到 \(B_{i,{j+1}}\) 必须在 \(B_{i,j}\) 之前运行?,即怎么安排反向传播逐次运行?就是怎么完成行内的依赖?

这就要通过 compute 的源码进行分析。重点说明的是:

    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        n = len(partitions)
        streams = [current_stream(d) for d in devices]
  
        for i, j in schedule: # 针对 schedule 之中的每一对 i,j
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)

            # Determine whether checkpointing or not.

            if checkpoint:
							# 忽略
            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition) # 前向计算,计算以 partition为单位计算,partition内部的层是顺序计算,由 Sequential保证。

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task) # 让 worker计算

        for i, j in schedule:
            ok, payload = out_queues[j].get() # 获取 worker 的前向计算结果,就是 第 j 个device 对 第 i 个 batch 的计算结果

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)

            # 第 j 个device 对 第 i 个 batch 的计算 就是 F[i,j]

            batches[i] = batch # 这里是关键,就是把 第 j 个device 对 第 i 个 batch 的计算结果 赋值到 batches[i],batches[i]就是 batches[i][j],在下次计算时候,构建的就是 F[i,j+1], 下一次 fence 之中的 depend 操作,就是针对 batches[i,j+1]

关于这个赋值操作,其对应的grad_fn 是 PermuteBackward,比如:

a = torch.tensor([2., 3.], requires_grad=True)
c = a
c.backward(gradient=external_grad)
print(c)

具体是:

c = {Tensor: 2} tensor([2., 3.], requires_grad=True)
  T = {Tensor: 2} tensor([2., 3.], grad_fn=<PermuteBackward>)

现在,我们把下图进行升级。

                 +-------------------------------------------------------------------+
                 | depend                                                            |
                 |                                                                   |
                 | +---------------+                                                 |
                 | |fork           |                                                 |
+-------------   | |               |     +-----------+                               |
|            |   | |               |     |           |                               |
|batches[i]  +-------------------------> | batches[i]|                               |
|            |   | |               |     |           |                               |
+-------------   | |               |     +-----------+                               |
                 | |               |                                                 |
                 | |               |                                                 |
                 | |               |     +--------+    +-------+                     |
                 | |  get_phony +------> |        +--->+ Join  |                     |
                 | |               |     | phony  |    |       |                     |
                 | +---------------+     |        |    |       |                     |
                 |                       +--------+    |       |                     |
                 |                                     |       |                     |
+-------------   |                                     |       |    +--------------+ |
|            |   |                                     |       |    |              | |
|batches[i+1]+----------------------------------------------------->+ batches[i+1] | |
|            |   |                                     |       |    |              | |
+-------------   |                                     |       |    +--------------+ |
                 |                                     +-------+                     |
                 |                                                                   |
                 +-------------------------------------------------------------------+

我们进行横向拓展,得到如下,即一个batch 被分成两个小批次: batches[i],batches[i+1] ,它们在两个设备 partitions[j],partitions[j + 1] 之上流水线,这样行和列都有反向传播的依赖。

                                 F[i,j]                                                                            F[i,j+1]

                    +------------------------------------------------+                            +-----------------------------------------------+
                    | partitions[j]                                  |                            |  partitions[j+1]                              |
                    |                                                |                            |                                               |
                    | +--------------------+   +------------------+  |                            | +-------------------+   +------------------+  |
                    | |fence               |   | compute          |  |                            | | fence             |   | compute          |  |
                    | |                    |   |                  |  |                            | |                   |   |                  |  |
+--------------+    | |  +--------------+  |   |  +------------+  |  |     +-----------------+    | |   +-------------+ |   |  +------------+  |  |       +-----------------+
|              |    | |  | depend       |  |   |  |forward     |  |  |     |                 |    | |   | depend      | |   |  |forward     |  |  |       |                 |
|  batches[i]  +---------------------------------------------------------> | batches[i][j]   +----------------------------------------------------------> | batches[i][j+1] |
|              |    | |  |              |  |   |  |            |  |  |     |                 |    | |   |             | |   |  |            |  |  |       |                 |
+--------------+    | |  |              |  |   |  |            |  |  |     +-----------------+    | |   |             | |   |  |            |  |  |       +-----------------+
                    | |  |              |  |   |  +------------+  |  |                            | |   |             | |   |  +------------+  |  |
                    | |  |              |  |   |                  |  |                            | |   |             | |   |                  |  |
+--------------+    | |  |              |  |   +------------------+  |     +-----------------+    | |   |             | |   +------------------+  |       +-------------------+
|              |    | |  |              |  |                         |     |                 |    | |   |             | |                         |       |                   |
|  batches[i+1]+---------------------------------------------------------> | batches[i+1][j] +----------------------------------------------------------> | batches[i+1][j+1] |
|              |    | |  |              |  |                         |     |                 |    | |   |             | |                         |       |                   |
+--------------+    | |  +--------------+  |                         |     +-----------------+    | |   +-------------+ |                         |       +-------------------+
                    | |                    |                         |                            | |                   |                         |
                    | +--------------------+                         |                            | +-------------------+                         |
                    +------------------------------------------------+                            +-----------------------------------------------+

手机如下:

0x04 总结

下图 $ m = 4, n = 3$。即,模型被分成3个子网络,小批次被分割成 4个微批次。F 和 B 的下标是 (m, n)。

img

如上图,这里需要完成两种依赖:

如上图,我们需要完成行,列两方面的依赖。

至此,我们完成了执行顺序和依赖关系的设定,下一篇我们介绍如何并行处理。

0xFF 参考

Markdown公式用法大全

markdown中公式编辑教程

https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior

CUDA学习:基础知识小结

CUDA随笔之Stream的使用

NVIDIA解决方案架构师深度解析大规模参数语言模型Megatron-BERT

Accelerating Wide & Deep Recommender Inference on GPUs

HugeCTR: High-Performance Click-Through Rate Estimation Training

https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548

https://github.com/NVIDIA/apex/

https://github.com/justheuristic/prefetch_generator

https://pytorch.org/tutorials/intermediate/model_parallel_turotial.html

https://pytorch.org/docs/stable/autograd.html

https://pytorch.org/docs/notes/cuda.html

https://zhuanlan.zhihu.com/p/61765561

https://pytorch.apachen.org/docs/1.7/64.html

https://zhidx.com/p/217999.html

标签:Layer,batches,依赖,--,phony,传播,PyTorch,源码,grad
来源: https://www.cnblogs.com/rossiXYZ/p/15370448.html