‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

LSTM.lua (3420B)


      1 _ = require 'underscore'
      2 local nninit = require 'nninit'
      3 
      4 local LSTM = {}
      5 function LSTM.lstm(input_size, rnn_size, n, dropout, recurrent_dropout, embedding, num_fixed)
      6   dropout = dropout or 0 
      7 
      8   -- there will be 2*n+1 inputs
      9   local inputs = {}
     10   table.insert(inputs, nn.Identity()()) -- x
     11   for L = 1,n do
     12     table.insert(inputs, nn.Identity()()) -- prev_c[L]
     13     table.insert(inputs, nn.Identity()()) -- prev_h[L]
     14   end
     15 
     16   local x, input_size_L
     17   local outputs = {}
     18   local fixeds = {}
     19   for L = 1,n do
     20     -- c,h from previos timesteps
     21     local prev_h = inputs[L*2+1]
     22     local prev_c = inputs[L*2]
     23     -- the input to this layer
     24     if L == 1 then
     25       if embedding ~= nil then
     26         input_size_L = 200
     27         local embedded = embedding(inputs[1])
     28         x = nn.Tanh()(embedded)
     29       else
     30         x = OneHot(input_size)(inputs[1])
     31         input_size_L = input_size
     32       end
     33     else 
     34       x = outputs[(L-1)*2] 
     35       if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any
     36       input_size_L = rnn_size
     37     end
     38 
     39 
     40     if recurrent_dropout > 0 then prev_h = nn.SharedDropout(dropout)(prev_h) end
     41 
     42     local i2h = nil
     43     local h2h = nil
     44     if L <= num_fixed then
     45 	i2h = nn.LinearFixed(input_size_L, 4 * rnn_size, "lstm-l"..L.."-i2h-w.t7", "lstm-l"..L.."-i2h-b.t7")(x):annotate{name='i2h_'..L}
     46     	h2h = nn.LinearFixed(rnn_size, 4 * rnn_size, "lstm-l"..L.."-h2h-w.t7", "lstm-l"..L.."-h2h-b.t7")(prev_h):annotate{name='h2h_'..L}
     47         fixeds[#fixeds+1] = i2h.data.module
     48         fixeds[#fixeds+1] = h2h.data.module
     49     else 
     50     	i2h = nn.Linear(input_size_L, 4 * rnn_size):init('weight', nninit.uniform, -0.08, 0.08)(x):annotate{name='i2h_'..L}
     51     	h2h = nn.Linear(rnn_size, 4 * rnn_size):init('weight', nninit.uniform, -0.08, 0.08)(prev_h):annotate{name='h2h_'..L}
     52     end
     53     
     54     local all_input_sums = nn.CAddTable()({i2h, h2h})
     55 
     56     local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
     57     local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
     58     -- decode the gates
     59     local in_gate = nn.Sigmoid()(n1)
     60     local forget_gate = nn.Sigmoid()(n2)
     61     local out_gate = nn.Sigmoid()(n3)
     62     -- decode the write inputs
     63     local in_transform = nn.Tanh()(n4)
     64     -- perform the LSTM update
     65     local next_c           = nn.CAddTable()({
     66         nn.CMulTable()({forget_gate, prev_c}),
     67         nn.CMulTable()({in_gate,     in_transform})
     68       })
     69     -- gated cells form the output
     70     local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
     71     table.insert(outputs, next_c)
     72     table.insert(outputs, next_h)
     73   end
     74   -- set up the decoder
     75   local top_h = outputs[#outputs]
     76   if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
     77   local proj = nn.Linear(rnn_size, input_size):init('weight', nninit.uniform, -0.08, 0.08)(top_h):annotate{name='decoder'}
     78   local logsoft = nn.LogSoftMax()(proj)
     79   table.insert(outputs, logsoft)
     80 
     81   local module = nn.gModule(inputs, outputs)
     82   function module.parametersNoGrad()
     83     --print(_.map(fixeds, function(fixed) return fixed:parametersNoGrad() end))
     84     return _.flatten(_.map(fixeds, function(fixed) return fixed:parametersNoGrad() end))
     85   end
     86   return module
     87 end
     88 
     89 function flatten(list)
     90   if type(list) ~= "table" then return {list} end
     91   local flat_list = {}
     92   for _, elem in ipairs(list) do
     93     for _, val in ipairs(flatten(elem)) do
     94       flat_list[#flat_list + 1] = val
     95     end
     96   end
     97   return flat_list
     98 end
     99 
    100 return LSTM
    101