‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

GRU.lua (2230B)


      1 
      2 local GRU = {}
      3 
      4 --[[
      5 Creates one timestep of one GRU
      6 Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf
      7 ]]--
      8 function GRU.gru(input_size, rnn_size, n, dropout, embedding)
      9   dropout = dropout or 0 
     10   -- there are n+1 inputs (hiddens on each layer and x)
     11   local inputs = {}
     12   table.insert(inputs, nn.Identity()()) -- x
     13   for L = 1,n do
     14     table.insert(inputs, nn.Identity()()) -- prev_h[L]
     15   end
     16 
     17   function new_input_sum(insize, xv, hv)
     18     local i2h = nn.Linear(insize, rnn_size)(xv)
     19     local h2h = nn.Linear(rnn_size, rnn_size)(hv)
     20     return nn.CAddTable()({i2h, h2h})
     21   end
     22 
     23   local x, input_size_L
     24   local outputs = {}
     25   for L = 1,n do
     26 
     27     local prev_h = inputs[L+1]
     28     -- the input to this layer
     29     if L == 1 then 
     30       if embedding ~= nil then
     31         input_size_L = 200
     32         local embedded = embedding(inputs[1])
     33         x = nn.Tanh()(embedded)
     34       else
     35         x = OneHot(input_size)(inputs[1])
     36         input_size_L = input_size
     37       end
     38     else 
     39       x = outputs[(L-1)] 
     40       if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any
     41       input_size_L = rnn_size
     42     end
     43     -- GRU tick
     44     -- forward the update and reset gates
     45     local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h))
     46     local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h))
     47     -- compute candidate hidden state
     48     local gated_hidden = nn.CMulTable()({reset_gate, prev_h})
     49     local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden)
     50     local p1 = nn.Linear(input_size_L, rnn_size)(x)
     51     local hidden_candidate = nn.Tanh()(nn.CAddTable()({p1,p2}))
     52     -- compute new interpolated hidden state, based on the update gate
     53     local zh = nn.CMulTable()({update_gate, hidden_candidate})
     54     local zhm1 = nn.CMulTable()({nn.AddConstant(1,false)(nn.MulConstant(-1,false)(update_gate)), prev_h})
     55     local next_h = nn.CAddTable()({zh, zhm1})
     56 
     57     table.insert(outputs, next_h)
     58   end
     59 -- set up the decoder
     60   local top_h = outputs[#outputs]
     61   if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
     62   local proj = nn.Linear(rnn_size, input_size)(top_h)
     63   local logsoft = nn.LogSoftMax()(proj)
     64   table.insert(outputs, logsoft)
     65 
     66   return nn.gModule(inputs, outputs)
     67 end
     68 
     69 return GRU