高友谊

记录科研、技术与思考

深度学习——nn.Sequential 的使用

这篇文章主要记录 深度学习——nn.Sequential 的使用 的学习过程,方便后续快速回顾核心概念、代码写法与实验细节。

目标流程

流程解析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear

class MyNN(nn.Module):
def __init__(self):
super(MyNN, self).__init__()
# Sequential组织多个卷积等操作,相当于transforms中的compose
self.module1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)

def forward(self, x):
x = self.module1(x)
return x

mynn = MyNN()
print(mynn)

# https://www.w3cschool.cn/pytorch/pytorch-skug3bpf.html,查看张量的四个属性
# torch.ones函数返回全1的张量,tensor数据类型

# 生成一个64Batch,3Channel,32Width, 32Height的源数据
input = torch.ones((64, 3, 32, 32))

# 卷积池化等操作
output = mynn(input)

# 查看结果
print(output.shape)