‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

train.lua (20096B)


      1 
      2 --[[
      3 
      4 This file trains a character-level multi-layer RNN on text data
      5 
      6 Code is based on implementation in 
      7 https://github.com/oxford-cs-ml-2015/practical6
      8 but modified to have multi-layer support, GPU support, as well as
      9 many other common model/optimization bells and whistles.
     10 The practical6 code is in turn based on 
     11 https://github.com/wojciechz/learning_to_execute
     12 which is turn based on other stuff in Torch, etc... (long lineage)
     13 
     14 ]]--
     15 
     16 require 'torch'
     17 require 'nn'
     18 require 'nngraph'
     19 require 'optim'
     20 require 'lfs'
     21 
     22 require 'util.OneHot'
     23 require 'util.GloVeEmbedding'
     24 require 'util.misc'
     25 local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader'
     26 local model_utils = require 'util.model_utils'
     27 local LSTM = require 'model.LSTM'
     28 local GRU = require 'model.GRU'
     29 local RNN = require 'model.RNN'
     30 local IRNN = require 'model.IRNN'
     31 
     32 cmd = torch.CmdLine()
     33 cmd:text()
     34 cmd:text('Train a character-level language model')
     35 cmd:text()
     36 cmd:text('Options')
     37 -- data
     38 cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data')
     39 -- model params
     40 cmd:option('-rnn_size', 128, 'size of LSTM internal state')
     41 cmd:option('-num_layers', 2, 'number of layers in the LSTM')
     42 cmd:option('-num_fixed', 0 ,'number of recurrent layers to remain fixed (untrained), pretrained (LSTM only)')
     43 cmd:option('-model', 'lstm', 'lstm, gru, rnn or irnn')
     44 -- optimization
     45 cmd:option('-learning_rate',2e-3,'learning rate')
     46 cmd:option('-learning_rate_decay',0.97,'learning rate decay')
     47 cmd:option('-learning_rate_decay_after',10,'in number of epochs, when to start decaying the learning rate')
     48 cmd:option('-decay_rate',0.95,'decay rate for rmsprop')
     49 cmd:option('-dropout',0,'dropout for regularization, used after each RNN hidden layer. 0 = no dropout, .3 = 30% dropout')
     50 cmd:option('-recurrent_dropout',0,'dropout for regularization, used on recurrent connections. 0 = no dropout')
     51 cmd:option('-seq_length',50,'number of timesteps to unroll for')
     52 cmd:option('-batch_size',50,'number of sequences to train on in parallel')
     53 cmd:option('-max_epochs',50,'number of full passes through the training data')
     54 cmd:option('-grad_clip',5,'clip gradients at this value')
     55 cmd:option('-train_frac',0.95,'fraction of data that goes into train set')
     56 cmd:option('-val_frac',0.05,'fraction of data that goes into validation set')
     57             -- test_frac will be computed as (1 - train_frac - val_frac)
     58 cmd:option('-init_from', '', 'initialize network parameters from checkpoint at this path')
     59 -- bookkeeping
     60 cmd:option('-seed',123,'torch manual random number generator seed')
     61 cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss')
     62 cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?')
     63 cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
     64 cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
     65 cmd:option('-accurate_gpu_timing',0,'set this flag to 1 to get precise timings when using GPU. Might make code bit slower but reports accurate timings.')
     66 -- GPU/CPU
     67 cmd:option('-gpuid',-1,'which gpu to use. -1 = use CPU')
     68 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
     69 cmd:option('-word_level',1,'whether to operate on the word level, instead of character level (0: use chars, 1: use words)')
     70 cmd:option('-threshold',0,'minimum number of occurences a token must have to be included (ignored if -word_level is 0)')
     71 cmd:option('-glove',0,'whether or not to use GloVe embeddings')
     72 cmd:option('-optimizer','eamsgd','which optimizer to use: adam or rmsprop')
     73 
     74 cmd:text()
     75 
     76 -- parse input params
     77 opt = cmd:parse(arg)
     78 torch.manualSeed(opt.seed)
     79 -- train / val / test split for data, in fractions
     80 local test_frac = math.max(0, 1 - (opt.train_frac + opt.val_frac))
     81 local split_sizes = {opt.train_frac, opt.val_frac, test_frac} 
     82 
     83 -- initialize cunn/cutorch for training on the GPU and fall back to CPU gracefully
     84 if opt.gpuid >= 0 and opt.opencl == 0 then
     85     local ok, cunn = pcall(require, 'cunn')
     86     local ok2, cutorch = pcall(require, 'cutorch')
     87     if not ok then print('package cunn not found!') end
     88     if not ok2 then print('package cutorch not found!') end
     89     if ok and ok2 then
     90         print('using CUDA on GPU ' .. opt.gpuid .. '...')
     91         cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
     92         cutorch.manualSeed(opt.seed)
     93     else
     94         print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.')
     95         print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.')
     96         print('Falling back on CPU mode')
     97         opt.gpuid = -1 -- overwrite user setting
     98     end
     99 end
    100 
    101 -- initialize clnn/cltorch for training on the GPU and fall back to CPU gracefully
    102 if opt.gpuid >= 0 and opt.opencl == 1 then
    103     local ok, cunn = pcall(require, 'clnn')
    104     local ok2, cutorch = pcall(require, 'cltorch')
    105     if not ok then print('package clnn not found!') end
    106     if not ok2 then print('package cltorch not found!') end
    107     if ok and ok2 then
    108         print('using OpenCL on GPU ' .. opt.gpuid .. '...')
    109         cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
    110         torch.manualSeed(opt.seed)
    111     else
    112         print('If cltorch and clnn are installed, your OpenCL driver may be improperly configured.')
    113         print('Check your OpenCL driver installation, check output of clinfo command, and try again.')
    114         print('Falling back on CPU mode')
    115         opt.gpuid = -1 -- overwrite user setting
    116     end
    117 end
    118 
    119 require 'util.SharedDropout'
    120 
    121 -- create the data loader class
    122 local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes, opt.word_level == 1, opt.threshold)
    123 local vocab_size = loader.vocab_size  -- the number of distinct characters
    124 local vocab = loader.vocab_mapping
    125 local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate }
    126 print('vocab size: ' .. vocab_size)
    127 -- make sure output directory exists
    128 if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end
    129 
    130 -- define the model: prototypes for one timestep, then clone them in time
    131 local h2hs = nil
    132 if string.len(opt.init_from) > 0 then
    133     print('loading a model from checkpoint ' .. opt.init_from)
    134     local checkpoint = torch.load(opt.init_from)
    135     protos = checkpoint.protos
    136     optim_state = checkpoint.optim_state
    137     optim_state.learningRate = opt.learning_rate
    138     -- make sure the vocabs are the same
    139     local vocab_compatible = true
    140     local checkpoint_vocab_size = 0
    141     for c,i in pairs(checkpoint.vocab) do
    142         if not (vocab[c] == i) then
    143             vocab_compatible = false
    144         end
    145         checkpoint_vocab_size = checkpoint_vocab_size + 1
    146     end
    147     if not (checkpoint_vocab_size == vocab_size) then
    148         vocab_compatible = false
    149         print('checkpoint_vocab_size: ' .. checkpoint_vocab_size)
    150     end
    151     assert(vocab_compatible, 'error, the character vocabulary for this dataset and the one in the saved checkpoint are not the same. This is trouble.')
    152     -- overwrite model settings based on checkpoint to ensure compatibility
    153     print('overwriting rnn_size=' .. checkpoint.opt.rnn_size .. ', num_layers=' .. checkpoint.opt.num_layers .. ', model=' .. checkpoint.opt.model .. ' based on the checkpoint.')
    154     opt.rnn_size = checkpoint.opt.rnn_size
    155     opt.num_layers = checkpoint.opt.num_layers
    156     opt.optimizer = checkpoint.optimizer
    157     opt.model = checkpoint.opt.model
    158 else
    159     print('creating an ' .. opt.model .. ' with ' .. opt.num_layers .. ' layers')
    160     protos = {}
    161     local embedding = nil
    162     if opt.glove == 1 then
    163         embedding = GloVeEmbedding(vocab, 200, opt.data_dir) --GloVeEmbeddingFixed(vocab, 200, opt.data_dir)
    164     end
    165     if opt.model == 'lstm' then
    166 	protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, opt.recurrent_dropout, embedding, opt.num_fixed)
    167     elseif opt.model == 'gru' then
    168         protos.rnn = GRU.gru(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, embedding)
    169     elseif opt.model == 'rnn' then
    170         protos.rnn = RNN.rnn(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, embedding)
    171     elseif opt.model == 'irnn' then
    172         protos.rnn, h2hs = IRNN.rnn(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, embedding)
    173     end
    174     protos.criterion = nn.ClassNLLCriterion()
    175 
    176     --local clusters = {}
    177     --for w,i in pairs(vocab) do 
    178     --    clusters[#clusters+1] = {1, i}
    179     --end
    180     --protos.criterion = nn.HSM(torch.Tensor(clusters), opt.rnn_size, 0) --vocab['UNK'])
    181 end
    182 
    183 print('using optimizer ' .. opt.optimizer)
    184 -- the initial state of the cell/hidden states
    185 init_state = {}
    186 for L=1,opt.num_layers do
    187     local h_init = torch.zeros(opt.batch_size, opt.rnn_size)
    188     if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end
    189     if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end
    190     table.insert(init_state, h_init:clone())
    191     if opt.model == 'lstm' then
    192         table.insert(init_state, h_init:clone())
    193     end
    194 end
    195 
    196 -- ship the model to the GPU if desired
    197 if opt.gpuid >= 0 and opt.opencl == 0 then
    198     for k,v in pairs(protos) do v:cuda() end
    199 end
    200 if opt.gpuid >= 0 and opt.opencl == 1 then
    201     for k,v in pairs(protos) do v:cl() end
    202 end
    203 
    204 -- put the above things into one flattened parameters tensor
    205 params, grad_params = model_utils.combine_all_parameters(protos.rnn)
    206 
    207 -- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning
    208 if opt.model == 'lstm' and string.len(opt.init_from) == 0 then
    209     for layer_idx = 1, opt.num_layers do
    210         for _,node in ipairs(protos.rnn.forwardnodes) do
    211             if node.data.annotations.name == "i2h_" .. layer_idx and layer_idx > opt.num_fixed then
    212                 print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx)
    213                 -- the gates are, in order, i,f,o,g, so f is the 2nd block of weights
    214                 node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0)
    215             end
    216         end
    217     end
    218 end
    219 
    220 print('number of parameters in the model: ' .. params:nElement())
    221 -- make a bunch of clones after flattening, as that reallocates memory
    222 clones = {}
    223 for name,proto in pairs(protos) do
    224     print('cloning ' .. name)
    225     clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters)
    226 end
    227 
    228 -- preprocessing helper function
    229 function prepro(x,y)
    230     x = x:transpose(1,2):contiguous() -- swap the axes for faster indexing
    231     y = y:transpose(1,2):contiguous()
    232     if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU
    233         -- have to convert to float because integers can't be cuda()'d
    234         x = x:float():cuda()
    235         y = y:float():cuda()
    236     end
    237     if opt.gpuid >= 0 and opt.opencl == 1 then -- ship the input arrays to GPU
    238         x = x:cl()
    239         y = y:cl()
    240     end
    241     return x,y
    242 end
    243 
    244 -- evaluate the loss over an entire split
    245 function eval_split(split_index, max_batches)
    246     print('evaluating loss over split index ' .. split_index)
    247     local n = loader.split_sizes[split_index]
    248     if max_batches ~= nil then n = math.min(max_batches, n) end
    249 
    250     loader:reset_batch_pointer(split_index) -- move batch iteration pointer for this split to front
    251     local loss = 0
    252     local rnn_state = {[0] = init_state}
    253     
    254     for i = 1,n do -- iterate over batches in the split
    255         -- fetch a batch
    256         local x, y = loader:next_batch(split_index)
    257         x,y = prepro(x,y)
    258         -- forward pass
    259         for t=1,opt.seq_length do
    260             clones.rnn[t]:evaluate() -- for dropout proper functioning
    261             local lst = clones.rnn[t]:forward{x[t], unpack(rnn_state[t-1])}
    262             rnn_state[t] = {}
    263             for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end
    264             prediction = lst[#lst] 
    265             loss = loss + clones.criterion[t]:forward(prediction, y[t])
    266         end
    267         -- carry over lstm state
    268         rnn_state[0] = rnn_state[#rnn_state]
    269         print(i .. '/' .. n .. '...')
    270     end
    271 
    272     loss = loss / opt.seq_length / n
    273     return loss
    274 end
    275 
    276 -- do fwd/bwd and return loss, grad_params
    277 local init_state_global = clone_list(init_state)
    278 function feval(x)
    279     if x ~= params then
    280         params:copy(x)
    281     end
    282     grad_params:zero()
    283 
    284     ------------------ get minibatch -------------------
    285     local x, y = loader:next_batch(1)
    286     x,y = prepro(x,y)
    287     ------------------- forward pass -------------------
    288     if opt.recurrent_dropout ~= 0 then 
    289         --todo: these are shared across all layers in depth also. that's not optimal
    290         SharedDropout_noise:resize(opt.batch_size, opt.rnn_size)
    291         SharedDropout_noise:bernoulli(1 - opt.recurrent_dropout)
    292         SharedDropout_noise:div(1 - opt.recurrent_dropout)
    293     end
    294     local rnn_state = {[0] = init_state_global}
    295     local predictions = {}           -- softmax outputs
    296     local loss = 0
    297     for t=1,opt.seq_length do
    298         clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
    299         local lst = clones.rnn[t]:forward{x[t], unpack(rnn_state[t-1])}
    300         rnn_state[t] = {}
    301         for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
    302         predictions[t] = lst[#lst] -- last element is the prediction
    303         loss = loss + clones.criterion[t]:forward(predictions[t], y[t])
    304     end
    305     loss = loss / opt.seq_length
    306     ------------------ backward pass -------------------
    307     -- initialize gradient at time t to be zeros (there's no influence from future)
    308     local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones
    309     for t=opt.seq_length,1,-1 do
    310         -- backprop through loss, and softmax/linear
    311         local doutput_t = clones.criterion[t]:backward(predictions[t], y[t])
    312         table.insert(drnn_state[t], doutput_t)
    313         local dlst = clones.rnn[t]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t])
    314         drnn_state[t-1] = {}
    315         for k,v in pairs(dlst) do
    316             if k > 1 then -- k == 1 is gradient on x, which we dont need
    317                 -- note we do k-1 because first item is dembeddings, and then follow the 
    318                 -- derivatives of the state, starting at index 2. I know...
    319                 drnn_state[t-1][k-1] = v
    320             end
    321         end
    322     end
    323     ------------------------ misc ----------------------
    324     -- transfer final state to initial state (BPTT)
    325     init_state_global = rnn_state[#rnn_state] -- NOTE: I don't think this needs to be a clone, right?
    326     -- grad_params:div(opt.seq_length) -- this line should be here but since we use rmsprop it would have no effect. Removing for efficiency
    327     -- clip gradient element-wise
    328     grad_params:clamp(-opt.grad_clip, opt.grad_clip)
    329     return loss, grad_params
    330 end
    331 
    332 -- start optimization here
    333 train_losses = {}
    334 val_losses = {}
    335 local iterations = math.floor(opt.max_epochs * loader.ntrain)
    336 local iterations_per_epoch = loader.ntrain
    337 local loss0 = nil
    338 
    339 local optimizer = nil
    340 
    341 if opt.optimizer == 'adam' then
    342     optimizer = optim.adam
    343 elseif opt.optimizer == 'sgd' then
    344     optimizer = optim.sgd
    345     optim_state.learningRateDecay = opt.decay_rate
    346     optim_state.momentum = 0.99
    347     optim_state.nesterov = true
    348     optim_state.dampening = 0
    349 elseif opt.optimizer == 'eamsgd' then
    350     optimizer = optim.eamsgd
    351     optim_state.learningRate = mpiOptions.learningRate
    352     optim_state.momentum = mpiOptions.momentum
    353     optim_state.pclient = mpiOptions.pclient
    354     optim_state.communicationPeriod = mpiOptions.communicationPeriod
    355     optim_state.movingRateAlpha = mpiOptions.movingRateAlpha
    356     optim_state.learningRateDecay = mpiOptions.learningRateDecay
    357     optim_state.learningRateDecayPower = mpiOptions.learningRateDecayPower
    358 else
    359     optimizer = optim.rmsprop
    360 end
    361 
    362 -- initialize MPI optimizer clients
    363 rank = mpiOptions.rank or -1
    364 pclient = mpiOptions.pclient or nil
    365 print('i am ' .. rank .. ' ready to run')
    366 if pclient then
    367    pclient:start(params,grad_params)
    368    assert(rank == pclient.rank)
    369    print('pc ' .. rank .. ' started')
    370 end
    371 
    372 -- run optimizer
    373 sys.tic() -- time the training procedure
    374 for i = 1, iterations do
    375     local epoch = i / loader.ntrain
    376 
    377     local timer = torch.Timer()
    378 
    379     local _, loss = optimizer(feval, params, optim_state)
    380     if opt.accurate_gpu_timing == 1 and opt.gpuid >= 0 then
    381         --[[
    382         Note on timing: The reported time can be off because the GPU is invoked async. If one
    383         wants to have exactly accurate timings one must call cutorch.synchronize() right here.
    384         I will avoid doing so by default because this can incur computational overhead.
    385         --]]
    386         cutorch.synchronize()
    387     end
    388     local time = timer:time().real
    389 
    390     local train_loss = loss[1] -- the loss is inside a list, pop it
    391     train_losses[i] = train_loss
    392 
    393     -- exponential learning rate decay for rmsprop
    394     if opt.optimizer == 'rmsprop' and i % loader.ntrain == 0 and opt.learning_rate_decay < 1 then
    395         if epoch >= opt.learning_rate_decay_after then
    396             local decay_factor = opt.learning_rate_decay
    397             optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it
    398             print('decayed learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate)
    399         end
    400     end
    401 
    402     -- every now and then or on last iteration
    403     local eval_multiplier = 1
    404     if epoch < 10 then
    405         eval_multiplier = 1 --increase this to eval less often in the first iterations
    406     end
    407     print('current iteration: ' .. i)
    408     print('total iterations:  ' .. iterations)
    409     if i % (opt.eval_val_every * eval_multiplier) == 0 or i == iterations then
    410         -- evaluate loss on validation data
    411         local val_loss = eval_split(2) -- 2 = validation
    412         val_losses[i] = val_loss
    413 
    414         local savefile = string.format('%s/lm_%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss)
    415         print('saving checkpoint to ' .. savefile)
    416         local checkpoint = {}
    417         checkpoint.protos = protos
    418         checkpoint.opt = opt
    419         checkpoint.train_losses = train_losses
    420         checkpoint.val_loss = val_loss
    421         checkpoint.val_losses = val_losses
    422         checkpoint.i = i
    423         checkpoint.epoch = epoch
    424         checkpoint.vocab = loader.vocab_mapping
    425         --checkpoint.optim_state = optim_state
    426         --checkpoint.optimizer = opt.optimizer
    427         torch.save(savefile, checkpoint)
    428     end
    429 
    430     if i % opt.print_every == 0 then
    431     machine_name = io.popen('hostname -s'):read()
    432         print(string.format("%s Rank %s %d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.4fs",machine_name, rank, i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time))
    433     end
    434 
    435     if i % (opt.print_every*10) == 0 then
    436         --print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.4fs", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time))
    437         --e, V = torch.eig(h2hs[1].data.module.weight:float(), 'N')
    438         --print(e[1])
    439         --e, V = torch.eig(h2hs[2].data.module.weight:float(), 'N')
    440         --print(e[1])
    441         --e, V = torch.eig(h2hs[3].data.module.weight:float(), 'N')
    442         --print(e[1])
    443     end
    444    
    445     if i % 10 == 0 then collectgarbage() end
    446 
    447     -- handle early stopping if things are going really bad
    448     if loss[1] ~= loss[1] then
    449         print('loss is NaN.  This usually indicates a bug.  Please check the issues page for existing issues, or create a new issue, if none exist.  Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?')
    450         break -- halt
    451     end
    452     if loss0 == nil then loss0 = loss[1] end
    453     if loss[1] > loss0 * 100 then
    454         print(string.format("loss is exploding, aborting. (%6.2f vs %6.2f)", loss0, loss[1]))
    455         break -- halt
    456     end
    457 end
    458 
    459 -- stop optimizer clients
    460 if pclient then
    461    pclient:stop()
    462 end
    463 
    464 print(rank,'total training time is', sys.toc())