convert_gpu_cpu_checkpoint.lua (2176B)
1 --[[ 2 A quick patch for converting GPU checkpoints to 3 CPU checkpoints until I implement a more long-term 4 solution. Takes the path to the model and creates 5 a file in the same location and path, but with _cpu.t7 6 appended. 7 ]]-- 8 9 require 'torch' 10 require 'nn' 11 require 'nngraph' 12 require 'lfs' 13 14 require 'util.OneHot' 15 require 'util.misc' 16 17 cmd = torch.CmdLine() 18 cmd:text() 19 cmd:text('Sample from a character-level language model') 20 cmd:text() 21 cmd:text('Options') 22 cmd:argument('-model','GPU model checkpoint to convert') 23 cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') 24 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') 25 cmd:text() 26 27 -- parse input params 28 opt = cmd:parse(arg) 29 30 -- check that cunn/cutorch are installed if user wants to use the GPU 31 if opt.gpuid >= 0 and opt.opencl == 0 then 32 local ok, cunn = pcall(require, 'cunn') 33 local ok2, cutorch = pcall(require, 'cutorch') 34 if not ok then print('package cunn not found!') end 35 if not ok2 then print('package cutorch not found!') end 36 if ok and ok2 then 37 print('using CUDA on GPU ' .. opt.gpuid .. '...') 38 cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 39 else 40 print('Error, no GPU available?') 41 os.exit() 42 end 43 end 44 45 -- check that clnn/cltorch are installed if user wants to use OpenCL 46 if opt.gpuid >= 0 and opt.opencl == 1 then 47 local ok, cunn = pcall(require, 'clnn') 48 local ok2, cutorch = pcall(require, 'cltorch') 49 if not ok then print('package clnn not found!') end 50 if not ok2 then print('package cltorch not found!') end 51 if ok and ok2 then 52 print('using OpenCL on GPU ' .. opt.gpuid .. '...') 53 cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 54 else 55 print('Error, no GPU available?') 56 os.exit() 57 end 58 end 59 60 print('loading ' .. opt.model) 61 checkpoint = torch.load(opt.model) 62 protos = checkpoint.protos 63 64 -- convert the networks to be CPU models 65 for k,v in pairs(protos) do 66 print('converting ' .. k .. ' to CPU') 67 protos[k]:double() 68 end 69 70 local savefile = opt.model .. '_cpu.t7' -- append "cpu.t7" to filename 71 torch.save(savefile, checkpoint) 72 print('saved ' .. savefile) 73 74