‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

model_utils.lua (5152B)


      1 
      2 -- adapted from https://github.com/wojciechz/learning_to_execute
      3 -- utilities for combining/flattening parameters in a model
      4 -- the code in this script is more general than it needs to be, which is 
      5 -- why it is kind of a large
      6 
      7 require 'torch'
      8 local model_utils = {}
      9 function model_utils.combine_all_parameters(...)
     10     --[[ like module:getParameters, but operates on many modules ]]--
     11 
     12     -- get parameters
     13     local networks = {...}
     14     local parameters = {}
     15     local gradParameters = {}
     16     for i = 1, #networks do
     17         local net_params, net_grads = networks[i]:parameters()
     18 
     19         if net_params then
     20             for _, p in pairs(net_params) do
     21                 parameters[#parameters + 1] = p
     22             end
     23             for _, g in pairs(net_grads) do
     24                 gradParameters[#gradParameters + 1] = g
     25             end
     26         end
     27     end
     28 
     29     local function storageInSet(set, storage)
     30         local storageAndOffset = set[torch.pointer(storage)]
     31         if storageAndOffset == nil then
     32             return nil
     33         end
     34         local _, offset = unpack(storageAndOffset)
     35         return offset
     36     end
     37 
     38     -- this function flattens arbitrary lists of parameters,
     39     -- even complex shared ones
     40     local function flatten(parameters)
     41         if not parameters or #parameters == 0 then
     42             return torch.Tensor()
     43         end
     44         local Tensor = parameters[1].new
     45 
     46         local storages = {}
     47         local nParameters = 0
     48         for k = 1,#parameters do
     49             local storage = parameters[k]:storage()
     50             if not storageInSet(storages, storage) then
     51                 storages[torch.pointer(storage)] = {storage, nParameters}
     52                 nParameters = nParameters + storage:size()
     53             end
     54         end
     55 
     56         local flatParameters = Tensor(nParameters):fill(1)
     57         local flatStorage = flatParameters:storage()
     58 
     59         for k = 1,#parameters do
     60             local storageOffset = storageInSet(storages, parameters[k]:storage())
     61             parameters[k]:set(flatStorage,
     62                 storageOffset + parameters[k]:storageOffset(),
     63                 parameters[k]:size(),
     64                 parameters[k]:stride())
     65             parameters[k]:zero()
     66         end
     67 
     68         local maskParameters=  flatParameters:float():clone()
     69         local cumSumOfHoles = flatParameters:float():cumsum(1)
     70         local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
     71         local flatUsedParameters = Tensor(nUsedParameters)
     72         local flatUsedStorage = flatUsedParameters:storage()
     73 
     74         for k = 1,#parameters do
     75             local offset = cumSumOfHoles[parameters[k]:storageOffset()]
     76             parameters[k]:set(flatUsedStorage,
     77                 parameters[k]:storageOffset() - offset,
     78                 parameters[k]:size(),
     79                 parameters[k]:stride())
     80         end
     81 
     82         for _, storageAndOffset in pairs(storages) do
     83             local k, v = unpack(storageAndOffset)
     84             flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
     85         end
     86 
     87         if cumSumOfHoles:sum() == 0 then
     88             flatUsedParameters:copy(flatParameters)
     89         else
     90             local counter = 0
     91             for k = 1,flatParameters:nElement() do
     92                 if maskParameters[k] == 0 then
     93                     counter = counter + 1
     94                     flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
     95                 end
     96             end
     97             assert (counter == nUsedParameters)
     98         end
     99         return flatUsedParameters
    100     end
    101 
    102     -- flatten parameters and gradients
    103     local flatParameters = flatten(parameters)
    104     local flatGradParameters = flatten(gradParameters)
    105 
    106     -- return new flat vector that contains all discrete parameters
    107     return flatParameters, flatGradParameters
    108 end
    109 
    110 
    111 
    112 
    113 function model_utils.clone_many_times(net, T)
    114     local clones = {}
    115 
    116     local params, gradParams
    117     if net.parameters then
    118         params, gradParams = net:parameters()
    119         if params == nil then
    120             params = {}
    121         end
    122     end
    123 
    124     local paramsNoGrad
    125     if net.parametersNoGrad then
    126         paramsNoGrad = net:parametersNoGrad() 
    127     end
    128 
    129     local mem = torch.MemoryFile("w"):binary()
    130     mem:writeObject(net)
    131 
    132     for t = 1, T do
    133         -- We need to use a new reader for each clone.
    134         -- We don't want to use the pointers to already read objects.
    135         local reader = torch.MemoryFile(mem:storage(), "r"):binary()
    136         local clone = reader:readObject()
    137         reader:close()
    138 
    139         if net.parameters then
    140             local cloneParams, cloneGradParams = clone:parameters()
    141             local cloneParamsNoGrad
    142             for i = 1, #params do
    143                 cloneParams[i]:set(params[i])
    144                 cloneGradParams[i]:set(gradParams[i])
    145             end
    146             if paramsNoGrad then
    147                 cloneParamsNoGrad = clone:parametersNoGrad()
    148                 for i =1,#paramsNoGrad do
    149                     cloneParamsNoGrad[i]:set(paramsNoGrad[i])
    150                 end
    151             end
    152         end
    153 
    154         clones[t] = clone
    155         collectgarbage()
    156     end
    157 
    158     mem:close()
    159     return clones
    160 end
    161 
    162 return model_utils