1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| class netA(nn.Module): def __init__(self,input_size,hidden_size,num_layers,output_size,bidirectional): ''' input_size: embedding_dim bidirectional: True表示双向GRU , 其输出为 hidden_size * 2 ''' super(netA, self).__init__()
self.gru = nn.GRU(input_size=input_size,hidden_size = hidden_size,num_layers= num_layers, bidirectional = bidirectional, batch_first=True, dropout=0.1) if bidirectional: self.in_features = hidden_size*2 else: self.in_features = hidden_size self.fc = nn.Sequential( nn.Dropout(0.2), nn.Tanh(), nn.Linear(in_features = self.in_features, out_features = output_size), nn.Tanh() ) def forward(self,x):
x,h = self.gru(x) out = self.fc(x)
return out
class netB(nn.Module): def __init__(self,input_size,hidden_size,num_layers,output_size,bidirectional): ''' input_size: embedding_dim bidirectional: 双向网络 , 其输出为 hidden_size * 2 ''' super(netB, self).__init__()
self.gru = nn.LSTM(input_size=input_size,hidden_size = hidden_size,num_layers= num_layers, bidirectional = bidirectional, batch_first=True, dropout=0.1) if bidirectional: self.in_features = hidden_size*2 else: self.in_features = hidden_size
self.softmax = nn.Softmax(dim=1) self.fc = nn.Sequential( nn.Dropout(0.2), nn.Tanh(), nn.Linear(in_features = self.in_features, out_features = output_size) ) def forward(self,x): x,(hn,cn) = self.gru(x) out = self.fc(x) out =self.softmax(out)
return out
class netC(nn.Module): def __init__(self,vocab_size,input_size,hidden_size,num_layers,output_size,bidirectional): super(netC, self).__init__()
self.embedding = nn.Embedding(vocab_size,input_size) self.net1 = netA(input_size,hidden_size,num_layers,output_size,bidirectional) self.net2 = netB(input_size,hidden_size,num_layers,output_size,bidirectional)
def forward(self,x): x = self.embedding(x)
x_A = self.net1(x) x_B = self.net2(x)
out = torch.cat([x_A,x_B],dim=1)
return out
|