1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| class K_mer_aggregate(nn.Module): def __init__(self,kmers,in_dim,out_dim,dropout=0.1): super(K_mer_aggregate, self).__init__() self.dropout=nn.Dropout(dropout) self.convs=[] for i in kmers: print(i) self.convs.append(nn.Conv1d(in_dim,out_dim,i,padding=0)) self.convs=nn.ModuleList(self.convs) self.activation=nn.ReLU(inplace=True) self.norm=nn.LayerNorm(out_dim)
def forward(self,x): x = x.permute(0,2,1) outputs=[] for conv in self.convs: outputs.append(conv(x)) outputs=torch.cat(outputs,dim=2) outputs=self.norm(outputs.permute(0,2,1)) return outputs
|