CharSplitLMMinibatchLoader.lua (9534B)
1 local rex = require 'rex_pcre' 2 3 -- Modified from https://github.com/oxford-cs-ml-2015/practical6 4 -- the modification included support for train/val/test splits 5 6 local CharSplitLMMinibatchLoader = {} 7 CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader 8 9 function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, split_fractions, word_level, threshold) 10 -- split_fractions is e.g. {0.9, 0.05, 0.05} 11 12 if not word_level then 13 threshold = 0 14 end 15 local self = {} 16 setmetatable(self, CharSplitLMMinibatchLoader) 17 18 self.word_level = word_level 19 local input_file = path.join(data_dir, 'input.txt') 20 local vocab_file = path.join(data_dir, word_level and 'vocab_w' .. threshold .. '.t7' or 'vocab.t7') 21 local tensor_file = path.join(data_dir, word_level and 'data_w' .. threshold .. '.t7' or 'data.t7') 22 23 -- fetch file attributes to determine if we need to rerun preprocessing 24 local run_prepro = false 25 print(vocab_file) 26 print(tensor_file) 27 if not (path.exists(vocab_file) or path.exists(tensor_file)) then 28 -- prepro files do not exist, generate them 29 print('vocab.t7 and data.t7 do not exist. Running preprocessing...') 30 run_prepro = true 31 else 32 -- check if the input file was modified since last time we 33 -- ran the prepro. if so, we have to rerun the preprocessing 34 local input_attr = lfs.attributes(input_file) 35 local vocab_attr = lfs.attributes(vocab_file) 36 local tensor_attr = lfs.attributes(tensor_file) 37 if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then 38 print('vocab.t7 or data.t7 detected as stale. Re-running preprocessing...') 39 run_prepro = true 40 end 41 end 42 if run_prepro then 43 -- construct a tensor with all the data, and vocab file 44 print('one-time setup: preprocessing input text file ' .. input_file .. '...') 45 CharSplitLMMinibatchLoader.text_to_tensor(word_level, threshold, input_file, vocab_file, tensor_file) 46 end 47 48 print('loading data files...') 49 local data = torch.load(tensor_file) 50 self.vocab_mapping = torch.load(vocab_file) 51 52 -- cut off the end so that it divides evenly 53 local len = data:size(1) 54 if len % (batch_size * seq_length) ~= 0 then 55 print('cutting off end of data so that the batches/sequences divide evenly') 56 data = data:sub(1, batch_size * seq_length 57 * math.floor(len / (batch_size * seq_length))) 58 end 59 60 -- count vocab 61 self.vocab_size = 0 62 for _ in pairs(self.vocab_mapping) do 63 self.vocab_size = self.vocab_size + 1 64 end 65 66 -- self.batches is a table of tensors 67 print('reshaping tensor...') 68 self.batch_size = batch_size 69 self.seq_length = seq_length 70 71 local ydata = data:clone() 72 ydata:sub(1,-2):copy(data:sub(2,-1)) 73 ydata[-1] = data[1] 74 self.x_batches = data:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 75 self.nbatches = #self.x_batches 76 self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 77 assert(#self.x_batches == #self.y_batches) 78 79 -- lets try to be helpful here 80 if self.nbatches < 50 then 81 print('WARNING: less than 50 batches in the data in total? Looks like very small dataset. You probably want to use smaller batch_size and/or seq_length.') 82 end 83 84 -- perform safety checks on split_fractions 85 assert(split_fractions[1] >= 0 and split_fractions[1] <= 1, 'bad split fraction ' .. split_fractions[1] .. ' for train, not between 0 and 1') 86 assert(split_fractions[2] >= 0 and split_fractions[2] <= 1, 'bad split fraction ' .. split_fractions[2] .. ' for val, not between 0 and 1') 87 assert(split_fractions[3] >= 0 and split_fractions[3] <= 1, 'bad split fraction ' .. split_fractions[3] .. ' for test, not between 0 and 1') 88 if split_fractions[3] == 0 then 89 -- catch a common special case where the user might not want a test set 90 self.ntrain = math.floor(self.nbatches * split_fractions[1]) 91 self.nval = self.nbatches - self.ntrain 92 self.ntest = 0 93 else 94 -- divide data to train/val and allocate rest to test 95 self.ntrain = math.floor(self.nbatches * split_fractions[1]) 96 self.nval = math.floor(self.nbatches * split_fractions[2]) 97 self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly) 98 end 99 100 self.split_sizes = {self.ntrain, self.nval, self.ntest} 101 self.batch_ix = {0,0,0} 102 103 print(string.format('data load done. Number of data batches in train: %d, val: %d, test: %d', self.ntrain, self.nval, self.ntest)) 104 collectgarbage() 105 return self 106 end 107 108 function CharSplitLMMinibatchLoader:reset_batch_pointer(split_index, batch_index) 109 batch_index = batch_index or 0 110 self.batch_ix[split_index] = batch_index 111 end 112 113 function CharSplitLMMinibatchLoader:next_batch(split_index) 114 if self.split_sizes[split_index] == 0 then 115 -- perform a check here to make sure the user isn't screwing something up 116 local split_names = {'train', 'val', 'test'} 117 print('ERROR. Code requested a batch for split ' .. split_names[split_index] .. ', but this split has no data.') 118 os.exit() -- crash violently 119 end 120 -- split_index is integer: 1 = train, 2 = val, 3 = test 121 self.batch_ix[split_index] = self.batch_ix[split_index] + 1 122 if self.batch_ix[split_index] > self.split_sizes[split_index] then 123 self.batch_ix[split_index] = 1 -- cycle around to beginning 124 end 125 -- pull out the correct next batch 126 local ix = self.batch_ix[split_index] 127 if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size 128 if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + val 129 return self.x_batches[ix], self.y_batches[ix] 130 end 131 132 -- *** STATIC method *** 133 function CharSplitLMMinibatchLoader.text_to_tensor(word_level, threshold, in_textfile, out_vocabfile, out_tensorfile) 134 local timer = torch.Timer() 135 print('loading text file...') 136 local f = torch.DiskFile(in_textfile) 137 local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once 138 f:close() 139 140 -- create vocabulary if it doesn't exist yet 141 print('creating vocabulary mapping...') 142 print('word occurence threshold is ' .. threshold) 143 -- record all characters to a set 144 local unordered = {} 145 --rawdata = re.sub('([%s])' % (re.escape(string.punctuation)+"1234567890"), r" \1 ", rawdata) 146 local numtokens = 0 147 for token in CharSplitLMMinibatchLoader.tokens(rawdata, word_level) do 148 if not unordered[token] then 149 unordered[token] = 1 150 else 151 unordered[token] = unordered[token] + 1 152 end 153 numtokens = numtokens + 1 154 end 155 -- sort into a table (i.e. keys become 1..N) 156 local ordered = {} 157 for token, count in pairs(unordered) do 158 if count > threshold then 159 ordered[#ordered + 1] = token 160 end 161 end 162 if word_level then 163 ordered[#ordered + 1] = "UNK" --represents unknown words 164 end 165 table.sort(ordered) 166 -- invert `ordered` to create the char->int mapping 167 local vocab_mapping = {} 168 for i, char in ipairs(ordered) do 169 vocab_mapping[char] = i 170 end 171 -- construct a tensor with all the data 172 print('putting data into tensor...') 173 local data = word_level and torch.ShortTensor(numtokens) or torch.ByteTensor(#rawdata) -- store it into 1D first, then rearrange 174 if word_level then 175 local i = 1 176 for token in CharSplitLMMinibatchLoader.tokens(rawdata, word_level) do 177 data[i] = vocab_mapping[token] or vocab_mapping["UNK"] 178 i = i + 1 179 end 180 else 181 for i=1, #rawdata do 182 data[i] = vocab_mapping[rawdata:sub(i, i)] -- lua has no string indexing using [] 183 end 184 end 185 186 -- save output preprocessed files 187 print('saving ' .. out_vocabfile) 188 torch.save(out_vocabfile, vocab_mapping) 189 print('saving ' .. out_tensorfile) 190 torch.save(out_tensorfile, data) 191 end 192 193 function CharSplitLMMinibatchLoader.tokens(rawstr, word_level) 194 if word_level then 195 --local str, _, _ = rex.gsub(rawstr, '[[:punct:][:digit:]]', ' %0 ') 196 --str, _, _ = rex.gsub(str, '\\n', ' RN ') 197 --return rex.split(str, "\\s+") 198 return word_iter(rawstr) 199 else 200 return rawstr:gmatch'.' 201 end 202 end 203 204 function word_iter(str) 205 local n = str:len() 206 local punctdigit = rex.new('[[:punct:][:digit:]]') 207 local newline = rex.new('\\n') 208 local whitespace = rex.new('[ \\t]') --dont match newlines 209 local char_iter = str:gmatch'.' 210 local c = char_iter() 211 return function() 212 if c == nil then return nil end 213 while rex.count(c, whitespace) > 0 do 214 c = char_iter() 215 if c == nil then return nil end 216 end 217 if rex.count(c, punctdigit) > 0 then 218 local ret = c 219 c = char_iter() 220 return ret 221 end 222 if rex.count(c, newline) > 0 then 223 c = char_iter() 224 return '\n' 225 end 226 local word = '' 227 repeat 228 word = word .. c 229 c = char_iter() 230 if c == nil then return word end 231 until rex.count(c, whitespace) > 0 or rex.count(c, punctdigit) > 0 or rex.count(c, newline) > 0 232 233 return word 234 end 235 end 236 237 return CharSplitLMMinibatchLoader 238 239 240 241 242 243 244 245 246 247 248