RNN.lua (1307B)
1 local RNN = {} 2 3 function RNN.rnn(input_size, rnn_size, n, dropout, embedding) 4 5 -- there are n+1 inputs (hiddens on each layer and x) 6 local inputs = {} 7 table.insert(inputs, nn.Identity()()) -- x 8 for L = 1,n do 9 table.insert(inputs, nn.Identity()()) -- prev_h[L] 10 11 end 12 13 local x, input_size_L 14 local outputs = {} 15 for L = 1,n do 16 17 local prev_h = inputs[L+1] 18 if L == 1 then 19 if embedding ~= nil then 20 input_size_L = 200 21 local embedded = embedding(inputs[1]) 22 x = nn.Tanh()(embedded) 23 else 24 x = OneHot(input_size)(inputs[1]) 25 input_size_L = input_size 26 end 27 else 28 x = outputs[(L-1)] 29 if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 30 input_size_L = rnn_size 31 end 32 33 -- RNN tick 34 local i2h = nn.Linear(input_size_L, rnn_size)(x) 35 local h2h = nn.Linear(rnn_size, rnn_size)(prev_h) 36 local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h}) 37 38 table.insert(outputs, next_h) 39 end 40 -- set up the decoder 41 local top_h = outputs[#outputs] 42 if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end 43 local proj = nn.Linear(rnn_size, input_size)(top_h) 44 local logsoft = nn.LogSoftMax()(proj) 45 table.insert(outputs, logsoft) 46 47 return nn.gModule(inputs, outputs) 48 end 49 50 return RNN