点击上方“MLNLP”,选择“星标”公众号
重磅干货,第一时间送达
作者丨CV路上一名研究僧
知乎专栏丨深度图像与视频增强
地址丨https://zhuanlan.zhihu.com/p/79046709

0. 遇到大坑

笔者在最近的项目中用到了自定义loss函数,代码一切都准备就绪后,在训练时遇到了梯度爆炸的问题,每次训练几个iterations后,梯度和loss都会变为nan。一般情况下,梯度变为nan都是出现了 
 , 
 等情况,导致结果变为+inf,也就成了nan。

1. 问题分析

笔者需要的loss函数如下:
其中, 
 。
从理论上分析,这个loss函数在反向传播过程中很可能会遇到梯度爆炸,这是为什么呢?反向传播的过程是对loss链式求一阶导数的过程,那么, 
 的导数为:
由于 
 ,这个导数又可以表示为:
这样的话,出现了类似于 
 的表达式,也就会出现典型的$0/1$问题了。为了避免这个问题,首先进行了如下的 
 改变:
经过改变,在$x_i=0$时,不再是 
 问题了,而是转换为了一个线性函数,梯度成为了恒定的12.9,从理论上来看,避免了梯度爆炸的问题。

2. PyTorch初步实现

在实现这一过程时,依旧...遇到了大坑,下面通过示例代码来说明:
"""

loss = mse(X, gamma_inv(X))

"""

def loss_function
(
x
):

mask = (x <
0.003
).float()

gamma_x = mask *
12.9
* x + (
1
-mask) * (x **
0.5
)

loss = torch.mean((x - gamma_x) **
2
)

return
loss


if
__name__ ==
'__main__'
:

x = Variable(torch.FloatTensor([
0, 0.0025, 0.5, 0.8, 1
])
, requires_grad
=
True
)

loss = loss_function(x)

print
(
'loss:',
loss)

loss.backward()

print
(x.grad)
改进后的 
 是一个分支结构,在实现时,就采用了类似于Matlab中矩阵计算的mask方式,mask定义为 
 ,满足条件的$x_i$在mask中对应位置的值为1,因此, 
 的结构只会保留 
 的结果,同样的道理, 
 就实现了上述改进后的 
 公式。
按理来说,此时,在反向传播过程中的梯度应该是正确的,但是,上面代码的输出结果为:
loss: tensor(
0.0105, grad_fn
=<MeanBackward1>)

tensor([ nan
, 0.1416,
-
0.0243,
-
0.0167, 0.0000
])
emmm....依旧为nan,问题在理论层面得到了解决,但是,在实现层面依旧没能解决.....

3. 源码调试分析

上面源码的问题依旧在 
 的实现,这个过程,在Python解释器解释的过程或许是这样的:
  1. 计算 
     ,对mask进行广播式的乘法,结果为:原本为1的位置变为了12.9,原本为0的位置依旧为0;
  2. 将1.的结果继续与x相乘,本质上仍然是与x的每个元素相乘,只是mask中不满足条件的 
     位置为0,表现出的结果是仅对满足条件的 
     进行了计算;
  3. 按照2.所述的原理, 
     公式的后半部分也是同样的计算过程,即, 
     中的每个值依旧会进行 
     的计算;
按照上述过程进行前向传播,在反向传播时,梯度不是从某一个分支得到的,而是两个分支的题目相加得到的,换句话说,依旧没能解决梯度变为nan的问题。

4. 源码改进及问题解决

经过第三部分的分析,知道了梯度变为nan的根本原因是当 
 时依旧参与了 
 的计算,导致在反向传播时计算出的梯度为nan。
要解决这个问题,就要保证在 
 时不会进行这样的计算。
新的PyTorch代码如下:
def loss_function
(x):

mask = x <
0.003

gamma_x = torch.FloatTensor(x.size()).type_as(x)

gamma_x[mask] =
12.9
* x[mask]

mask = x >=
0.003

gamma_x[mask] = x[mask] **
0.5

loss = torch.mean((x - gamma_x) **
2
)

return
loss


if
__name__ ==
'__main__'
:

x = Variable(torch.FloatTensor([
0, 0.0025, 0.5, 0.8, 1
])
, requires_grad
=
True
)

loss = loss_function(x)

print
(
'loss:',
loss)

loss.backward()

print
(x.grad)
改变的地方位于`loss_function`,改变了对于 
 分支的处理方式,控制并保住每次计算仅有满足条件的值可以参与。此时输出为:
loss: tensor(
0.0105, grad_fn
=<MeanBackward1>)

tensor([
0.0000, 0.1416,
-
0.0243,
-
0.0167, 0.0000
])
就此,问题解决!
如有疑问,欢迎留言~
推荐阅读:
继续阅读