‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

RNN.lua (1307B)


      1 local RNN = {}
      2 
      3 function RNN.rnn(input_size, rnn_size, n, dropout, embedding)
      4   
      5   -- there are n+1 inputs (hiddens on each layer and x)
      6   local inputs = {}
      7   table.insert(inputs, nn.Identity()()) -- x
      8   for L = 1,n do
      9     table.insert(inputs, nn.Identity()()) -- prev_h[L]
     10 
     11   end
     12 
     13   local x, input_size_L
     14   local outputs = {}
     15   for L = 1,n do
     16     
     17     local prev_h = inputs[L+1]
     18     if L == 1 then 
     19       if embedding ~= nil then
     20         input_size_L = 200
     21         local embedded = embedding(inputs[1])
     22         x = nn.Tanh()(embedded)
     23       else
     24         x = OneHot(input_size)(inputs[1])
     25         input_size_L = input_size
     26       end
     27     else 
     28       x = outputs[(L-1)] 
     29       if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any
     30       input_size_L = rnn_size
     31     end
     32 
     33     -- RNN tick
     34     local i2h = nn.Linear(input_size_L, rnn_size)(x)
     35     local h2h = nn.Linear(rnn_size, rnn_size)(prev_h)
     36     local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h})
     37 
     38     table.insert(outputs, next_h)
     39   end
     40 -- set up the decoder
     41   local top_h = outputs[#outputs]
     42   if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
     43   local proj = nn.Linear(rnn_size, input_size)(top_h)
     44   local logsoft = nn.LogSoftMax()(proj)
     45   table.insert(outputs, logsoft)
     46 
     47   return nn.gModule(inputs, outputs)
     48 end
     49 
     50 return RNN