model_utils.lua (5152B)
1 2 -- adapted from https://github.com/wojciechz/learning_to_execute 3 -- utilities for combining/flattening parameters in a model 4 -- the code in this script is more general than it needs to be, which is 5 -- why it is kind of a large 6 7 require 'torch' 8 local model_utils = {} 9 function model_utils.combine_all_parameters(...) 10 --[[ like module:getParameters, but operates on many modules ]]-- 11 12 -- get parameters 13 local networks = {...} 14 local parameters = {} 15 local gradParameters = {} 16 for i = 1, #networks do 17 local net_params, net_grads = networks[i]:parameters() 18 19 if net_params then 20 for _, p in pairs(net_params) do 21 parameters[#parameters + 1] = p 22 end 23 for _, g in pairs(net_grads) do 24 gradParameters[#gradParameters + 1] = g 25 end 26 end 27 end 28 29 local function storageInSet(set, storage) 30 local storageAndOffset = set[torch.pointer(storage)] 31 if storageAndOffset == nil then 32 return nil 33 end 34 local _, offset = unpack(storageAndOffset) 35 return offset 36 end 37 38 -- this function flattens arbitrary lists of parameters, 39 -- even complex shared ones 40 local function flatten(parameters) 41 if not parameters or #parameters == 0 then 42 return torch.Tensor() 43 end 44 local Tensor = parameters[1].new 45 46 local storages = {} 47 local nParameters = 0 48 for k = 1,#parameters do 49 local storage = parameters[k]:storage() 50 if not storageInSet(storages, storage) then 51 storages[torch.pointer(storage)] = {storage, nParameters} 52 nParameters = nParameters + storage:size() 53 end 54 end 55 56 local flatParameters = Tensor(nParameters):fill(1) 57 local flatStorage = flatParameters:storage() 58 59 for k = 1,#parameters do 60 local storageOffset = storageInSet(storages, parameters[k]:storage()) 61 parameters[k]:set(flatStorage, 62 storageOffset + parameters[k]:storageOffset(), 63 parameters[k]:size(), 64 parameters[k]:stride()) 65 parameters[k]:zero() 66 end 67 68 local maskParameters= flatParameters:float():clone() 69 local cumSumOfHoles = flatParameters:float():cumsum(1) 70 local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 71 local flatUsedParameters = Tensor(nUsedParameters) 72 local flatUsedStorage = flatUsedParameters:storage() 73 74 for k = 1,#parameters do 75 local offset = cumSumOfHoles[parameters[k]:storageOffset()] 76 parameters[k]:set(flatUsedStorage, 77 parameters[k]:storageOffset() - offset, 78 parameters[k]:size(), 79 parameters[k]:stride()) 80 end 81 82 for _, storageAndOffset in pairs(storages) do 83 local k, v = unpack(storageAndOffset) 84 flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 85 end 86 87 if cumSumOfHoles:sum() == 0 then 88 flatUsedParameters:copy(flatParameters) 89 else 90 local counter = 0 91 for k = 1,flatParameters:nElement() do 92 if maskParameters[k] == 0 then 93 counter = counter + 1 94 flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 95 end 96 end 97 assert (counter == nUsedParameters) 98 end 99 return flatUsedParameters 100 end 101 102 -- flatten parameters and gradients 103 local flatParameters = flatten(parameters) 104 local flatGradParameters = flatten(gradParameters) 105 106 -- return new flat vector that contains all discrete parameters 107 return flatParameters, flatGradParameters 108 end 109 110 111 112 113 function model_utils.clone_many_times(net, T) 114 local clones = {} 115 116 local params, gradParams 117 if net.parameters then 118 params, gradParams = net:parameters() 119 if params == nil then 120 params = {} 121 end 122 end 123 124 local paramsNoGrad 125 if net.parametersNoGrad then 126 paramsNoGrad = net:parametersNoGrad() 127 end 128 129 local mem = torch.MemoryFile("w"):binary() 130 mem:writeObject(net) 131 132 for t = 1, T do 133 -- We need to use a new reader for each clone. 134 -- We don't want to use the pointers to already read objects. 135 local reader = torch.MemoryFile(mem:storage(), "r"):binary() 136 local clone = reader:readObject() 137 reader:close() 138 139 if net.parameters then 140 local cloneParams, cloneGradParams = clone:parameters() 141 local cloneParamsNoGrad 142 for i = 1, #params do 143 cloneParams[i]:set(params[i]) 144 cloneGradParams[i]:set(gradParams[i]) 145 end 146 if paramsNoGrad then 147 cloneParamsNoGrad = clone:parametersNoGrad() 148 for i =1,#paramsNoGrad do 149 cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 150 end 151 end 152 end 153 154 clones[t] = clone 155 collectgarbage() 156 end 157 158 mem:close() 159 return clones 160 end 161 162 return model_utils