公众号关注 “ML_NLP
设为 “星标”,重磅干货,第一时间送达!

作者:Erfandi Maula Yusnu, Lalu
编译:ronghuaiyang
导读
对使用PyTorch Lightning的训练代码和原始的PyTorch代码进行了对比,展示了其简单,干净,灵活的优点,相信你会喜欢的。
PyTorch Lightning是为ML研究人员设计的轻型PyTorch封装。它帮助你扩展模型并编写更少的样板文件,同时维护代码干净和灵活同时进行扩展。它帮助研究人员更多地专注于解决问题,而不是编写工程代码。
我从两年前就开始使用PyTorch了,我从0.3.0版本开始使用。在我使用PyTorch之前,我使用Keras作为我的深度学习框架,但后来我开始切换到PyTorch,原因有几个。如果你想知道我的原因,看看下面这篇文章:https://medium.com/swlh/why-i-switch-from-keras-to-pytorch-e48922f5846。
由于我一直在使用PyTorch,所以我需要牺牲在Keras中只用几行简单的行代码就可以进行训练的乐趣,而编写自己的训练代码。它有优点也有缺点,但是我选择PyTorch编写代码的方式来获得对我的训练代码的更多控制。但每当我想在深度学习中尝试一些新的模型时,就意味着我每次都需要编写训练和评估代码。
所以,我决定建立我自己的库,我称之为torchwisdom,但我陷入了困境,因为我仍在为我的公司构建OCR全pipeline系统。所以,我试图找到另一个解决方案,然后我找到了PyTorch Lightning,在我看到代码后,它让我一见钟情。
因此,我将在本文中介绍的内容是安装、基本的代码比较以及通过示例进行比较,这些示例是我自己通过从pytorch lightning site获取的,一些代码自己创建的。最后是本文的结论。

安装

好的,让我们从安装pytorch-lighting开始,这样你就可以跟着我一起做了。你可以使用pip或者conda安装pytorch lightning。
pip install
pip install pytorch-lightning

conda install
conda install pytorch-lightning -c conda-forge

对我来说,我更喜欢用anaconda作为我的python解释器,它对于深度学习和数据科学的人来说更完整。从第一次安装开始,它就自带了许多标准机器学习和数据处理库包。

基本代码的比较

在我们进入代码之前,我想让你看看下面的图片。下面有2张图片解释了pytorch和pytorch lightning在编码、建模和训练上的区别。在左边,你可以看到,pytorch需要更多的代码行来创建模型和训练。
有了pytorch lightning,代码就变成了Lightning模块的内部,所有的训练工程代码都被pytorch lightning解决了。但是你需要在一定程度上定制你的训练步骤,如下面的示例代码所示。
对于训练代码,你只需要3行代码,第一行是用于实例化模型类,第二行是用于实例化Trainer类,第三行是用于训练模型。
这个例子是用pytorch lightning训练的一种方法。当然,你可以对pytorch进行自定义风格的编码,因为pytorch lightning具有不同程度的灵活性。你想看吗?让我们继续。

通过例子进行比较

好了,在完成安装之后,让我们开始编写代码。要做的第一件事是导入需要使用的所有库。在此之后,你需要构建将用于训练的数据集和数据加载器。
# import all you need
import
 os

import
 torch

import
 torchvision

import
 torch.nn 
as
 nn

import
 torch.nn.functional 
as
 F

from
 torch.utils.data 
import
 DataLoader, random_split

from
 torchvision.datasets 
import
 MNIST

from
 torchvision 
import
 datasets, transforms

import
 pytorch_lightning 
as
 pl

from
 pytorch_lightning 
import
 Trainer

from
 pytorch_lightning.core.lightning 
import
 LightningModule



# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),

                              transforms.Normalize((
0.1307
,), (
0.3081
,))])


# data
mnist_train = MNIST(os.getcwd(), train=
True
, download=
True
, transform=transform)

mnist_train_loader = DataLoader(mnist_train, batch_size=
64
)

正如上面看到的代码,我们使用来自torchvision的MNIST数据集,并使用torch.utils.DataLoader创建数据加载器。现在,在下面的代码中,我们使网络与28x28像素的MNIST数据集想匹配。第一层有128个隐藏节点,第二层有256个隐藏节点,第三层为输出层,有10个类作为输出。
# build your model
classCustomMNIST(LightningModule):
def__init__(self):
        super().__init__()

# mnist images are (1, 28, 28) (channels, width, height)
        self.layer1 = torch.nn.Linear(
28
 * 
28
128
)

        self.layer2 = torch.nn.Linear(
128
256
)

        self.layer3 = torch.nn.Linear(
256
10
)


defforward(self, x):
        batch_size, channels, width, height = x.size()


# (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, 
-1
)


        x = self.layer1(x)

        x = torch.relu(x)


        x = self.layer2(x)

        x = torch.relu(x)


        x = self.layer3(x)

        x = torch.log_softmax(x, dim=
1
)


