inspect_checkpoint.lua (954B)
1 -- simple script that loads a checkpoint and prints its opts 2 3 require 'torch' 4 require 'nn' 5 require 'nngraph' 6 7 require 'util.OneHot' 8 require 'util.misc' 9 10 cmd = torch.CmdLine() 11 cmd:text() 12 cmd:text('Load a checkpoint and print its options and validation losses.') 13 cmd:text() 14 cmd:text('Options') 15 cmd:argument('-model','model to load') 16 cmd:option('-gpuid',0,'gpu to use') 17 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') 18 cmd:text() 19 20 -- parse input params 21 opt = cmd:parse(arg) 22 23 if opt.gpuid >= 0 and opt.opencl == 0 then 24 print('using CUDA on GPU ' .. opt.gpuid .. '...') 25 require 'cutorch' 26 require 'cunn' 27 cutorch.setDevice(opt.gpuid + 1) 28 end 29 30 if opt.gpuid >= 0 and opt.opencl == 1 then 31 print('using OpenCL on GPU ' .. opt.gpuid .. '...') 32 require 'cltorch' 33 require 'clnn' 34 cltorch.setDevice(opt.gpuid + 1) 35 end 36 37 local model = torch.load(opt.model) 38 39 print('opt:') 40 print(model.opt) 41 print('val losses:') 42 print(model.val_losses) 43