GloVeEmbedding.lua (6534B)
1 2 function split(inputstr, sep) 3 if sep == nil then 4 sep = "%s" 5 end 6 local t={} ; i=1 7 for str in string.gmatch(inputstr, "([^"..sep.."]+)") do 8 t[i] = str 9 i = i + 1 10 end 11 return t 12 end 13 14 function numkeys(T) 15 local count = 0 16 for _ in pairs(T) do count = count + 1 end 17 return count 18 end 19 20 local GloVeEmbedding, parent = torch.class('GloVeEmbedding', 'nn.LookupTable') 21 22 function GloVeEmbedding:__init(word2idx, embedding_size, data_dir) 23 -- you need glove embeddings in the directory ./util/glove/ 24 -- download them from http://nlp.stanford.edu/projects/glove/ 25 local embedding_file = 'util/glove/vectors.6B.200d.txt' 26 local file_embedding_size = 200 27 self.vocab_size = numkeys(word2idx) 28 self.word2idx = word2idx 29 parent.__init(self, self.vocab_size, embedding_size) 30 print("loading glove vectors") 31 self.embedding_size = embedding_size 32 --load glove vectors as a tensor from a .t7 file if it exists, otherwise generate that .t7 file 33 local vocab_embedding_file = path.join(data_dir, "glove_" .. self.vocab_size .. "x" .. embedding_size .. ".t7") 34 --vocab_embedding_file = "lstm-glove-w.t7" 35 --(if you want to load a different word vector here, just swap the vocab_embedding_file file name with something else) 36 if path.exists(vocab_embedding_file) then 37 self.weight = torch.load(vocab_embedding_file):clone() 38 else 39 w = self:parseEmbeddingFile(embedding_file, file_embedding_size, word2idx) 40 if file_embedding_size ~= embedding_size then 41 w = torch.mm(w, torch.rand(file_embedding_size, embedding_size)) 42 end 43 self.weight = w:contiguous() 44 torch.save(vocab_embedding_file, self.weight) 45 end 46 print("loaded glove vectors") 47 end 48 49 function GloVeEmbedding:parseEmbeddingFile(embedding_file, file_embedding_size, word2idx) 50 local word_lower2idx = {} 51 local loaded = {} 52 local weight = torch.Tensor(self.vocab_size, file_embedding_size) 53 for word, idx in pairs(word2idx) do 54 word_lower2idx[word:lower()] = idx 55 end 56 57 for line in io.lines(embedding_file) do 58 local parts = split(line, " ") 59 local word = parts[1] 60 if word_lower2idx[word] then 61 local idx = word_lower2idx[word] 62 for i=2, #parts do 63 weight[idx][i-1] = tonumber(parts[i]) 64 end 65 loaded[word] = true 66 end 67 end 68 for word, idx in pairs(word2idx) do 69 if not loaded[word:lower()] then 70 print("Not loaded: " .. word:lower()) 71 for i=1, file_embedding_size do 72 weight[idx][i] = torch.normal(0, 0.9) --better way to do this? 73 end 74 end 75 end 76 return weight 77 end 78 79 function GloVeEmbedding:updateOutput(input) 80 return parent.updateOutput(self, input:contiguous()) 81 end 82 83 function GloVeEmbedding:accGradParameters(input, gradOutput, scale) 84 return parent.accGradParameters(self, input:contiguous(), gradOutput:contiguous(), scale) 85 end 86 87 88 local GloVeEmbeddingFixed, parent = torch.class('GloVeEmbeddingFixed', 'GloVeEmbedding') 89 90 function GloVeEmbeddingFixed:accGradParameters(input, gradOutput, scale) 91 return nil 92 end 93 function GloVeEmbeddingFixed:parameters() 94 return {}, {} 95 end 96 97 --[[ 98 function GloVeEmbeddingProject:__init(word2idx, embedding_size, data_dir) 99 local embedding_file = 'util/glove/vectors.6B.200d.txt' 100 local file_embedding_size = 200 101 self.vocab_size = numkeys(word2idx) 102 parent.__init(self, self.vocab_size, embedding_size) 103 print("loading glove vectors") 104 self.embedding_size = embedding_size 105 local vocab_embedding_file = path.join(data_dir, "glove_" .. self.vocab_size .. "x" .. embedding_size .. ".t7") 106 --print("loading pretrained word vectors") 107 --vocab_embedding_file = "glove_embeddings_pretrained1862x200-4sitelow-descs.t7" 108 --vocab_embedding_file = "glove_embeddings_pretrained1704x200-5site-title.t7" 109 if path.exists(vocab_embedding_file) then 110 self.weight = torch.load(vocab_embedding_file):clone() 111 else 112 w = self:parseEmbeddingFile(embedding_file, file_embedding_size, word2idx) 113 if file_embedding_size ~= embedding_size then 114 w = torch.mm(w, torch.rand(file_embedding_size, embedding_size)) 115 end 116 self.weight = w:contiguous() 117 torch.save(vocab_embedding_file, self.weight) 118 end 119 print("self.weight size") 120 print(self.weight:size()) 121 print("loaded glove vectors") 122 end 123 124 125 126 local GloVeEmbeddingFixed, parent = torch.class('GloVeEmbeddingFixed', 'nn.Module') 127 function GloVeEmbeddingFixed:__init(word2idx, embedding_file, embedding_size, data_dir) 128 self.vocab_size = numkeys(word2idx) 129 parent.__init(self, self.vocab_size, embedding_size) 130 print("loading glove vectors") 131 self.embedding_size = embedding_size 132 local vocab_embedding_file = path.join(data_dir, "glove_" .. self.vocab_size .. "x" .. embedding_size .. ".t7") 133 print("loading pretrained word vectors") 134 vocab_embedding_file = "glove_embeddings_pretrained1704x200.t7" --pretrained vectors 135 local loaded = {} 136 if path.exists(vocab_embedding_file) then 137 self.weight = torch.load(vocab_embedding_file) 138 else 139 self.weight = torch.Tensor(self.vocab_size, embedding_size) 140 local word_lower2idx = {} 141 for word, idx in pairs(word2idx) do 142 word_lower2idx[word:lower()] = idx 143 end 144 145 for line in io.lines(embedding_file) do 146 local parts = split(line, " ") 147 local word = parts[1] 148 if word_lower2idx[word] then 149 local idx = word_lower2idx[word] 150 for i=2, #parts do 151 self.weight[idx][i-1] = tonumber(parts[i]) 152 end 153 loaded[word] = true 154 end 155 end 156 for word, idx in pairs(word2idx) do 157 if not loaded[word:lower()] then 158 print("Not loaded: " .. word:lower()) 159 for i=1, self.embedding_size do 160 self.weight[idx][i] = torch.normal(0, 0.9) --better way to do this? 161 end 162 end 163 end 164 torch.save(vocab_embedding_file, self.weight) 165 end 166 print("loaded glove vectors") 167 end 168 169 function GloVeEmbeddingFixed:updateOutput(input) 170 print("GloVeEmbeddingFixed input size: " .. input:size()) 171 self.output:resize(input:size(1), self.embedding_size):zero() 172 local longInput = input:long() 173 self.output:copy(self.weight:index(1, longInput)) 174 return self.output 175 end 176 ]]-- 177