‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

convert_gpu_cpu_checkpoint.lua (2176B)


      1 --[[
      2 A quick patch for converting GPU checkpoints to 
      3 CPU checkpoints until I implement a more long-term
      4 solution. Takes the path to the model and creates
      5 a file in the same location and path, but with _cpu.t7
      6 appended.
      7 ]]--
      8 
      9 require 'torch'
     10 require 'nn'
     11 require 'nngraph'
     12 require 'lfs'
     13 
     14 require 'util.OneHot'
     15 require 'util.misc'
     16 
     17 cmd = torch.CmdLine()
     18 cmd:text()
     19 cmd:text('Sample from a character-level language model')
     20 cmd:text()
     21 cmd:text('Options')
     22 cmd:argument('-model','GPU model checkpoint to convert')
     23 cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
     24 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
     25 cmd:text()
     26 
     27 -- parse input params
     28 opt = cmd:parse(arg)
     29 
     30 -- check that cunn/cutorch are installed if user wants to use the GPU
     31 if opt.gpuid >= 0 and opt.opencl == 0 then
     32     local ok, cunn = pcall(require, 'cunn')
     33     local ok2, cutorch = pcall(require, 'cutorch')
     34     if not ok then print('package cunn not found!') end
     35     if not ok2 then print('package cutorch not found!') end
     36     if ok and ok2 then
     37         print('using CUDA on GPU ' .. opt.gpuid .. '...')
     38         cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
     39     else
     40     	print('Error, no GPU available?')
     41         os.exit()
     42     end
     43 end
     44 
     45 -- check that clnn/cltorch are installed if user wants to use OpenCL
     46 if opt.gpuid >= 0 and opt.opencl == 1 then
     47     local ok, cunn = pcall(require, 'clnn')
     48     local ok2, cutorch = pcall(require, 'cltorch')
     49     if not ok then print('package clnn not found!') end
     50     if not ok2 then print('package cltorch not found!') end
     51     if ok and ok2 then
     52         print('using OpenCL on GPU ' .. opt.gpuid .. '...')
     53         cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
     54     else
     55         print('Error, no GPU available?')
     56         os.exit()
     57     end
     58 end
     59 
     60 print('loading ' .. opt.model)
     61 checkpoint = torch.load(opt.model)
     62 protos = checkpoint.protos
     63 
     64 -- convert the networks to be CPU models
     65 for k,v in pairs(protos) do
     66 	print('converting ' .. k .. ' to CPU')
     67 	protos[k]:double()
     68 end
     69 
     70 local savefile = opt.model .. '_cpu.t7' -- append "cpu.t7" to filename
     71 torch.save(savefile, checkpoint)
     72 print('saved ' .. savefile)
     73 
     74