‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

inspect_checkpoint.lua (954B)


      1 -- simple script that loads a checkpoint and prints its opts
      2 
      3 require 'torch'
      4 require 'nn'
      5 require 'nngraph'
      6 
      7 require 'util.OneHot'
      8 require 'util.misc'
      9 
     10 cmd = torch.CmdLine()
     11 cmd:text()
     12 cmd:text('Load a checkpoint and print its options and validation losses.')
     13 cmd:text()
     14 cmd:text('Options')
     15 cmd:argument('-model','model to load')
     16 cmd:option('-gpuid',0,'gpu to use')
     17 cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
     18 cmd:text()
     19 
     20 -- parse input params
     21 opt = cmd:parse(arg)
     22 
     23 if opt.gpuid >= 0 and opt.opencl == 0 then
     24     print('using CUDA on GPU ' .. opt.gpuid .. '...')
     25     require 'cutorch'
     26     require 'cunn'
     27     cutorch.setDevice(opt.gpuid + 1)
     28 end
     29 
     30 if opt.gpuid >= 0 and opt.opencl == 1 then
     31     print('using OpenCL on GPU ' .. opt.gpuid .. '...')
     32     require 'cltorch'
     33     require 'clnn'
     34     cltorch.setDevice(opt.gpuid + 1)
     35 end
     36 
     37 local model = torch.load(opt.model)
     38 
     39 print('opt:')
     40 print(model.opt)
     41 print('val losses:')
     42 print(model.val_losses)
     43