1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| import torch from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
class MyNN(nn.Module): def __init__(self): super(MyNN, self).__init__() # Sequential组织多个卷积等操作,相当于transforms中的compose self.module1 = Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024, 64), Linear(64, 10) )
def forward(self, x): x = self.module1(x) return x
mynn = MyNN() print(mynn)
# https://www.w3cschool.cn/pytorch/pytorch-skug3bpf.html,查看张量的四个属性 # torch.ones函数返回全1的张量,tensor数据类型
# 生成一个64Batch,3Channel,32Width, 32Height的源数据 input = torch.ones((64, 3, 32, 32))
# 卷积池化等操作 output = mynn(input)
# 查看结果 print(output.shape)
|