‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

OneHot.lua (670B)


      1 
      2 local OneHot, parent = torch.class('OneHot', 'nn.Module')
      3 
      4 function OneHot:__init(outputSize)
      5   parent.__init(self)
      6   self.outputSize = outputSize
      7   -- We'll construct one-hot encodings by using the index method to
      8   -- reshuffle the rows of an identity matrix. To avoid recreating
      9   -- it every iteration we'll cache it.
     10   self._eye = torch.eye(outputSize)
     11 end
     12 
     13 function OneHot:updateOutput(input)
     14   self.output:resize(input:size(1), self.outputSize):zero()
     15   if self._eye == nil then self._eye = torch.eye(self.outputSize) end
     16   self._eye = self._eye:float()
     17   local longInput = input:long()
     18   self.output:copy(self._eye:index(1, longInput))
     19   return self.output
     20 end