‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

CharSplitLMMinibatchLoader.lua (9534B)


      1 local rex = require 'rex_pcre'
      2 
      3 -- Modified from https://github.com/oxford-cs-ml-2015/practical6
      4 -- the modification included support for train/val/test splits
      5 
      6 local CharSplitLMMinibatchLoader = {}
      7 CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader
      8 
      9 function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, split_fractions, word_level, threshold)
     10     -- split_fractions is e.g. {0.9, 0.05, 0.05}
     11 
     12     if not word_level then 
     13         threshold = 0 
     14     end
     15     local self = {}
     16     setmetatable(self, CharSplitLMMinibatchLoader)
     17 
     18     self.word_level = word_level
     19     local input_file = path.join(data_dir,  'input.txt')
     20     local vocab_file = path.join(data_dir, word_level and 'vocab_w' .. threshold .. '.t7' or 'vocab.t7')
     21     local tensor_file = path.join(data_dir, word_level and 'data_w' .. threshold .. '.t7' or 'data.t7')
     22 
     23     -- fetch file attributes to determine if we need to rerun preprocessing
     24     local run_prepro = false
     25     print(vocab_file)
     26     print(tensor_file)
     27     if not (path.exists(vocab_file) or path.exists(tensor_file)) then
     28         -- prepro files do not exist, generate them
     29         print('vocab.t7 and data.t7 do not exist. Running preprocessing...')
     30         run_prepro = true
     31     else
     32         -- check if the input file was modified since last time we 
     33         -- ran the prepro. if so, we have to rerun the preprocessing
     34         local input_attr = lfs.attributes(input_file)
     35         local vocab_attr = lfs.attributes(vocab_file)
     36         local tensor_attr = lfs.attributes(tensor_file)
     37         if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then
     38             print('vocab.t7 or data.t7 detected as stale. Re-running preprocessing...')
     39             run_prepro = true
     40         end
     41     end
     42     if run_prepro then
     43         -- construct a tensor with all the data, and vocab file
     44         print('one-time setup: preprocessing input text file ' .. input_file .. '...')
     45         CharSplitLMMinibatchLoader.text_to_tensor(word_level, threshold, input_file, vocab_file, tensor_file)
     46     end
     47 
     48     print('loading data files...')
     49     local data = torch.load(tensor_file)
     50     self.vocab_mapping = torch.load(vocab_file)
     51 
     52     -- cut off the end so that it divides evenly
     53     local len = data:size(1)
     54     if len % (batch_size * seq_length) ~= 0 then
     55         print('cutting off end of data so that the batches/sequences divide evenly')
     56         data = data:sub(1, batch_size * seq_length 
     57                     * math.floor(len / (batch_size * seq_length)))
     58     end
     59 
     60     -- count vocab
     61     self.vocab_size = 0
     62     for _ in pairs(self.vocab_mapping) do 
     63         self.vocab_size = self.vocab_size + 1 
     64     end
     65 
     66     -- self.batches is a table of tensors
     67     print('reshaping tensor...')
     68     self.batch_size = batch_size
     69     self.seq_length = seq_length
     70 
     71     local ydata = data:clone()
     72     ydata:sub(1,-2):copy(data:sub(2,-1))
     73     ydata[-1] = data[1]
     74     self.x_batches = data:view(batch_size, -1):split(seq_length, 2)  -- #rows = #batches
     75     self.nbatches = #self.x_batches
     76     self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2)  -- #rows = #batches
     77     assert(#self.x_batches == #self.y_batches)
     78 
     79     -- lets try to be helpful here
     80     if self.nbatches < 50 then
     81         print('WARNING: less than 50 batches in the data in total? Looks like very small dataset. You probably want to use smaller batch_size and/or seq_length.')
     82     end
     83 
     84     -- perform safety checks on split_fractions
     85     assert(split_fractions[1] >= 0 and split_fractions[1] <= 1, 'bad split fraction ' .. split_fractions[1] .. ' for train, not between 0 and 1')
     86     assert(split_fractions[2] >= 0 and split_fractions[2] <= 1, 'bad split fraction ' .. split_fractions[2] .. ' for val, not between 0 and 1')
     87     assert(split_fractions[3] >= 0 and split_fractions[3] <= 1, 'bad split fraction ' .. split_fractions[3] .. ' for test, not between 0 and 1')
     88     if split_fractions[3] == 0 then 
     89         -- catch a common special case where the user might not want a test set
     90         self.ntrain = math.floor(self.nbatches * split_fractions[1])
     91         self.nval = self.nbatches - self.ntrain
     92         self.ntest = 0
     93     else
     94         -- divide data to train/val and allocate rest to test
     95         self.ntrain = math.floor(self.nbatches * split_fractions[1])
     96         self.nval = math.floor(self.nbatches * split_fractions[2])
     97         self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly)
     98     end
     99 
    100     self.split_sizes = {self.ntrain, self.nval, self.ntest}
    101     self.batch_ix = {0,0,0}
    102 
    103     print(string.format('data load done. Number of data batches in train: %d, val: %d, test: %d', self.ntrain, self.nval, self.ntest))
    104     collectgarbage()
    105     return self
    106 end
    107 
    108 function CharSplitLMMinibatchLoader:reset_batch_pointer(split_index, batch_index)
    109     batch_index = batch_index or 0
    110     self.batch_ix[split_index] = batch_index
    111 end
    112 
    113 function CharSplitLMMinibatchLoader:next_batch(split_index)
    114     if self.split_sizes[split_index] == 0 then
    115         -- perform a check here to make sure the user isn't screwing something up
    116         local split_names = {'train', 'val', 'test'}
    117         print('ERROR. Code requested a batch for split ' .. split_names[split_index] .. ', but this split has no data.')
    118         os.exit() -- crash violently
    119     end
    120     -- split_index is integer: 1 = train, 2 = val, 3 = test
    121     self.batch_ix[split_index] = self.batch_ix[split_index] + 1
    122     if self.batch_ix[split_index] > self.split_sizes[split_index] then
    123         self.batch_ix[split_index] = 1 -- cycle around to beginning
    124     end
    125     -- pull out the correct next batch
    126     local ix = self.batch_ix[split_index]
    127     if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size
    128     if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + val
    129     return self.x_batches[ix], self.y_batches[ix]
    130 end
    131 
    132 -- *** STATIC method ***
    133 function CharSplitLMMinibatchLoader.text_to_tensor(word_level, threshold, in_textfile, out_vocabfile, out_tensorfile)
    134     local timer = torch.Timer()
    135     print('loading text file...')
    136     local f = torch.DiskFile(in_textfile)
    137     local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once
    138     f:close()
    139 
    140     -- create vocabulary if it doesn't exist yet
    141     print('creating vocabulary mapping...')
    142     print('word occurence threshold is ' .. threshold)
    143     -- record all characters to a set
    144     local unordered = {}
    145     --rawdata = re.sub('([%s])' % (re.escape(string.punctuation)+"1234567890"), r" \1 ", rawdata)
    146     local numtokens = 0    
    147     for token in CharSplitLMMinibatchLoader.tokens(rawdata, word_level) do
    148         if not unordered[token] then 
    149             unordered[token] = 1 
    150         else
    151             unordered[token] = unordered[token] + 1 
    152         end
    153         numtokens = numtokens + 1
    154     end
    155     -- sort into a table (i.e. keys become 1..N)
    156     local ordered = {}
    157     for token, count in pairs(unordered) do 
    158         if count > threshold then
    159             ordered[#ordered + 1] = token 
    160         end
    161     end
    162     if word_level then
    163         ordered[#ordered + 1] = "UNK" --represents unknown words
    164     end
    165     table.sort(ordered)
    166     -- invert `ordered` to create the char->int mapping
    167     local vocab_mapping = {}
    168     for i, char in ipairs(ordered) do
    169         vocab_mapping[char] = i
    170     end
    171     -- construct a tensor with all the data
    172     print('putting data into tensor...')
    173     local data = word_level and torch.ShortTensor(numtokens) or torch.ByteTensor(#rawdata) -- store it into 1D first, then rearrange
    174     if word_level then
    175         local i = 1
    176         for token in CharSplitLMMinibatchLoader.tokens(rawdata, word_level) do
    177             data[i] = vocab_mapping[token] or vocab_mapping["UNK"]
    178             i = i + 1
    179         end
    180     else
    181 		for i=1, #rawdata do
    182 			data[i] = vocab_mapping[rawdata:sub(i, i)] -- lua has no string indexing using []
    183 		end
    184     end
    185 
    186     -- save output preprocessed files
    187     print('saving ' .. out_vocabfile)
    188     torch.save(out_vocabfile, vocab_mapping)
    189     print('saving ' .. out_tensorfile)
    190     torch.save(out_tensorfile, data)
    191 end
    192 
    193 function CharSplitLMMinibatchLoader.tokens(rawstr, word_level)
    194     if word_level then
    195         --local str, _, _ = rex.gsub(rawstr, '[[:punct:][:digit:]]', ' %0 ')
    196         --str, _, _ = rex.gsub(str, '\\n', ' RN ')
    197         --return rex.split(str, "\\s+")
    198         return word_iter(rawstr)
    199     else
    200         return rawstr:gmatch'.'
    201     end
    202 end
    203 
    204 function word_iter(str)
    205     local n = str:len()
    206     local punctdigit = rex.new('[[:punct:][:digit:]]')
    207     local newline = rex.new('\\n')
    208     local whitespace = rex.new('[ \\t]') --dont match newlines
    209     local char_iter = str:gmatch'.'
    210     local c = char_iter()
    211     return function()
    212         if c == nil then return nil end
    213         while rex.count(c, whitespace) > 0 do
    214             c = char_iter()
    215             if c == nil then return nil end
    216         end
    217         if rex.count(c, punctdigit) > 0 then
    218             local ret = c
    219             c = char_iter()
    220             return ret
    221         end
    222         if rex.count(c, newline) > 0 then
    223             c = char_iter()
    224             return '\n'
    225         end
    226         local word = ''
    227         repeat
    228             word = word .. c
    229             c = char_iter()
    230             if c == nil then return word end
    231         until rex.count(c, whitespace) > 0 or rex.count(c, punctdigit) > 0 or rex.count(c, newline) > 0
    232         
    233         return word
    234     end
    235 end
    236 
    237 return CharSplitLMMinibatchLoader
    238 
    239 
    240 
    241 
    242 
    243 
    244 
    245 
    246 
    247 
    248