LSTM.lua (3420B)
1 _ = require 'underscore' 2 local nninit = require 'nninit' 3 4 local LSTM = {} 5 function LSTM.lstm(input_size, rnn_size, n, dropout, recurrent_dropout, embedding, num_fixed) 6 dropout = dropout or 0 7 8 -- there will be 2*n+1 inputs 9 local inputs = {} 10 table.insert(inputs, nn.Identity()()) -- x 11 for L = 1,n do 12 table.insert(inputs, nn.Identity()()) -- prev_c[L] 13 table.insert(inputs, nn.Identity()()) -- prev_h[L] 14 end 15 16 local x, input_size_L 17 local outputs = {} 18 local fixeds = {} 19 for L = 1,n do 20 -- c,h from previos timesteps 21 local prev_h = inputs[L*2+1] 22 local prev_c = inputs[L*2] 23 -- the input to this layer 24 if L == 1 then 25 if embedding ~= nil then 26 input_size_L = 200 27 local embedded = embedding(inputs[1]) 28 x = nn.Tanh()(embedded) 29 else 30 x = OneHot(input_size)(inputs[1]) 31 input_size_L = input_size 32 end 33 else 34 x = outputs[(L-1)*2] 35 if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 36 input_size_L = rnn_size 37 end 38 39 40 if recurrent_dropout > 0 then prev_h = nn.SharedDropout(dropout)(prev_h) end 41 42 local i2h = nil 43 local h2h = nil 44 if L <= num_fixed then 45 i2h = nn.LinearFixed(input_size_L, 4 * rnn_size, "lstm-l"..L.."-i2h-w.t7", "lstm-l"..L.."-i2h-b.t7")(x):annotate{name='i2h_'..L} 46 h2h = nn.LinearFixed(rnn_size, 4 * rnn_size, "lstm-l"..L.."-h2h-w.t7", "lstm-l"..L.."-h2h-b.t7")(prev_h):annotate{name='h2h_'..L} 47 fixeds[#fixeds+1] = i2h.data.module 48 fixeds[#fixeds+1] = h2h.data.module 49 else 50 i2h = nn.Linear(input_size_L, 4 * rnn_size):init('weight', nninit.uniform, -0.08, 0.08)(x):annotate{name='i2h_'..L} 51 h2h = nn.Linear(rnn_size, 4 * rnn_size):init('weight', nninit.uniform, -0.08, 0.08)(prev_h):annotate{name='h2h_'..L} 52 end 53 54 local all_input_sums = nn.CAddTable()({i2h, h2h}) 55 56 local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) 57 local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) 58 -- decode the gates 59 local in_gate = nn.Sigmoid()(n1) 60 local forget_gate = nn.Sigmoid()(n2) 61 local out_gate = nn.Sigmoid()(n3) 62 -- decode the write inputs 63 local in_transform = nn.Tanh()(n4) 64 -- perform the LSTM update 65 local next_c = nn.CAddTable()({ 66 nn.CMulTable()({forget_gate, prev_c}), 67 nn.CMulTable()({in_gate, in_transform}) 68 }) 69 -- gated cells form the output 70 local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 71 table.insert(outputs, next_c) 72 table.insert(outputs, next_h) 73 end 74 -- set up the decoder 75 local top_h = outputs[#outputs] 76 if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end 77 local proj = nn.Linear(rnn_size, input_size):init('weight', nninit.uniform, -0.08, 0.08)(top_h):annotate{name='decoder'} 78 local logsoft = nn.LogSoftMax()(proj) 79 table.insert(outputs, logsoft) 80 81 local module = nn.gModule(inputs, outputs) 82 function module.parametersNoGrad() 83 --print(_.map(fixeds, function(fixed) return fixed:parametersNoGrad() end)) 84 return _.flatten(_.map(fixeds, function(fixed) return fixed:parametersNoGrad() end)) 85 end 86 return module 87 end 88 89 function flatten(list) 90 if type(list) ~= "table" then return {list} end 91 local flat_list = {} 92 for _, elem in ipairs(list) do 93 for _, val in ipairs(flatten(elem)) do 94 flat_list[#flat_list + 1] = val 95 end 96 end 97 return flat_list 98 end 99 100 return LSTM 101