return
 x


deftraining_step(self, batch, batch_idx):
        data, target = batch

        logits = self.forward(data)

        loss = F.nll_loss(logits, target)

return
 {
'loss'
: loss}


defconfigure_optimizers(self):
return
 torch.optim.Adam(self.parameters(), lr=
1e-3
)



# train your model
model = CustomMNIST()

trainer = Trainer(max_epochs=
5
, gpus=
1
)

如果你在上面的gist代码中看到第27和33行,你会看到training_stepconfigure_optimators方法,它覆盖了在第2行中扩展的类LightningModule中的方法。这使得pytorch中标准的nn.Module不同于LightningModule,它有一些方法使它与第39行中的Trainer类兼容。
现在,让我们尝试另一种方法来编写代码。假设你必须编写一个库,或者希望其他人使用纯pytorch编写的库。你该怎样使用pytorch lightning?
下面的代码有两个类,第一个类使用标准的pytorch的nn.Module作为其父类。它是按照标准pytorch模块中通常编写的方式编写的,但是看第30行,有一个名为ExtendMNIST的类继承了两个类。这两个类由StandardMNIST类和LightningModule类组合在一起。这就是我喜欢python的地方,一个类可以有多个父类。
# build your model
classStandardMNIST(nn.Module):
def__init__(self):
        super().__init__()

# mnist images are (1, 28, 28) (channels, width, height)
        self.layer1 = torch.nn.Linear(
28
 * 
28
128
)

        self.layer2 = torch.nn.Linear(
128
256
)

        self.layer3 = torch.nn.Linear(
256
10
)


defforward(self, x):
        batch_size, channels, width, height = x.size()


# (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, 
-1
)


        x = self.layer1(x)

        x = torch.relu(x)


        x = self.layer2(x)

        x = torch.relu(x)


        x = self.layer3(x)

        x = torch.log_softmax(x, dim=
1
)


return
 x



# extend StandardMNIST and LightningModule at the same time
# this is what I like from python, extend two class at the same time
classExtendMNIST(StandardMNIST, LightningModule):
def__init__(self):
        super().__init__()  


deftraining_step(self, batch, batch_idx):
        data, target = batch

        logits = self.forward(data)

        loss = F.nll_loss(logits, target)

return
 {
'loss'
: loss}


defconfigure_optimizers(self):
return
 torch.optim.Adam(self.parameters(), lr=
1e-3
)



# run the training
model = ExtendMNIST()

trainer = Trainer(max_epochs=
5
, gpus=
1
)

trainer.fit(model, mnist_train_loader)

如果你看到ExtendMNIST类中的代码,你会看到它只是覆盖了LightningModule类。使用这种编写代码的方法,你可以扩展以前编写的任何其他模型,而无需更改它,并且仍然可以使用pytorch lightning库。
那么,你能在训练时给我看一下结果吗?好,让我们看看它在训练时是什么样子。
这样你就有了它在训练时的屏幕截图。它有一个很好的进度条,显示了网络的损失,这不是让你更容易训练一个模型吗?
如果你想查看实际运行的代码,可以单击下面的链接。第一个是pytorch lightning的标准方式,第二个是自定义方式。
PyTorch Lightning StandardStandard waycolab.research.google.com
PyTorch Lightning CustomCustom Waycolab.research.google.com

总结

PyTorch Lightning已经开发出了一个很好的标准代码,它有229个贡献者,并且它的开发非常活跃。现在,它甚至有风险投资,因为它达到了版本0.7。
在这种情况下(风险投资),我相信pytorch lightning将足够稳定,可以用作你编写pytorch代码的标准库,而不必担心将来开发会停止。
对于我来说,我选择在我的下一个项目中使用pytorch lighting,我喜欢它的灵活性,简单和干净的方式来编写用于深度学习研究的代码。
好了,今天就到这里,祝你愉快。记住要去尝试,不会有什么损失。
仓库地址共享:
在机器学习算法与自然语言处理公众号后台回复“代码”
即可获取195篇NAACL+295篇ACL2019有代码开源的论文。开源地址如下:https://github.com/yizhen20133868/NLP-Conferences-Code
重磅!机器学习算法与自然语言处理交流群已正式成立
群内有大量资源,欢迎大家进群学习!
额外赠送福利资源!邱锡鹏深度学习与神经网络,pytorch官方中文教程,利用Python进行数据分析,机器学习学习笔记,pandas官方文档中文版,effective java(中文版)等20项福利资源
获取方式:进入群后点开群公告即可领取下载链接
注意:请大家添加时修改备注为 [学校/公司 + 姓名 + 方向]
例如 —— 哈工大+张三+对话系统。
号主,微商请自觉绕道。谢谢!
继续阅读
阅读原文