IRNN.lua (2126B)
1 --require 'util.ELU' 2 --require 'util.LinearFixed' 3 --require 'util.L2Linear' 4 --require 'util.EigenvalueL2Linear' 5 6 local RNN = {} 7 8 function RNN.rnn(input_size, rnn_size, n, dropout, embedding) 9 10 -- there are n+1 inputs (hiddens on each layer and x) 11 local inputs = {} 12 table.insert(inputs, nn.Identity()()) -- x 13 for L = 1,n do 14 table.insert(inputs, nn.Identity()()) -- prev_h[L] 15 end 16 local h2hs = {} 17 local x, input_size_L 18 local outputs = {} 19 for L = 1,n do 20 21 local prev_h = inputs[L+1] 22 if L == 1 then 23 if embedding ~= nil then 24 input_size_L = 200 25 local embedded = embedding(inputs[1]) 26 print("**********EMBEDDING**********") 27 x = nn.Tanh()(embedded) 28 else 29 print("############OneHot###########") 30 x = OneHot(input_size)(inputs[1]) 31 input_size_L = input_size 32 end 33 else 34 x = outputs[(L-1)] 35 if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 36 input_size_L = rnn_size 37 end 38 39 -- RNN tick 40 local init = 0.9 41 local i2h = nil 42 if input_size_L == rnn_size then 43 i2h = nn.Linear(input_size_L, rnn_size)(x) 44 --i2h = nn.EigenvalueL2Linear(input_size_L, rnn_size, "i-h"..L)(x) 45 i2h.data.module.weight:eye(rnn_size):mul(init) 46 else 47 i2h = nn.Linear(input_size_L, rnn_size)(x) 48 end 49 local h2h = nn.Linear(rnn_size, rnn_size)(prev_h) 50 --local h2h = nn.L2Linear(rnn_size, rnn_size, 0.0015)(prev_h) 51 --local h2h = nn.EigenvalueL2Linear(rnn_size, rnn_size, "h-h"..L)(prev_h) 52 53 h2h.data.module.weight:eye(rnn_size):mul(init) 54 local next_h = nn.ELU()(nn.CAddTable(){i2h, h2h}) 55 --local next_h = nn.ReLU()(nn.CAddTable(){i2h, h2h}) 56 --local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h}) 57 58 table.insert(outputs, next_h) 59 table.insert(h2hs, h2h) 60 end 61 -- set up the decoder 62 local top_h = outputs[#outputs] 63 if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end 64 local proj = nn.Linear(rnn_size, input_size)(top_h) 65 local logsoft = nn.LogSoftMax()(proj) 66 table.insert(outputs, logsoft) 67 68 return nn.gModule(inputs, outputs), h2hs 69 end 70 71 return RNN