OneHot.lua (670B)
1 2 local OneHot, parent = torch.class('OneHot', 'nn.Module') 3 4 function OneHot:__init(outputSize) 5 parent.__init(self) 6 self.outputSize = outputSize 7 -- We'll construct one-hot encodings by using the index method to 8 -- reshuffle the rows of an identity matrix. To avoid recreating 9 -- it every iteration we'll cache it. 10 self._eye = torch.eye(outputSize) 11 end 12 13 function OneHot:updateOutput(input) 14 self.output:resize(input:size(1), self.outputSize):zero() 15 if self._eye == nil then self._eye = torch.eye(self.outputSize) end 16 self._eye = self._eye:float() 17 local longInput = input:long() 18 self.output:copy(self._eye:index(1, longInput)) 19 return self.output 20 end