‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

IRNN.lua (2126B)


      1 --require 'util.ELU'
      2 --require 'util.LinearFixed'
      3 --require 'util.L2Linear'
      4 --require 'util.EigenvalueL2Linear'
      5 
      6 local RNN = {}
      7 
      8 function RNN.rnn(input_size, rnn_size, n, dropout, embedding)
      9   
     10   -- there are n+1 inputs (hiddens on each layer and x)
     11   local inputs = {}
     12   table.insert(inputs, nn.Identity()()) -- x
     13   for L = 1,n do
     14     table.insert(inputs, nn.Identity()()) -- prev_h[L]
     15   end
     16   local h2hs = {}
     17   local x, input_size_L
     18   local outputs = {}
     19   for L = 1,n do
     20     
     21     local prev_h = inputs[L+1]
     22     if L == 1 then 
     23       if embedding ~= nil then
     24         input_size_L = 200
     25         local embedded = embedding(inputs[1])
     26 print("**********EMBEDDING**********")
     27         x = nn.Tanh()(embedded)
     28       else
     29 print("############OneHot###########")
     30         x = OneHot(input_size)(inputs[1])
     31         input_size_L = input_size
     32       end
     33     else 
     34       x = outputs[(L-1)] 
     35       if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any
     36       input_size_L = rnn_size
     37     end
     38 
     39     -- RNN tick
     40     local init = 0.9
     41     local i2h = nil
     42     if input_size_L == rnn_size then
     43       i2h = nn.Linear(input_size_L, rnn_size)(x) 
     44       --i2h = nn.EigenvalueL2Linear(input_size_L, rnn_size, "i-h"..L)(x)
     45       i2h.data.module.weight:eye(rnn_size):mul(init)
     46     else
     47       i2h = nn.Linear(input_size_L, rnn_size)(x)
     48     end
     49     local h2h = nn.Linear(rnn_size, rnn_size)(prev_h) 
     50     --local h2h = nn.L2Linear(rnn_size, rnn_size, 0.0015)(prev_h) 
     51     --local h2h = nn.EigenvalueL2Linear(rnn_size, rnn_size, "h-h"..L)(prev_h)
     52 
     53     h2h.data.module.weight:eye(rnn_size):mul(init)
     54     local next_h = nn.ELU()(nn.CAddTable(){i2h, h2h})
     55     --local next_h = nn.ReLU()(nn.CAddTable(){i2h, h2h})
     56     --local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h})
     57 
     58     table.insert(outputs, next_h)
     59     table.insert(h2hs, h2h)
     60   end
     61 -- set up the decoder
     62   local top_h = outputs[#outputs]
     63   if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
     64   local proj = nn.Linear(rnn_size, input_size)(top_h)
     65   local logsoft = nn.LogSoftMax()(proj)
     66   table.insert(outputs, logsoft)
     67 
     68   return nn.gModule(inputs, outputs), h2hs
     69 end
     70 
     71 return RNN