train.lua (20096B)
1 2 --[[ 3 4 This file trains a character-level multi-layer RNN on text data 5 6 Code is based on implementation in 7 https://github.com/oxford-cs-ml-2015/practical6 8 but modified to have multi-layer support, GPU support, as well as 9 many other common model/optimization bells and whistles. 10 The practical6 code is in turn based on 11 https://github.com/wojciechz/learning_to_execute 12 which is turn based on other stuff in Torch, etc... (long lineage) 13 14 ]]-- 15 16 require 'torch' 17 require 'nn' 18 require 'nngraph' 19 require 'optim' 20 require 'lfs' 21 22 require 'util.OneHot' 23 require 'util.GloVeEmbedding' 24 require 'util.misc' 25 local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader' 26 local model_utils = require 'util.model_utils' 27 local LSTM = require 'model.LSTM' 28 local GRU = require 'model.GRU' 29 local RNN = require 'model.RNN' 30 local IRNN = require 'model.IRNN' 31 32 cmd = torch.CmdLine() 33 cmd:text() 34 cmd:text('Train a character-level language model') 35 cmd:text() 36 cmd:text('Options') 37 -- data 38 cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data') 39 -- model params 40 cmd:option('-rnn_size', 128, 'size of LSTM internal state') 41 cmd:option('-num_layers', 2, 'number of layers in the LSTM') 42 cmd:option('-num_fixed', 0 ,'number of recurrent layers to remain fixed (untrained), pretrained (LSTM only)') 43 cmd:option('-model', 'lstm', 'lstm, gru, rnn or irnn') 44 -- optimization 45 cmd:option('-learning_rate',2e-3,'learning rate') 46 cmd:option('-learning_rate_decay',0.97,'learning rate decay') 47 cmd:option('-learning_rate_decay_after',10,'in number of epochs, when to start decaying the learning rate') 48 cmd:option('-decay_rate',0.95,'decay rate for rmsprop') 49 cmd:option('-dropout',0,'dropout for regularization, used after each RNN hidden layer. 0 = no dropout, .3 = 30% dropout') 50 cmd:option('-recurrent_dropout',0,'dropout for regularization, used on recurrent connections. 0 = no dropout') 51 cmd:option('-seq_length',50,'number of timesteps to unroll for') 52 cmd:option('-batch_size',50,'number of sequences to train on in parallel') 53 cmd:option('-max_epochs',50,'number of full passes through the training data') 54 cmd:option('-grad_clip',5,'clip gradients at this value') 55 cmd:option('-train_frac',0.95,'fraction of data that goes into train set') 56 cmd:option('-val_frac',0.05,'fraction of data that goes into validation set') 57 -- test_frac will be computed as (1 - train_frac - val_frac) 58 cmd:option('-init_from', '', 'initialize network parameters from checkpoint at this path') 59 -- bookkeeping 60 cmd:option('-seed',123,'torch manual random number generator seed') 61 cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') 62 cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') 63 cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') 64 cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') 65 cmd:option('-accurate_gpu_timing',0,'set this flag to 1 to get precise timings when using GPU. Might make code bit slower but reports accurate timings.') 66 -- GPU/CPU 67 cmd:option('-gpuid',-1,'which gpu to use. -1 = use CPU') 68 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') 69 cmd:option('-word_level',1,'whether to operate on the word level, instead of character level (0: use chars, 1: use words)') 70 cmd:option('-threshold',0,'minimum number of occurences a token must have to be included (ignored if -word_level is 0)') 71 cmd:option('-glove',0,'whether or not to use GloVe embeddings') 72 cmd:option('-optimizer','eamsgd','which optimizer to use: adam or rmsprop') 73 74 cmd:text() 75 76 -- parse input params 77 opt = cmd:parse(arg) 78 torch.manualSeed(opt.seed) 79 -- train / val / test split for data, in fractions 80 local test_frac = math.max(0, 1 - (opt.train_frac + opt.val_frac)) 81 local split_sizes = {opt.train_frac, opt.val_frac, test_frac} 82 83 -- initialize cunn/cutorch for training on the GPU and fall back to CPU gracefully 84 if opt.gpuid >= 0 and opt.opencl == 0 then 85 local ok, cunn = pcall(require, 'cunn') 86 local ok2, cutorch = pcall(require, 'cutorch') 87 if not ok then print('package cunn not found!') end 88 if not ok2 then print('package cutorch not found!') end 89 if ok and ok2 then 90 print('using CUDA on GPU ' .. opt.gpuid .. '...') 91 cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 92 cutorch.manualSeed(opt.seed) 93 else 94 print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.') 95 print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.') 96 print('Falling back on CPU mode') 97 opt.gpuid = -1 -- overwrite user setting 98 end 99 end 100 101 -- initialize clnn/cltorch for training on the GPU and fall back to CPU gracefully 102 if opt.gpuid >= 0 and opt.opencl == 1 then 103 local ok, cunn = pcall(require, 'clnn') 104 local ok2, cutorch = pcall(require, 'cltorch') 105 if not ok then print('package clnn not found!') end 106 if not ok2 then print('package cltorch not found!') end 107 if ok and ok2 then 108 print('using OpenCL on GPU ' .. opt.gpuid .. '...') 109 cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 110 torch.manualSeed(opt.seed) 111 else 112 print('If cltorch and clnn are installed, your OpenCL driver may be improperly configured.') 113 print('Check your OpenCL driver installation, check output of clinfo command, and try again.') 114 print('Falling back on CPU mode') 115 opt.gpuid = -1 -- overwrite user setting 116 end 117 end 118 119 require 'util.SharedDropout' 120 121 -- create the data loader class 122 local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes, opt.word_level == 1, opt.threshold) 123 local vocab_size = loader.vocab_size -- the number of distinct characters 124 local vocab = loader.vocab_mapping 125 local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate } 126 print('vocab size: ' .. vocab_size) 127 -- make sure output directory exists 128 if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end 129 130 -- define the model: prototypes for one timestep, then clone them in time 131 local h2hs = nil 132 if string.len(opt.init_from) > 0 then 133 print('loading a model from checkpoint ' .. opt.init_from) 134 local checkpoint = torch.load(opt.init_from) 135 protos = checkpoint.protos 136 optim_state = checkpoint.optim_state 137 optim_state.learningRate = opt.learning_rate 138 -- make sure the vocabs are the same 139 local vocab_compatible = true 140 local checkpoint_vocab_size = 0 141 for c,i in pairs(checkpoint.vocab) do 142 if not (vocab[c] == i) then 143 vocab_compatible = false 144 end 145 checkpoint_vocab_size = checkpoint_vocab_size + 1 146 end 147 if not (checkpoint_vocab_size == vocab_size) then 148 vocab_compatible = false 149 print('checkpoint_vocab_size: ' .. checkpoint_vocab_size) 150 end 151 assert(vocab_compatible, 'error, the character vocabulary for this dataset and the one in the saved checkpoint are not the same. This is trouble.') 152 -- overwrite model settings based on checkpoint to ensure compatibility 153 print('overwriting rnn_size=' .. checkpoint.opt.rnn_size .. ', num_layers=' .. checkpoint.opt.num_layers .. ', model=' .. checkpoint.opt.model .. ' based on the checkpoint.') 154 opt.rnn_size = checkpoint.opt.rnn_size 155 opt.num_layers = checkpoint.opt.num_layers 156 opt.optimizer = checkpoint.optimizer 157 opt.model = checkpoint.opt.model 158 else 159 print('creating an ' .. opt.model .. ' with ' .. opt.num_layers .. ' layers') 160 protos = {} 161 local embedding = nil 162 if opt.glove == 1 then 163 embedding = GloVeEmbedding(vocab, 200, opt.data_dir) --GloVeEmbeddingFixed(vocab, 200, opt.data_dir) 164 end 165 if opt.model == 'lstm' then 166 protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, opt.recurrent_dropout, embedding, opt.num_fixed) 167 elseif opt.model == 'gru' then 168 protos.rnn = GRU.gru(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, embedding) 169 elseif opt.model == 'rnn' then 170 protos.rnn = RNN.rnn(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, embedding) 171 elseif opt.model == 'irnn' then 172 protos.rnn, h2hs = IRNN.rnn(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, embedding) 173 end 174 protos.criterion = nn.ClassNLLCriterion() 175 176 --local clusters = {} 177 --for w,i in pairs(vocab) do 178 -- clusters[#clusters+1] = {1, i} 179 --end 180 --protos.criterion = nn.HSM(torch.Tensor(clusters), opt.rnn_size, 0) --vocab['UNK']) 181 end 182 183 print('using optimizer ' .. opt.optimizer) 184 -- the initial state of the cell/hidden states 185 init_state = {} 186 for L=1,opt.num_layers do 187 local h_init = torch.zeros(opt.batch_size, opt.rnn_size) 188 if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end 189 if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end 190 table.insert(init_state, h_init:clone()) 191 if opt.model == 'lstm' then 192 table.insert(init_state, h_init:clone()) 193 end 194 end 195 196 -- ship the model to the GPU if desired 197 if opt.gpuid >= 0 and opt.opencl == 0 then 198 for k,v in pairs(protos) do v:cuda() end 199 end 200 if opt.gpuid >= 0 and opt.opencl == 1 then 201 for k,v in pairs(protos) do v:cl() end 202 end 203 204 -- put the above things into one flattened parameters tensor 205 params, grad_params = model_utils.combine_all_parameters(protos.rnn) 206 207 -- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning 208 if opt.model == 'lstm' and string.len(opt.init_from) == 0 then 209 for layer_idx = 1, opt.num_layers do 210 for _,node in ipairs(protos.rnn.forwardnodes) do 211 if node.data.annotations.name == "i2h_" .. layer_idx and layer_idx > opt.num_fixed then 212 print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx) 213 -- the gates are, in order, i,f,o,g, so f is the 2nd block of weights 214 node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0) 215 end 216 end 217 end 218 end 219 220 print('number of parameters in the model: ' .. params:nElement()) 221 -- make a bunch of clones after flattening, as that reallocates memory 222 clones = {} 223 for name,proto in pairs(protos) do 224 print('cloning ' .. name) 225 clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) 226 end 227 228 -- preprocessing helper function 229 function prepro(x,y) 230 x = x:transpose(1,2):contiguous() -- swap the axes for faster indexing 231 y = y:transpose(1,2):contiguous() 232 if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU 233 -- have to convert to float because integers can't be cuda()'d 234 x = x:float():cuda() 235 y = y:float():cuda() 236 end 237 if opt.gpuid >= 0 and opt.opencl == 1 then -- ship the input arrays to GPU 238 x = x:cl() 239 y = y:cl() 240 end 241 return x,y 242 end 243 244 -- evaluate the loss over an entire split 245 function eval_split(split_index, max_batches) 246 print('evaluating loss over split index ' .. split_index) 247 local n = loader.split_sizes[split_index] 248 if max_batches ~= nil then n = math.min(max_batches, n) end 249 250 loader:reset_batch_pointer(split_index) -- move batch iteration pointer for this split to front 251 local loss = 0 252 local rnn_state = {[0] = init_state} 253 254 for i = 1,n do -- iterate over batches in the split 255 -- fetch a batch 256 local x, y = loader:next_batch(split_index) 257 x,y = prepro(x,y) 258 -- forward pass 259 for t=1,opt.seq_length do 260 clones.rnn[t]:evaluate() -- for dropout proper functioning 261 local lst = clones.rnn[t]:forward{x[t], unpack(rnn_state[t-1])} 262 rnn_state[t] = {} 263 for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end 264 prediction = lst[#lst] 265 loss = loss + clones.criterion[t]:forward(prediction, y[t]) 266 end 267 -- carry over lstm state 268 rnn_state[0] = rnn_state[#rnn_state] 269 print(i .. '/' .. n .. '...') 270 end 271 272 loss = loss / opt.seq_length / n 273 return loss 274 end 275 276 -- do fwd/bwd and return loss, grad_params 277 local init_state_global = clone_list(init_state) 278 function feval(x) 279 if x ~= params then 280 params:copy(x) 281 end 282 grad_params:zero() 283 284 ------------------ get minibatch ------------------- 285 local x, y = loader:next_batch(1) 286 x,y = prepro(x,y) 287 ------------------- forward pass ------------------- 288 if opt.recurrent_dropout ~= 0 then 289 --todo: these are shared across all layers in depth also. that's not optimal 290 SharedDropout_noise:resize(opt.batch_size, opt.rnn_size) 291 SharedDropout_noise:bernoulli(1 - opt.recurrent_dropout) 292 SharedDropout_noise:div(1 - opt.recurrent_dropout) 293 end 294 local rnn_state = {[0] = init_state_global} 295 local predictions = {} -- softmax outputs 296 local loss = 0 297 for t=1,opt.seq_length do 298 clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag) 299 local lst = clones.rnn[t]:forward{x[t], unpack(rnn_state[t-1])} 300 rnn_state[t] = {} 301 for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output 302 predictions[t] = lst[#lst] -- last element is the prediction 303 loss = loss + clones.criterion[t]:forward(predictions[t], y[t]) 304 end 305 loss = loss / opt.seq_length 306 ------------------ backward pass ------------------- 307 -- initialize gradient at time t to be zeros (there's no influence from future) 308 local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones 309 for t=opt.seq_length,1,-1 do 310 -- backprop through loss, and softmax/linear 311 local doutput_t = clones.criterion[t]:backward(predictions[t], y[t]) 312 table.insert(drnn_state[t], doutput_t) 313 local dlst = clones.rnn[t]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t]) 314 drnn_state[t-1] = {} 315 for k,v in pairs(dlst) do 316 if k > 1 then -- k == 1 is gradient on x, which we dont need 317 -- note we do k-1 because first item is dembeddings, and then follow the 318 -- derivatives of the state, starting at index 2. I know... 319 drnn_state[t-1][k-1] = v 320 end 321 end 322 end 323 ------------------------ misc ---------------------- 324 -- transfer final state to initial state (BPTT) 325 init_state_global = rnn_state[#rnn_state] -- NOTE: I don't think this needs to be a clone, right? 326 -- grad_params:div(opt.seq_length) -- this line should be here but since we use rmsprop it would have no effect. Removing for efficiency 327 -- clip gradient element-wise 328 grad_params:clamp(-opt.grad_clip, opt.grad_clip) 329 return loss, grad_params 330 end 331 332 -- start optimization here 333 train_losses = {} 334 val_losses = {} 335 local iterations = math.floor(opt.max_epochs * loader.ntrain) 336 local iterations_per_epoch = loader.ntrain 337 local loss0 = nil 338 339 local optimizer = nil 340 341 if opt.optimizer == 'adam' then 342 optimizer = optim.adam 343 elseif opt.optimizer == 'sgd' then 344 optimizer = optim.sgd 345 optim_state.learningRateDecay = opt.decay_rate 346 optim_state.momentum = 0.99 347 optim_state.nesterov = true 348 optim_state.dampening = 0 349 elseif opt.optimizer == 'eamsgd' then 350 optimizer = optim.eamsgd 351 optim_state.learningRate = mpiOptions.learningRate 352 optim_state.momentum = mpiOptions.momentum 353 optim_state.pclient = mpiOptions.pclient 354 optim_state.communicationPeriod = mpiOptions.communicationPeriod 355 optim_state.movingRateAlpha = mpiOptions.movingRateAlpha 356 optim_state.learningRateDecay = mpiOptions.learningRateDecay 357 optim_state.learningRateDecayPower = mpiOptions.learningRateDecayPower 358 else 359 optimizer = optim.rmsprop 360 end 361 362 -- initialize MPI optimizer clients 363 rank = mpiOptions.rank or -1 364 pclient = mpiOptions.pclient or nil 365 print('i am ' .. rank .. ' ready to run') 366 if pclient then 367 pclient:start(params,grad_params) 368 assert(rank == pclient.rank) 369 print('pc ' .. rank .. ' started') 370 end 371 372 -- run optimizer 373 sys.tic() -- time the training procedure 374 for i = 1, iterations do 375 local epoch = i / loader.ntrain 376 377 local timer = torch.Timer() 378 379 local _, loss = optimizer(feval, params, optim_state) 380 if opt.accurate_gpu_timing == 1 and opt.gpuid >= 0 then 381 --[[ 382 Note on timing: The reported time can be off because the GPU is invoked async. If one 383 wants to have exactly accurate timings one must call cutorch.synchronize() right here. 384 I will avoid doing so by default because this can incur computational overhead. 385 --]] 386 cutorch.synchronize() 387 end 388 local time = timer:time().real 389 390 local train_loss = loss[1] -- the loss is inside a list, pop it 391 train_losses[i] = train_loss 392 393 -- exponential learning rate decay for rmsprop 394 if opt.optimizer == 'rmsprop' and i % loader.ntrain == 0 and opt.learning_rate_decay < 1 then 395 if epoch >= opt.learning_rate_decay_after then 396 local decay_factor = opt.learning_rate_decay 397 optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it 398 print('decayed learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate) 399 end 400 end 401 402 -- every now and then or on last iteration 403 local eval_multiplier = 1 404 if epoch < 10 then 405 eval_multiplier = 1 --increase this to eval less often in the first iterations 406 end 407 print('current iteration: ' .. i) 408 print('total iterations: ' .. iterations) 409 if i % (opt.eval_val_every * eval_multiplier) == 0 or i == iterations then 410 -- evaluate loss on validation data 411 local val_loss = eval_split(2) -- 2 = validation 412 val_losses[i] = val_loss 413 414 local savefile = string.format('%s/lm_%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss) 415 print('saving checkpoint to ' .. savefile) 416 local checkpoint = {} 417 checkpoint.protos = protos 418 checkpoint.opt = opt 419 checkpoint.train_losses = train_losses 420 checkpoint.val_loss = val_loss 421 checkpoint.val_losses = val_losses 422 checkpoint.i = i 423 checkpoint.epoch = epoch 424 checkpoint.vocab = loader.vocab_mapping 425 --checkpoint.optim_state = optim_state 426 --checkpoint.optimizer = opt.optimizer 427 torch.save(savefile, checkpoint) 428 end 429 430 if i % opt.print_every == 0 then 431 machine_name = io.popen('hostname -s'):read() 432 print(string.format("%s Rank %s %d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.4fs",machine_name, rank, i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 433 end 434 435 if i % (opt.print_every*10) == 0 then 436 --print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.4fs", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 437 --e, V = torch.eig(h2hs[1].data.module.weight:float(), 'N') 438 --print(e[1]) 439 --e, V = torch.eig(h2hs[2].data.module.weight:float(), 'N') 440 --print(e[1]) 441 --e, V = torch.eig(h2hs[3].data.module.weight:float(), 'N') 442 --print(e[1]) 443 end 444 445 if i % 10 == 0 then collectgarbage() end 446 447 -- handle early stopping if things are going really bad 448 if loss[1] ~= loss[1] then 449 print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?') 450 break -- halt 451 end 452 if loss0 == nil then loss0 = loss[1] end 453 if loss[1] > loss0 * 100 then 454 print(string.format("loss is exploding, aborting. (%6.2f vs %6.2f)", loss0, loss[1])) 455 break -- halt 456 end 457 end 458 459 -- stop optimizer clients 460 if pclient then 461 pclient:stop() 462 end 463 464 print(rank,'total training time is', sys.toc())