Sequence
pad_sequence
: 填充序列
1 | # 定义序列个数、最小长度和最大长度 |
pack_padded_sequence
pad_packed_sequence
- 使用RNN(如LSTM或GRU)处理序列数据时,
pack_padded_sequence
将填充后的序列数据压缩为紧凑的表示形式,以提高计算效率。 - 在RNN处理完紧凑的序列后,可能需要将输出转换回填充序列的格式,以便进行后续操作(如解码、分类等)。
pad_packed_sequence
函数将RNN的输出转换回填充序列格式,同时保留原始序列的长度信息。
1 | lengths = torch.tensor([len(x) for x in random_sequences]) |
torch.nn.utils.rnn.pad_sequence — PyTorch 2.1 documentation
torch.nn.utils.rnn.pad_packed_sequence — PyTorch 2.1 documentation