计算图(Computation Graph)

计算图

1. 什么是计算图

计算图是由节点和边构成的有向图(Graph),其中:

  • 节点(Nodes):表示操作(例如加法、乘法、激活函数等)或变量(例如输入、权重、偏置等)。
  • 边(Edges):表示数据在节点之间的流动,即操作的输入和输出之间的依赖关系。

在计算图中,每个节点都对应一个操作或变量,边则表示这些操作之间的依赖顺序。

2. 计算图的优势

  1. 自动微分:能够自动计算复杂函数的梯度(链式法则)
  2. 并行计算:可以识别独立的操作并并行执行
  3. 优化:通过图重写、操作融合、内存优化等技术优化计算效率
    1. 操作融合:将多个小操作合并为一个更大的操作,以减少数据传输和中间结果的存储需求,如y = 2x + 3x可优化为y = 5x
    2. 内存优化:通过在计算图中智能地分配和释放内存,减少内存占用
    3. 利用计算图的结构,框架可以识别出独立的操作,并在多个处理器或 GPU 上并行执行这些操作

3. 计算图的构建

前向计算的过程大概如下

Forward Calculation

根据链式法则,反向计算的过程大概如下

Backward Calculation
  1. 先按着计算图往前计算得到各节点的值;
  2. 然后反方向计算出各直接连接的节点间(如x和z、y和z)的偏导;
  3. 最后对有效路径求和,即可得到r对a、b、c、d、e的偏导。

4. 计算图的实现

计算图的实现一般分为两种:

  1. 静态计算图:在计算图构建完成后,图的结构和节点的计算顺序都是固定的,如 TensorFlow 的计算图。
  2. 动态计算图:在计算图构建的过程中,可以根据需要动态地构建计算图的结构和节点的计算顺序,如 PyTorch 的计算图。
  3. 混合计算图:结合了静态计算图和动态计算图的优点,如 TensorFlow 2.0 的计算图。

5. 总结

计算图是深度学习框架的核心,它能够自动微分、并行计算、优化计算效率等,是实现深度学习算法的基础。

6. 参考文献

  1. 计算图的概念与理解
  2. 浅显易懂的计算图、链式法则讲解