Welcome to Hexo! This is your very first post. Check documentation for more info. If you get any problems when using Hexo, you can find the answer in troubleshooting or you can ask me on GitHub.

net

utils

st_gcn.py

  1. 导入库
1
2
3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from net.utils.tgcn import ConvTemporalGraphical
from net.utils.graph import Graph
  1. class 参数
  • in_channels: 输入通道数

  • out_channels: 输出通道数

  • kernel_size: 卷积核大小。(时间卷积核大小,空间卷积核大小)

  • stride(int): 时间步长

  • dropout:

  • residual: 是否使用残差

  1. 检测 kernel_size 长度是否为 2

  2. ?检测时间卷积核大小是否是奇数

  3. ?padding: 时间填充,空间不填充)

  4. gcn 函数: 自定义,空间卷积(输入通道数,输出通道数,空间卷积核大小)

返回图 X 和 邻接矩阵 A

  1. tcn 函数: 时间卷积,套用 Covn2d
  • BatchNorm2d

  • ReLU: 激活函数

  • Conv2d: (输入通道数, 输出通道数, 卷积核大小, 填充)

  • BatchNorm2d

  • Dropout

    返回 图 X

  1. 残差
  • 无残差

    lambda x: 0: 返回 0

  • 输入输出通道数一样, 并且步长为 1

    lambda x: x: 返回本身

  • 其他情况

    • 卷积 Conv2d (输入通道数, 输出通道数, 卷积核大小: 1, 步长: (时间步长, 1))

    • ?BatchNorm2d

  1. relu

nn.BatchNorm2d(out_channels)

inplace-选择是否进行覆盖运算, 对从上层网络 Conv2d 中传递下来的 tensor 直接进行修改,这样能够节省运算内存,不用多存储其他变量

  1. forward
1
2
3
res = self.residual(x)
x, A = self.gcn(x, A)
x = self.tcn(x) + res

relu -> gcn -> tcn