sample.lua (7346B)
1 2 --[[ 3 4 This file samples characters from a trained model 5 6 Code is based on implementation in 7 https://github.com/oxford-cs-ml-2015/practical6 8 9 ]]-- 10 11 require 'torch' 12 require 'nn' 13 require 'nngraph' 14 require 'optim' 15 require 'lfs' 16 17 require 'util.GloVeEmbedding' 18 require 'util.OneHot' 19 require 'util.misc' 20 local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader' 21 22 cmd = torch.CmdLine() 23 cmd:text() 24 cmd:text('Sample from a character-level language model') 25 cmd:text() 26 cmd:text('Options') 27 -- required: 28 cmd:argument('-model','model checkpoint to use for sampling') 29 -- optional parameters 30 cmd:option('-seed',123,'random number generator\'s seed') 31 cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep') 32 cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.') 33 cmd:option('-length',2000,'number of characters to sample') 34 cmd:option('-temperature',1,'temperature of sampling') 35 cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') 36 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') 37 cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics') 38 cmd:option('-skip_unk',0,'whether to skip UNK tokens when sampling') 39 cmd:option('-input_loop',0,'whether to read new seed text from stdin after having finished sampling') 40 cmd:option('-word_level',1,'whether to operate on the word level, instead of character level (0: use chars, 1: use words)') --todo: set this in checkpoint 41 cmd:text() 42 43 -- parse input params 44 opt = cmd:parse(arg) 45 46 -- gated print: simple utility function wrapping a print 47 function gprint(str) 48 if opt.verbose == 1 then print(str) end 49 end 50 51 -- check that cunn/cutorch are installed if user wants to use the GPU 52 if opt.gpuid >= 0 and opt.opencl == 0 then 53 local ok, cunn = pcall(require, 'cunn') 54 local ok2, cutorch = pcall(require, 'cutorch') 55 if not ok then gprint('package cunn not found!') end 56 if not ok2 then gprint('package cutorch not found!') end 57 if ok and ok2 then 58 gprint('using CUDA on GPU ' .. opt.gpuid .. '...') 59 gprint('Make sure that your saved checkpoint was also trained with GPU. If it was trained with CPU use -gpuid -1 for sampling as well') 60 cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 61 cutorch.manualSeed(opt.seed) 62 else 63 gprint('Falling back on CPU mode') 64 opt.gpuid = -1 -- overwrite user setting 65 end 66 end 67 68 -- check that clnn/cltorch are installed if user wants to use OpenCL 69 if opt.gpuid >= 0 and opt.opencl == 1 then 70 local ok, cunn = pcall(require, 'clnn') 71 local ok2, cutorch = pcall(require, 'cltorch') 72 if not ok then print('package clnn not found!') end 73 if not ok2 then print('package cltorch not found!') end 74 if ok and ok2 then 75 gprint('using OpenCL on GPU ' .. opt.gpuid .. '...') 76 gprint('Make sure that your saved checkpoint was also trained with GPU. If it was trained with CPU use -gpuid -1 for sampling as well') 77 cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 78 torch.manualSeed(opt.seed) 79 else 80 gprint('Falling back on CPU mode') 81 opt.gpuid = -1 -- overwrite user setting 82 end 83 end 84 require 'util.SharedDropout' 85 86 torch.manualSeed(opt.seed) 87 88 -- load the model checkpoint 89 if not lfs.attributes(opt.model, 'mode') then 90 gprint('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?') 91 end 92 checkpoint = torch.load(opt.model) 93 protos = checkpoint.protos 94 protos.rnn:evaluate() -- put in eval mode so that dropout works properly 95 96 -- initialize the vocabulary (and its inverted version) 97 local vocab = checkpoint.vocab 98 local ivocab = {} 99 for c,i in pairs(vocab) do ivocab[i] = c end 100 101 -- initialize the rnn state to all zeros 102 gprint('creating an ' .. checkpoint.opt.model .. '...') 103 104 105 -- do a few seeded timesteps 106 local seed_text = opt.primetext 107 108 repeat 109 110 local current_state 111 current_state = {} 112 for L = 1,checkpoint.opt.num_layers do 113 -- c and h for all layers 114 local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double() 115 if opt.gpuid >= 0 and opt.opencl == 0 then h_init = h_init:cuda() end 116 if opt.gpuid >= 0 and opt.opencl == 1 then h_init = h_init:cl() end 117 table.insert(current_state, h_init:clone()) 118 if checkpoint.opt.model == 'lstm' then 119 table.insert(current_state, h_init:clone()) 120 end 121 end 122 state_size = #current_state 123 if string.len(seed_text) > 0 then 124 gprint('seeding with ' .. seed_text) 125 gprint('--------------------------') 126 local tokens = {} 127 for c in CharSplitLMMinibatchLoader.tokens(seed_text, opt.word_level == 1) do 128 if vocab[c] == nil then c = c:lower() end 129 tokens[#tokens + 1] = c 130 end 131 --tokens[#tokens + 1] = '.' 132 for _, c in ipairs(tokens) do --todo: word_level should be stored in checkpoint 133 local idx = vocab[c] 134 if idx ~= nil then 135 prev_char = torch.Tensor{idx} 136 if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end 137 if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end 138 local lst = protos.rnn:forward{prev_char, unpack(current_state)} 139 -- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece 140 current_state = {} 141 for i=1,state_size do table.insert(current_state, lst[i]) end 142 prediction = lst[#lst] -- last element holds the log probabilities 143 end 144 end 145 else 146 -- fill with uniform probabilities over characters (? hmm) 147 gprint('missing seed text, using uniform probability over first character') 148 gprint('--------------------------') 149 print('1') 150 prediction = torch.Tensor(1, #ivocab):fill(1)/(#ivocab) 151 prin('2') 152 if opt.gpuid >= 0 and opt.opencl == 0 then prediction = prediction:cuda() end 153 print('3') 154 if opt.gpuid >= 0 and opt.opencl == 1 then prediction = prediction:cl() end 155 end 156 157 -- start sampling/argmaxing 158 for i=1, opt.length do 159 print('4') 160 -- log probabilities from the previous timestep 161 if opt.sample == 0 then 162 -- use argmax 163 -- TODO: Skip UNK 164 local _, prev_char_ = prediction:max(2) 165 prev_char = prev_char_:resize(1) 166 else 167 -- use sampling 168 prediction:div(opt.temperature) -- scale by temperature 169 local probs = torch.exp(prediction):squeeze() 170 probs:div(torch.sum(probs)) -- renormalize so probs sum to one 171 172 if opt.skip_unk then 173 prev_char = torch.multinomial(probs:float(), 2):float() 174 prev_char = prev_char[1] == vocab["UNK"] and prev_char[{{2}}] or prev_char[{{1}}] 175 else 176 prev_char = torch.multinomial(probs:float(), 1):resize(1):float() 177 end 178 end 179 180 -- forward the rnn for next character 181 local lst = protos.rnn:forward{prev_char, unpack(current_state)} 182 current_state = {} 183 for i=1,state_size do table.insert(current_state, lst[i]) end 184 prediction = lst[#lst] -- last element holds the log probabilities 185 186 word = ivocab[prev_char[1]] 187 if opt.word_level and word == "RN" then 188 word = "\n" 189 end 190 io.write(word) 191 if opt.word_level then 192 io.write(" ") 193 end 194 end 195 io.write('\n') io.flush() 196 if opt.input_loop == 1 then 197 seed_text = io.read() 198 end 199 until opt.input_loop ~= 1 200