torch.nn.Sequential
是 PyTorch 中一个模块容器,用于将一系列层或模块按顺序连接在一起,简化前向传播过程。在 Sequential
中,所有的子模块会按照添加的顺序被执行,适合那些有明确顺序的神经网络结构,比如卷积神经网络、全连接网络等。
主要特点
- 按顺序执行: 将多个子模块按顺序组合,前向传播时依次调用。
- 简洁代码: 减少显式定义
forward
方法的需求,对于简单的网络结构,使用Sequential
可以大大简化代码。 - 嵌套支持:
Sequential
容器可以嵌套,允许将多个Sequential
容器嵌套在一起。
使用方式
- 直接传入模块: 可以通过将模块按顺序传入
Sequential
。 - 有序字典: 可以使用
OrderedDict
来为每个模块指定名字。
基本用法
1. 直接传入模块
import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 5)
)input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
在这个例