‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

sample.lua (7346B)


      1 
      2 --[[
      3 
      4 This file samples characters from a trained model
      5 
      6 Code is based on implementation in 
      7 https://github.com/oxford-cs-ml-2015/practical6
      8 
      9 ]]--
     10 
     11 require 'torch'
     12 require 'nn'
     13 require 'nngraph'
     14 require 'optim'
     15 require 'lfs'
     16 
     17 require 'util.GloVeEmbedding'
     18 require 'util.OneHot'
     19 require 'util.misc'
     20 local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader'
     21 
     22 cmd = torch.CmdLine()
     23 cmd:text()
     24 cmd:text('Sample from a character-level language model')
     25 cmd:text()
     26 cmd:text('Options')
     27 -- required:
     28 cmd:argument('-model','model checkpoint to use for sampling')
     29 -- optional parameters
     30 cmd:option('-seed',123,'random number generator\'s seed')
     31 cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
     32 cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
     33 cmd:option('-length',2000,'number of characters to sample')
     34 cmd:option('-temperature',1,'temperature of sampling')
     35 cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
     36 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
     37 cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics')
     38 cmd:option('-skip_unk',0,'whether to skip UNK tokens when sampling')
     39 cmd:option('-input_loop',0,'whether to read new seed text from stdin after having finished sampling')
     40 cmd:option('-word_level',1,'whether to operate on the word level, instead of character level (0: use chars, 1: use words)') --todo: set this in checkpoint
     41 cmd:text()
     42 
     43 -- parse input params
     44 opt = cmd:parse(arg)
     45 
     46 -- gated print: simple utility function wrapping a print
     47 function gprint(str)
     48     if opt.verbose == 1 then print(str) end
     49 end
     50 
     51 -- check that cunn/cutorch are installed if user wants to use the GPU
     52 if opt.gpuid >= 0 and opt.opencl == 0 then
     53     local ok, cunn = pcall(require, 'cunn')
     54     local ok2, cutorch = pcall(require, 'cutorch')
     55     if not ok then gprint('package cunn not found!') end
     56     if not ok2 then gprint('package cutorch not found!') end
     57     if ok and ok2 then
     58         gprint('using CUDA on GPU ' .. opt.gpuid .. '...')
     59         gprint('Make sure that your saved checkpoint was also trained with GPU. If it was trained with CPU use -gpuid -1 for sampling as well')
     60         cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
     61         cutorch.manualSeed(opt.seed)
     62     else
     63         gprint('Falling back on CPU mode')
     64         opt.gpuid = -1 -- overwrite user setting
     65     end
     66 end
     67 
     68 -- check that clnn/cltorch are installed if user wants to use OpenCL
     69 if opt.gpuid >= 0 and opt.opencl == 1 then
     70     local ok, cunn = pcall(require, 'clnn')
     71     local ok2, cutorch = pcall(require, 'cltorch')
     72     if not ok then print('package clnn not found!') end
     73     if not ok2 then print('package cltorch not found!') end
     74     if ok and ok2 then
     75         gprint('using OpenCL on GPU ' .. opt.gpuid .. '...')
     76         gprint('Make sure that your saved checkpoint was also trained with GPU. If it was trained with CPU use -gpuid -1 for sampling as well')
     77         cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
     78         torch.manualSeed(opt.seed)
     79     else
     80         gprint('Falling back on CPU mode')
     81         opt.gpuid = -1 -- overwrite user setting
     82     end
     83 end
     84 require 'util.SharedDropout'
     85 
     86 torch.manualSeed(opt.seed)
     87 
     88 -- load the model checkpoint
     89 if not lfs.attributes(opt.model, 'mode') then
     90     gprint('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?')
     91 end
     92 checkpoint = torch.load(opt.model)
     93 protos = checkpoint.protos
     94 protos.rnn:evaluate() -- put in eval mode so that dropout works properly
     95 
     96 -- initialize the vocabulary (and its inverted version)
     97 local vocab = checkpoint.vocab
     98 local ivocab = {}
     99 for c,i in pairs(vocab) do ivocab[i] = c end
    100 
    101 -- initialize the rnn state to all zeros
    102 gprint('creating an ' .. checkpoint.opt.model .. '...')
    103 
    104 
    105 -- do a few seeded timesteps
    106 local seed_text = opt.primetext
    107 
    108 repeat
    109 
    110 	local current_state
    111 	current_state = {}
    112 	for L = 1,checkpoint.opt.num_layers do
    113 		-- c and h for all layers
    114 		local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double()
    115 		if opt.gpuid >= 0 and opt.opencl == 0 then h_init = h_init:cuda() end
    116 		if opt.gpuid >= 0 and opt.opencl == 1 then h_init = h_init:cl() end
    117 		table.insert(current_state, h_init:clone())
    118 		if checkpoint.opt.model == 'lstm' then
    119 			table.insert(current_state, h_init:clone())
    120 		end
    121 	end
    122 	state_size = #current_state
    123 	if string.len(seed_text) > 0 then
    124 		gprint('seeding with ' .. seed_text)
    125 		gprint('--------------------------')
    126         local tokens = {}
    127         for c in CharSplitLMMinibatchLoader.tokens(seed_text, opt.word_level == 1) do
    128             if vocab[c] == nil then c = c:lower() end
    129             tokens[#tokens + 1] = c
    130         end
    131         --tokens[#tokens + 1] = '.'
    132         for _, c in ipairs(tokens) do --todo: word_level should be stored in checkpoint
    133             local idx = vocab[c]
    134             if idx ~= nil then
    135                 prev_char = torch.Tensor{idx}
    136 				if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end
    137 				if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end
    138 			local lst = protos.rnn:forward{prev_char, unpack(current_state)}
    139 			-- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece
    140 			current_state = {}
    141 			for i=1,state_size do table.insert(current_state, lst[i]) end
    142 			prediction = lst[#lst] -- last element holds the log probabilities
    143 		end
    144         end
    145 	else
    146 		-- fill with uniform probabilities over characters (? hmm)
    147 		gprint('missing seed text, using uniform probability over first character')
    148 		gprint('--------------------------')
    149         print('1')
    150 		prediction = torch.Tensor(1, #ivocab):fill(1)/(#ivocab)
    151         prin('2')
    152 		if opt.gpuid >= 0 and opt.opencl == 0 then prediction = prediction:cuda() end
    153 		print('3')
    154         if opt.gpuid >= 0 and opt.opencl == 1 then prediction = prediction:cl() end
    155 	end
    156 
    157 	-- start sampling/argmaxing
    158 	for i=1, opt.length do
    159         print('4')
    160 		-- log probabilities from the previous timestep
    161 		if opt.sample == 0 then
    162 			-- use argmax
    163             -- TODO: Skip UNK
    164 			local _, prev_char_ = prediction:max(2)
    165 			prev_char = prev_char_:resize(1)
    166 		else
    167 			-- use sampling
    168 			prediction:div(opt.temperature) -- scale by temperature
    169 			local probs = torch.exp(prediction):squeeze()
    170 			probs:div(torch.sum(probs)) -- renormalize so probs sum to one
    171 
    172             if opt.skip_unk then
    173                 prev_char = torch.multinomial(probs:float(), 2):float()
    174                 prev_char = prev_char[1] == vocab["UNK"] and prev_char[{{2}}] or prev_char[{{1}}]
    175             else
    176 				prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
    177 			end
    178         end
    179 
    180 		-- forward the rnn for next character
    181 		local lst = protos.rnn:forward{prev_char, unpack(current_state)}
    182 		current_state = {}
    183 		for i=1,state_size do table.insert(current_state, lst[i]) end
    184 		prediction = lst[#lst] -- last element holds the log probabilities
    185 
    186         word = ivocab[prev_char[1]]
    187         if opt.word_level and word == "RN" then 
    188             word = "\n"
    189         end
    190         io.write(word)
    191         if opt.word_level then
    192             io.write(" ")
    193         end
    194 	end
    195 	io.write('\n') io.flush()
    196     if opt.input_loop == 1 then
    197         seed_text = io.read()
    198     end
    199 until opt.input_loop ~= 1
    200