‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

dropout.lua (825B)


      1 require 'nn'
      2 
      3 local Dropout, Parent = torch.class('nn.Dropout', 'nn.Module')
      4 
      5 function Dropout:__init(p)
      6    Parent.__init(self)
      7    self.p = p or 0.5
      8    if self.p >= 1 or self.p < 0 then
      9       error('<Dropout> illegal percentage, must be 0 <= p < 1')
     10    end
     11    self.noise = torch.Tensor()
     12 end
     13 
     14 function Dropout:updateOutput(input)
     15    self.output:resizeAs(input):copy(input)
     16    if self.p == 0 then return self.output end
     17    self.noise:resizeAs(input)
     18    self.noise:bernoulli(1-self.p)
     19    self.output:cmul(self.noise)
     20    self.output:div(1-self.p)
     21    return self.output
     22 end
     23 
     24 function Dropout:updateGradInput(input, gradOutput)
     25    self.gradInput:resizeAs(gradOutput):copy(gradOutput)
     26    self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector
     27    self.gradInput:div(1-self.p)
     28    return self.gradInput
     29 end