FrEIA.framework module
The framework module contains the logic used in building the graph and inferring the order that the nodes have to be executed in forward and backward direction.
'''The framework module contains the logic used in building the graph and inferring the order that the nodes have to be executed in forward and backward direction.''' import warnings import numpy as np import torch import torch.nn as nn from torch.autograd import Variable from . import dummy_modules as dummys class Node: '''The Node class represents one transformation in the graph, with an arbitrary number of in- and outputs.''' def __init__(self, inputs, module_type, module_args, conditions=[], name=None): self.inputs = self.parse_inputs(inputs) if isinstance(conditions, (list, tuple)): self.conditions = conditions else: self.conditions = [conditions,] self.outputs = [] self.module_type = module_type self.module_args = module_args self.input_dims = None self.module = None self.computed = None self.computed_rev = None self.id = None if name: self.name = name else: self.name = hex(id(self))[-6:] for i in range(255): exec('self.out{0} = (self, {0})'.format(i)) def parse_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if isinstance(inputs[0], (list, tuple)): return inputs elif len(inputs) == 2: return [inputs,] else: raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.") else: assert isinstance(inputs, Node), "Received object of invalid type "\ f"({type(inputs)}) as input for node '{name}'." return [(inputs, 0),] def build_modules(self, verbose=True): ''' Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node. ''' if not self.input_dims: # Only do it if this hasn't been computed yet self.input_dims = [n.build_modules(verbose=verbose)[c] for n, c in self.inputs] try: if len(self.conditions) > 0: c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions] self.module = self.module_type(self.input_dims, dims_c=c_dims, **self.module_args) else: self.module = self.module_type(self.input_dims, **self.module_args) except Exception as e: print('Error in node %s' % (self.name)) raise e if verbose: print(f"Node '{self.name}' takes the following inputs:") for d, (n, c) in zip(self.input_dims, self.inputs): print(f"\t Output #{c} of node '{n.name}' with dims {d}") for c in self.conditions: print(f"\t conditioned on node '{c.name}' " + f"with dims {c.data.shape}") print() self.output_dims = self.module.output_dims(self.input_dims) self.n_outputs = len(self.output_dims) return self.output_dims def run_forward(self, op_list): '''Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)''' if not self.computed: # Compute all nodes which provide inputs, filter out the # channels you need self.input_vars = [] for i, (n, c) in enumerate(self.inputs): self.input_vars.append(n.run_forward(op_list)[c]) # Register self as an output in the input node n.outputs.append((self, i)) # Compute all nodes which provide conditioning self.condition_vars = [] for i, c in enumerate(self.conditions): self.condition_vars.append(c.run_forward(op_list)[0]) # Register self as an output in the condition node c.outputs.append((self, i)) # All outputs could now be computed self.computed = [(self.id, i) for i in range(self.n_outputs)] op_list.append((self.id, self.input_vars, self.computed, self.condition_vars)) # Return the variables you have computed (this happens mulitple times # without recomputing if called repeatedly) return self.computed def run_backward(self, op_list): '''See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work''' assert len(self.outputs) > 0, "Call run_forward first" if not self.computed_rev: # These are the input variables that must be computed first output_vars = [(self.id, i) for i in range(self.n_outputs)] # Recursively compute these for n, c in self.outputs: n.run_backward(op_list) # The variables that this node computes are the input variables # from the forward pass self.computed_rev = self.input_vars if len(self.condition_vars) == 0: self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions] op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars)) return self.computed_rev class InputNode(Node): '''Special type of node that represents the input data of the whole net (or ouput when running reverse)''' def __init__(self, *dims, name='node'): self.name = name self.data = dummys.dummy_data(*dims) self.outputs = [] self.conditions = [] self.condition_vars = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0) def build_modules(self, verbose=True): return [self.data.shape] def run_forward(self, op_list): return [(self.id, 0)] class ConditionNode(Node): '''Special type of node that represents contitional input to the internal networks inside coupling layers''' def __init__(self, *dims, name='node'): self.name = name self.data = dummys.dummy_data(*dims) self.outputs = [] self.conditions = [] self.condition_vars = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0) def build_modules(self, verbose=True): return [self.data.shape] def run_forward(self, op_list): return [(self.id, 0)] class OutputNode(Node): '''Special type of node that represents the output of the whole net (of the input when running in reverse)''' class dummy(nn.Module): def __init__(self, *args): super(OutputNode.dummy, self).__init__() def __call__(*args): return args def output_dims(*args): return args def __init__(self, inputs, name='node'): self.module_type, self.module_args = self.dummy, {} self.output_dims = [] self.inputs = self.parse_inputs(inputs) self.conditions = [] self.input_dims, self.module = None, None self.computed = None self.id = None self.name = name for c, inp in enumerate(self.inputs): inp[0].outputs.append((self, c)) def run_backward(self, op_list): return [(self.id, 0)] class ReversibleGraphNet(nn.Module): '''This class represents the invertible net itself. It is a subclass of torch.nn.Module and supports the same methods. The forward method has an additional option 'rev', whith which the net can be computed in reverse.''' def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True): '''node_list should be a list of all nodes involved, and ind_in, ind_out are the indexes of the special nodes InputNode and OutputNode in this list.''' super(ReversibleGraphNet, self).__init__() # Gather lists of input, output and condition nodes if ind_in is not None: warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " + "input and output nodes are detected automatically.") if isinstance(ind_in, int): self.ind_in = list([ind_in]) else: self.ind_in = ind_in else: self.ind_in = [i for i in range(len(node_list)) if isinstance(node_list[i], InputNode)] assert len(self.ind_in) > 0, "No input nodes specified." if ind_out is not None: warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " + "input and output nodes are detected automatically.") if isinstance(ind_out, int): self.ind_out = list([ind_out]) else: self.ind_out = ind_out else: self.ind_out = [i for i in range(len(node_list)) if isinstance(node_list[i], OutputNode)] assert len(self.ind_out) > 0, "No output nodes specified." self.ind_cond = [i for i in range(len(node_list)) if isinstance(node_list[i], ConditionNode)] self.return_vars = [] self.input_vars = [] self.cond_vars = [] # Assign each node a unique ID self.node_list = node_list for i, n in enumerate(node_list): n.id = i n.graph = self # Recursively build the nodes nn.Modules and determine order of # operations ops = [] for i in self.ind_out: node_list[i].build_modules(verbose=verbose) node_list[i].run_forward(ops) # create list of Pytorch variables that are used variables = set() for o in ops: variables = variables.union(set(o[1] + o[2] + o[3])) self.variables_ind = list(variables) self.indexed_ops = self.ops_to_indexed(ops) self.module_list = nn.ModuleList([n.module for n in node_list]) self.module_cond = [(len(n.conditions) > 0) for n in node_list] self._buffers = {F'tmp_var_{i}' : None for i in range(len(variables))} # Find out the order of operations for reverse calculations ops_rev = [] for i in self.ind_in + self.ind_cond: node_list[i].run_backward(ops_rev) self.indexed_ops_rev = self.ops_to_indexed(ops_rev) def ops_to_indexed(self, ops): '''Helper function to translate the list of variables (origin ID, channel), to variable IDs.''' result = [] for o in ops: try: vars_in = [self.variables_ind.index(v) for v in o[1]] except ValueError: vars_in = -1 vars_out = [self.variables_ind.index(v) for v in o[2]] vars_cond = [self.variables_ind.index(v) for v in o[3]] # Collect input/output/conditioning nodes in separate lists, but don't # add to indexed ops if o[0] in self.ind_out: self.return_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_in: self.input_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_cond: if self.variables_ind.index(o[1][0]) not in self.cond_vars: self.cond_vars.append(self.variables_ind.index(o[1][0])) else: print('Is this branch ever reached?') continue result.append((o[0], vars_in, vars_out, vars_cond)) # Sort input/output/conditioning variables so they correspond to initial # node list order self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) self.cond_vars.sort(key=lambda i: self.variables_ind[i][0]) return result def forward(self, x, c=None, rev=False, intermediate_outputs=False): '''Forward or backward computation of the whole net.''' if rev: use_list = self.indexed_ops_rev input_vars, output_vars = self.return_vars, self.input_vars else: use_list = self.indexed_ops input_vars, output_vars = self.input_vars, self.return_vars # Assign input data to respective variables if isinstance(x, (list, tuple)): assert len(x) == len(input_vars), ( f"Got list of {len(x)} input tensors for " f"{'inverse' if rev else 'forward'} pass, but expected " f"{len(input_vars)}." ) for i in range(len(input_vars)): self._buffers[F'tmp_var_{input_vars[i]}'] = x[i] else: assert len(input_vars) == 1, (f"Got single input tensor for " f"{'inverse' if rev else 'forward'} " f"pass, but expected list of " f"{len(input_vars)}.") self._buffers[F'tmp_var_{input_vars[0]}'] = x # Assign conditioning data to respective variables if c is None: assert len(self.cond_vars) == 0 elif isinstance(c, (list, tuple)): assert len(c) == len(self.cond_vars), f'{len(c)}, {len(self.cond_vars)}' for i in range(len(self.cond_vars)): self._buffers[F'tmp_var_{self.cond_vars[i]}'] = c[i] else: assert len(self.cond_vars) == 1 self._buffers[F'tmp_var_{self.cond_vars[0]}'] = c # Prepare dictionary for intermediate node outputs out_dict = {} # Run all modules with the given inputs for o in use_list: try: x = [self._buffers[F'tmp_var_{i}'] for i in o[1]] if self.module_cond[o[0]]: c = [self._buffers[F'tmp_var_{i}'] for i in o[3]] results = self.module_list[o[0]](x, c=c, rev=rev) else: results = self.module_list[o[0]](x, rev=rev) except TypeError: raise RuntimeError("Are you sure all used Nodes are in the " "Node list?") out_dict[self.node_list[o[0]].name] = results for i, r in zip(o[2], results): self._buffers[F'tmp_var_{i}'] = r if intermediate_outputs: return out_dict else: out = [self._buffers[F'tmp_var_{output_vars[i]}'] for i in range(len(output_vars))] if len(out) == 1: return out[0] else: return out def log_jacobian(self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False): '''Compute the log jacobian determinant of the whole net.''' if run_forward or c is not None: self.condition = c jacobian = 0 if rev: use_list = self.indexed_ops_rev else: use_list = self.indexed_ops if run_forward: if x is None: raise RuntimeError("You need to provide an input if you want " "to run a forward pass") self.forward(x, c, rev=rev) # Prepare dictionary for intermediate node outputs jacobian_dict = {} # Run all modules with the given inputs for o in use_list: x = [self._buffers[F'tmp_var_{i}'] for i in o[1]] if self.module_cond[o[0]]: c = [self._buffers[F'tmp_var_{i}'] for i in o[3]] module_jacobian = self.module_list[o[0]].jacobian(x, c=c, rev=rev) else: module_jacobian = self.module_list[o[0]].jacobian(x, rev=rev) jacobian += module_jacobian jacobian_dict[self.node_list[o[0]].name] = module_jacobian if intermediate_outputs: return jacobian_dict else: return jacobian def jacobian(self, *args, **kwargs): '''Compute the log jacobian determinant of the whole net.''' warnings.warn("This function computes the log-jacobian determinant, not the " "jacobian as the name suggest. Will be removed in the future.") return self.log_jacobian(*args, **kwargs) def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04): '''Approximate log Jacobian determinant via finite differences.''' if isinstance(x, (list, tuple)): batch_size = x[0].shape[0] ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x] ndim_x_total = sum(ndim_x_separate) x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1) else: batch_size = x.shape[0] ndim_x_total = np.prod(x.shape[1:]) x_flat = x.reshape(batch_size, -1) J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total) for i in range(ndim_x_total): offset = x[0].new_zeros(batch_size, ndim_x_total) offset[:,i] = h if isinstance(x, (list, tuple)): x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1) x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))] x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1) x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))] else: x_upper = (x_flat + offset).view(*x.shape) x_lower = (x_flat - offset).view(*x.shape) y_upper = self.forward(x_upper, c=c) y_lower = self.forward(x_lower, c=c) if isinstance(y_upper, (list, tuple)): y_upper = torch.cat([y_i.view(batch_size, -1) for y_i in y_upper], dim=1) y_lower = torch.cat([y_i.view(batch_size, -1) for y_i in y_lower], dim=1) J_num[:,:,i] = (y_upper - y_lower).view(batch_size, -1) / (2*h) logdet_num = x[0].new_zeros(batch_size) for i in range(batch_size): logdet_num[i] = torch.det(J_num[i,:,:]).abs().log() return logdet_num def load_state_dict(self, state_dict, *args, **kwargs): state_dict_no_buffers = {} for k,p in state_dict.items(): if k in self._buffers and self._buffers[k] is None: continue state_dict_no_buffers[k] = p return super().load_state_dict(state_dict_no_buffers, *args, **kwargs) # Testing example if __name__ == '__main__': inp = InputNode(4, 64, 64, name='input') t1 = Node([(inp, 0)], dummys.dummy_mux, {}, name='t1') s1 = Node([(t1, 0)], dummys.dummy_2split, {}, name='s1') t2 = Node([(s1, 0)], dummys.dummy_module, {}, name='t2') s2 = Node([(s1, 1)], dummys.dummy_2split, {}, name='s2') t3 = Node([(s2, 0)], dummys.dummy_module, {}, name='t3') m1 = Node([(t3, 0), (s2, 1)], dummys.dummy_2merge, {}, name='m1') m2 = Node([(t2, 0), (m1, 0)], dummys.dummy_2merge, {}, name='m2') outp = OutputNode([(m2, 0)], name='output') all_nodes = [inp, outp, t1, s1, t2, s2, t3, m1, m2] net = ReversibleGraphNet(all_nodes, 0, 1)
Classes
class ConditionNode
Special type of node that represents contitional input to the internal networks inside coupling layers
class ConditionNode(Node): '''Special type of node that represents contitional input to the internal networks inside coupling layers''' def __init__(self, *dims, name='node'): self.name = name self.data = dummys.dummy_data(*dims) self.outputs = [] self.conditions = [] self.condition_vars = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0) def build_modules(self, verbose=True): return [self.data.shape] def run_forward(self, op_list): return [(self.id, 0)]
Ancestors (in MRO)
- ConditionNode
- Node
- builtins.object
Static methods
def __init__(
self, *dims)
Initialize self. See help(type(self)) for accurate signature.
def __init__(self, *dims, name='node'): self.name = name self.data = dummys.dummy_data(*dims) self.outputs = [] self.conditions = [] self.condition_vars = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0)
def build_modules(
self, verbose=True)
Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.
def build_modules(self, verbose=True): return [self.data.shape]
def parse_inputs(
self, inputs)
def parse_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if isinstance(inputs[0], (list, tuple)): return inputs elif len(inputs) == 2: return [inputs,] else: raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.") else: assert isinstance(inputs, Node), "Received object of invalid type "\ f"({type(inputs)}) as input for node '{name}'." return [(inputs, 0),]
def run_backward(
self, op_list)
See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work
def run_backward(self, op_list): '''See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work''' assert len(self.outputs) > 0, "Call run_forward first" if not self.computed_rev: # These are the input variables that must be computed first output_vars = [(self.id, i) for i in range(self.n_outputs)] # Recursively compute these for n, c in self.outputs: n.run_backward(op_list) # The variables that this node computes are the input variables # from the forward pass self.computed_rev = self.input_vars if len(self.condition_vars) == 0: self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions] op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars)) return self.computed_rev
def run_forward(
self, op_list)
Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)
def run_forward(self, op_list): return [(self.id, 0)]
Instance variables
var computed_rev
var condition_vars
var conditions
var data
var input_vars
var module
var n_outputs
var name
var out0
var outputs
class InputNode
Special type of node that represents the input data of the whole net (or ouput when running reverse)
class InputNode(Node): '''Special type of node that represents the input data of the whole net (or ouput when running reverse)''' def __init__(self, *dims, name='node'): self.name = name self.data = dummys.dummy_data(*dims) self.outputs = [] self.conditions = [] self.condition_vars = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0) def build_modules(self, verbose=True): return [self.data.shape] def run_forward(self, op_list): return [(self.id, 0)]
Ancestors (in MRO)
Static methods
def __init__(
self, *dims)
Initialize self. See help(type(self)) for accurate signature.
def __init__(self, *dims, name='node'): self.name = name self.data = dummys.dummy_data(*dims) self.outputs = [] self.conditions = [] self.condition_vars = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0)
def build_modules(
self, verbose=True)
Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.
def build_modules(self, verbose=True): return [self.data.shape]
def parse_inputs(
self, inputs)
def parse_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if isinstance(inputs[0], (list, tuple)): return inputs elif len(inputs) == 2: return [inputs,] else: raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.") else: assert isinstance(inputs, Node), "Received object of invalid type "\ f"({type(inputs)}) as input for node '{name}'." return [(inputs, 0),]
def run_backward(
self, op_list)
See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work
def run_backward(self, op_list): '''See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work''' assert len(self.outputs) > 0, "Call run_forward first" if not self.computed_rev: # These are the input variables that must be computed first output_vars = [(self.id, i) for i in range(self.n_outputs)] # Recursively compute these for n, c in self.outputs: n.run_backward(op_list) # The variables that this node computes are the input variables # from the forward pass self.computed_rev = self.input_vars if len(self.condition_vars) == 0: self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions] op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars)) return self.computed_rev
def run_forward(
self, op_list)
Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)
def run_forward(self, op_list): return [(self.id, 0)]
Instance variables
var computed_rev
var condition_vars
var conditions
var data
var input_vars
var module
var n_outputs
var name
var out0
var outputs
class Node
The Node class represents one transformation in the graph, with an arbitrary number of in- and outputs.
class Node: '''The Node class represents one transformation in the graph, with an arbitrary number of in- and outputs.''' def __init__(self, inputs, module_type, module_args, conditions=[], name=None): self.inputs = self.parse_inputs(inputs) if isinstance(conditions, (list, tuple)): self.conditions = conditions else: self.conditions = [conditions,] self.outputs = [] self.module_type = module_type self.module_args = module_args self.input_dims = None self.module = None self.computed = None self.computed_rev = None self.id = None if name: self.name = name else: self.name = hex(id(self))[-6:] for i in range(255): exec('self.out{0} = (self, {0})'.format(i)) def parse_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if isinstance(inputs[0], (list, tuple)): return inputs elif len(inputs) == 2: return [inputs,] else: raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.") else: assert isinstance(inputs, Node), "Received object of invalid type "\ f"({type(inputs)}) as input for node '{name}'." return [(inputs, 0),] def build_modules(self, verbose=True): ''' Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node. ''' if not self.input_dims: # Only do it if this hasn't been computed yet self.input_dims = [n.build_modules(verbose=verbose)[c] for n, c in self.inputs] try: if len(self.conditions) > 0: c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions] self.module = self.module_type(self.input_dims, dims_c=c_dims, **self.module_args) else: self.module = self.module_type(self.input_dims, **self.module_args) except Exception as e: print('Error in node %s' % (self.name)) raise e if verbose: print(f"Node '{self.name}' takes the following inputs:") for d, (n, c) in zip(self.input_dims, self.inputs): print(f"\t Output #{c} of node '{n.name}' with dims {d}") for c in self.conditions: print(f"\t conditioned on node '{c.name}' " + f"with dims {c.data.shape}") print() self.output_dims = self.module.output_dims(self.input_dims) self.n_outputs = len(self.output_dims) return self.output_dims def run_forward(self, op_list): '''Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)''' if not self.computed: # Compute all nodes which provide inputs, filter out the # channels you need self.input_vars = [] for i, (n, c) in enumerate(self.inputs): self.input_vars.append(n.run_forward(op_list)[c]) # Register self as an output in the input node n.outputs.append((self, i)) # Compute all nodes which provide conditioning self.condition_vars = [] for i, c in enumerate(self.conditions): self.condition_vars.append(c.run_forward(op_list)[0]) # Register self as an output in the condition node c.outputs.append((self, i)) # All outputs could now be computed self.computed = [(self.id, i) for i in range(self.n_outputs)] op_list.append((self.id, self.input_vars, self.computed, self.condition_vars)) # Return the variables you have computed (this happens mulitple times # without recomputing if called repeatedly) return self.computed def run_backward(self, op_list): '''See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work''' assert len(self.outputs) > 0, "Call run_forward first" if not self.computed_rev: # These are the input variables that must be computed first output_vars = [(self.id, i) for i in range(self.n_outputs)] # Recursively compute these for n, c in self.outputs: n.run_backward(op_list) # The variables that this node computes are the input variables # from the forward pass self.computed_rev = self.input_vars if len(self.condition_vars) == 0: self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions] op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars)) return self.computed_rev
Ancestors (in MRO)
- Node
- builtins.object
Static methods
def __init__(
self, inputs, module_type, module_args, conditions=[], name=None)
Initialize self. See help(type(self)) for accurate signature.
def __init__(self, inputs, module_type, module_args, conditions=[], name=None): self.inputs = self.parse_inputs(inputs) if isinstance(conditions, (list, tuple)): self.conditions = conditions else: self.conditions = [conditions,] self.outputs = [] self.module_type = module_type self.module_args = module_args self.input_dims = None self.module = None self.computed = None self.computed_rev = None self.id = None if name: self.name = name else: self.name = hex(id(self))[-6:] for i in range(255): exec('self.out{0} = (self, {0})'.format(i))
def build_modules(
self, verbose=True)
Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.
def build_modules(self, verbose=True): ''' Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node. ''' if not self.input_dims: # Only do it if this hasn't been computed yet self.input_dims = [n.build_modules(verbose=verbose)[c] for n, c in self.inputs] try: if len(self.conditions) > 0: c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions] self.module = self.module_type(self.input_dims, dims_c=c_dims, **self.module_args) else: self.module = self.module_type(self.input_dims, **self.module_args) except Exception as e: print('Error in node %s' % (self.name)) raise e if verbose: print(f"Node '{self.name}' takes the following inputs:") for d, (n, c) in zip(self.input_dims, self.inputs): print(f"\t Output #{c} of node '{n.name}' with dims {d}") for c in self.conditions: print(f"\t conditioned on node '{c.name}' " + f"with dims {c.data.shape}") print() self.output_dims = self.module.output_dims(self.input_dims) self.n_outputs = len(self.output_dims) return self.output_dims
def parse_inputs(
self, inputs)
def parse_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if isinstance(inputs[0], (list, tuple)): return inputs elif len(inputs) == 2: return [inputs,] else: raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.") else: assert isinstance(inputs, Node), "Received object of invalid type "\ f"({type(inputs)}) as input for node '{name}'." return [(inputs, 0),]
def run_backward(
self, op_list)
See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work
def run_backward(self, op_list): '''See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work''' assert len(self.outputs) > 0, "Call run_forward first" if not self.computed_rev: # These are the input variables that must be computed first output_vars = [(self.id, i) for i in range(self.n_outputs)] # Recursively compute these for n, c in self.outputs: n.run_backward(op_list) # The variables that this node computes are the input variables # from the forward pass self.computed_rev = self.input_vars if len(self.condition_vars) == 0: self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions] op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars)) return self.computed_rev
def run_forward(
self, op_list)
Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)
def run_forward(self, op_list): '''Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)''' if not self.computed: # Compute all nodes which provide inputs, filter out the # channels you need self.input_vars = [] for i, (n, c) in enumerate(self.inputs): self.input_vars.append(n.run_forward(op_list)[c]) # Register self as an output in the input node n.outputs.append((self, i)) # Compute all nodes which provide conditioning self.condition_vars = [] for i, c in enumerate(self.conditions): self.condition_vars.append(c.run_forward(op_list)[0]) # Register self as an output in the condition node c.outputs.append((self, i)) # All outputs could now be computed self.computed = [(self.id, i) for i in range(self.n_outputs)] op_list.append((self.id, self.input_vars, self.computed, self.condition_vars)) # Return the variables you have computed (this happens mulitple times # without recomputing if called repeatedly) return self.computed
Instance variables
var computed
var computed_rev
var id
var input_dims
var inputs
var module
var module_args
var module_type
var outputs
class OutputNode
Special type of node that represents the output of the whole net (of the input when running in reverse)
class OutputNode(Node): '''Special type of node that represents the output of the whole net (of the input when running in reverse)''' class dummy(nn.Module): def __init__(self, *args): super(OutputNode.dummy, self).__init__() def __call__(*args): return args def output_dims(*args): return args def __init__(self, inputs, name='node'): self.module_type, self.module_args = self.dummy, {} self.output_dims = [] self.inputs = self.parse_inputs(inputs) self.conditions = [] self.input_dims, self.module = None, None self.computed = None self.id = None self.name = name for c, inp in enumerate(self.inputs): inp[0].outputs.append((self, c)) def run_backward(self, op_list): return [(self.id, 0)]
Ancestors (in MRO)
- OutputNode
- Node
- builtins.object
Class variables
var dummy
Static methods
def __init__(
self, inputs, name='node')
Initialize self. See help(type(self)) for accurate signature.
def __init__(self, inputs, name='node'): self.module_type, self.module_args = self.dummy, {} self.output_dims = [] self.inputs = self.parse_inputs(inputs) self.conditions = [] self.input_dims, self.module = None, None self.computed = None self.id = None self.name = name for c, inp in enumerate(self.inputs): inp[0].outputs.append((self, c))
def build_modules(
self, verbose=True)
Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.
def build_modules(self, verbose=True): ''' Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node. ''' if not self.input_dims: # Only do it if this hasn't been computed yet self.input_dims = [n.build_modules(verbose=verbose)[c] for n, c in self.inputs] try: if len(self.conditions) > 0: c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions] self.module = self.module_type(self.input_dims, dims_c=c_dims, **self.module_args) else: self.module = self.module_type(self.input_dims, **self.module_args) except Exception as e: print('Error in node %s' % (self.name)) raise e if verbose: print(f"Node '{self.name}' takes the following inputs:") for d, (n, c) in zip(self.input_dims, self.inputs): print(f"\t Output #{c} of node '{n.name}' with dims {d}") for c in self.conditions: print(f"\t conditioned on node '{c.name}' " + f"with dims {c.data.shape}") print() self.output_dims = self.module.output_dims(self.input_dims) self.n_outputs = len(self.output_dims) return self.output_dims
def parse_inputs(
self, inputs)
def parse_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if isinstance(inputs[0], (list, tuple)): return inputs elif len(inputs) == 2: return [inputs,] else: raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.") else: assert isinstance(inputs, Node), "Received object of invalid type "\ f"({type(inputs)}) as input for node '{name}'." return [(inputs, 0),]
def run_backward(
self, op_list)
See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work
def run_backward(self, op_list): return [(self.id, 0)]
def run_forward(
self, op_list)
Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)
def run_forward(self, op_list): '''Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)''' if not self.computed: # Compute all nodes which provide inputs, filter out the # channels you need self.input_vars = [] for i, (n, c) in enumerate(self.inputs): self.input_vars.append(n.run_forward(op_list)[c]) # Register self as an output in the input node n.outputs.append((self, i)) # Compute all nodes which provide conditioning self.condition_vars = [] for i, c in enumerate(self.conditions): self.condition_vars.append(c.run_forward(op_list)[0]) # Register self as an output in the condition node c.outputs.append((self, i)) # All outputs could now be computed self.computed = [(self.id, i) for i in range(self.n_outputs)] op_list.append((self.id, self.input_vars, self.computed, self.condition_vars)) # Return the variables you have computed (this happens mulitple times # without recomputing if called repeatedly) return self.computed
Instance variables
var computed
var conditions
var id
var inputs
var name
var output_dims
class ReversibleGraphNet
This class represents the invertible net itself. It is a subclass of torch.nn.Module and supports the same methods. The forward method has an additional option 'rev', whith which the net can be computed in reverse.
class ReversibleGraphNet(nn.Module): '''This class represents the invertible net itself. It is a subclass of torch.nn.Module and supports the same methods. The forward method has an additional option 'rev', whith which the net can be computed in reverse.''' def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True): '''node_list should be a list of all nodes involved, and ind_in, ind_out are the indexes of the special nodes InputNode and OutputNode in this list.''' super(ReversibleGraphNet, self).__init__() # Gather lists of input, output and condition nodes if ind_in is not None: warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " + "input and output nodes are detected automatically.") if isinstance(ind_in, int): self.ind_in = list([ind_in]) else: self.ind_in = ind_in else: self.ind_in = [i for i in range(len(node_list)) if isinstance(node_list[i], InputNode)] assert len(self.ind_in) > 0, "No input nodes specified." if ind_out is not None: warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " + "input and output nodes are detected automatically.") if isinstance(ind_out, int): self.ind_out = list([ind_out]) else: self.ind_out = ind_out else: self.ind_out = [i for i in range(len(node_list)) if isinstance(node_list[i], OutputNode)] assert len(self.ind_out) > 0, "No output nodes specified." self.ind_cond = [i for i in range(len(node_list)) if isinstance(node_list[i], ConditionNode)] self.return_vars = [] self.input_vars = [] self.cond_vars = [] # Assign each node a unique ID self.node_list = node_list for i, n in enumerate(node_list): n.id = i n.graph = self # Recursively build the nodes nn.Modules and determine order of # operations ops = [] for i in self.ind_out: node_list[i].build_modules(verbose=verbose) node_list[i].run_forward(ops) # create list of Pytorch variables that are used variables = set() for o in ops: variables = variables.union(set(o[1] + o[2] + o[3])) self.variables_ind = list(variables) self.indexed_ops = self.ops_to_indexed(ops) self.module_list = nn.ModuleList([n.module for n in node_list]) self.module_cond = [(len(n.conditions) > 0) for n in node_list] self._buffers = {F'tmp_var_{i}' : None for i in range(len(variables))} # Find out the order of operations for reverse calculations ops_rev = [] for i in self.ind_in + self.ind_cond: node_list[i].run_backward(ops_rev) self.indexed_ops_rev = self.ops_to_indexed(ops_rev) def ops_to_indexed(self, ops): '''Helper function to translate the list of variables (origin ID, channel), to variable IDs.''' result = [] for o in ops: try: vars_in = [self.variables_ind.index(v) for v in o[1]] except ValueError: vars_in = -1 vars_out = [self.variables_ind.index(v) for v in o[2]] vars_cond = [self.variables_ind.index(v) for v in o[3]] # Collect input/output/conditioning nodes in separate lists, but don't # add to indexed ops if o[0] in self.ind_out: self.return_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_in: self.input_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_cond: if self.variables_ind.index(o[1][0]) not in self.cond_vars: self.cond_vars.append(self.variables_ind.index(o[1][0])) else: print('Is this branch ever reached?') continue result.append((o[0], vars_in, vars_out, vars_cond)) # Sort input/output/conditioning variables so they correspond to initial # node list order self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) self.cond_vars.sort(key=lambda i: self.variables_ind[i][0]) return result def forward(self, x, c=None, rev=False, intermediate_outputs=False): '''Forward or backward computation of the whole net.''' if rev: use_list = self.indexed_ops_rev input_vars, output_vars = self.return_vars, self.input_vars else: use_list = self.indexed_ops input_vars, output_vars = self.input_vars, self.return_vars # Assign input data to respective variables if isinstance(x, (list, tuple)): assert len(x) == len(input_vars), ( f"Got list of {len(x)} input tensors for " f"{'inverse' if rev else 'forward'} pass, but expected " f"{len(input_vars)}." ) for i in range(len(input_vars)): self._buffers[F'tmp_var_{input_vars[i]}'] = x[i] else: assert len(input_vars) == 1, (f"Got single input tensor for " f"{'inverse' if rev else 'forward'} " f"pass, but expected list of " f"{len(input_vars)}.") self._buffers[F'tmp_var_{input_vars[0]}'] = x # Assign conditioning data to respective variables if c is None: assert len(self.cond_vars) == 0 elif isinstance(c, (list, tuple)): assert len(c) == len(self.cond_vars), f'{len(c)}, {len(self.cond_vars)}' for i in range(len(self.cond_vars)): self._buffers[F'tmp_var_{self.cond_vars[i]}'] = c[i] else: assert len(self.cond_vars) == 1 self._buffers[F'tmp_var_{self.cond_vars[0]}'] = c # Prepare dictionary for intermediate node outputs out_dict = {} # Run all modules with the given inputs for o in use_list: try: x = [self._buffers[F'tmp_var_{i}'] for i in o[1]] if self.module_cond[o[0]]: c = [self._buffers[F'tmp_var_{i}'] for i in o[3]] results = self.module_list[o[0]](x, c=c, rev=rev) else: results = self.module_list[o[0]](x, rev=rev) except TypeError: raise RuntimeError("Are you sure all used Nodes are in the " "Node list?") out_dict[self.node_list[o[0]].name] = results for i, r in zip(o[2], results): self._buffers[F'tmp_var_{i}'] = r if intermediate_outputs: return out_dict else: out = [self._buffers[F'tmp_var_{output_vars[i]}'] for i in range(len(output_vars))] if len(out) == 1: return out[0] else: return out def log_jacobian(self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False): '''Compute the log jacobian determinant of the whole net.''' if run_forward or c is not None: self.condition = c jacobian = 0 if rev: use_list = self.indexed_ops_rev else: use_list = self.indexed_ops if run_forward: if x is None: raise RuntimeError("You need to provide an input if you want " "to run a forward pass") self.forward(x, c, rev=rev) # Prepare dictionary for intermediate node outputs jacobian_dict = {} # Run all modules with the given inputs for o in use_list: x = [self._buffers[F'tmp_var_{i}'] for i in o[1]] if self.module_cond[o[0]]: c = [self._buffers[F'tmp_var_{i}'] for i in o[3]] module_jacobian = self.module_list[o[0]].jacobian(x, c=c, rev=rev) else: module_jacobian = self.module_list[o[0]].jacobian(x, rev=rev) jacobian += module_jacobian jacobian_dict[self.node_list[o[0]].name] = module_jacobian if intermediate_outputs: return jacobian_dict else: return jacobian def jacobian(self, *args, **kwargs): '''Compute the log jacobian determinant of the whole net.''' warnings.warn("This function computes the log-jacobian determinant, not the " "jacobian as the name suggest. Will be removed in the future.") return self.log_jacobian(*args, **kwargs) def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04): '''Approximate log Jacobian determinant via finite differences.''' if isinstance(x, (list, tuple)): batch_size = x[0].shape[0] ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x] ndim_x_total = sum(ndim_x_separate) x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1) else: batch_size = x.shape[0] ndim_x_total = np.prod(x.shape[1:]) x_flat = x.reshape(batch_size, -1) J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total) for i in range(ndim_x_total): offset = x[0].new_zeros(batch_size, ndim_x_total) offset[:,i] = h if isinstance(x, (list, tuple)): x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1) x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))] x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1) x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))] else: x_upper = (x_flat + offset).view(*x.shape) x_lower = (x_flat - offset).view(*x.shape) y_upper = self.forward(x_upper, c=c) y_lower = self.forward(x_lower, c=c) if isinstance(y_upper, (list, tuple)): y_upper = torch.cat([y_i.view(batch_size, -1) for y_i in y_upper], dim=1) y_lower = torch.cat([y_i.view(batch_size, -1) for y_i in y_lower], dim=1) J_num[:,:,i] = (y_upper - y_lower).view(batch_size, -1) / (2*h) logdet_num = x[0].new_zeros(batch_size) for i in range(batch_size): logdet_num[i] = torch.det(J_num[i,:,:]).abs().log() return logdet_num def load_state_dict(self, state_dict, *args, **kwargs): state_dict_no_buffers = {} for k,p in state_dict.items(): if k in self._buffers and self._buffers[k] is None: continue state_dict_no_buffers[k] = p return super().load_state_dict(state_dict_no_buffers, *args, **kwargs)
Ancestors (in MRO)
- ReversibleGraphNet
- torch.nn.modules.module.Module
- builtins.object
Class variables
var dump_patches
Static methods
def __init__(
self, node_list, ind_in=None, ind_out=None, verbose=True)
node_list should be a list of all nodes involved, and ind_in, ind_out are the indexes of the special nodes InputNode and OutputNode in this list.
def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True): '''node_list should be a list of all nodes involved, and ind_in, ind_out are the indexes of the special nodes InputNode and OutputNode in this list.''' super(ReversibleGraphNet, self).__init__() # Gather lists of input, output and condition nodes if ind_in is not None: warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " + "input and output nodes are detected automatically.") if isinstance(ind_in, int): self.ind_in = list([ind_in]) else: self.ind_in = ind_in else: self.ind_in = [i for i in range(len(node_list)) if isinstance(node_list[i], InputNode)] assert len(self.ind_in) > 0, "No input nodes specified." if ind_out is not None: warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " + "input and output nodes are detected automatically.") if isinstance(ind_out, int): self.ind_out = list([ind_out]) else: self.ind_out = ind_out else: self.ind_out = [i for i in range(len(node_list)) if isinstance(node_list[i], OutputNode)] assert len(self.ind_out) > 0, "No output nodes specified." self.ind_cond = [i for i in range(len(node_list)) if isinstance(node_list[i], ConditionNode)] self.return_vars = [] self.input_vars = [] self.cond_vars = [] # Assign each node a unique ID self.node_list = node_list for i, n in enumerate(node_list): n.id = i n.graph = self # Recursively build the nodes nn.Modules and determine order of # operations ops = [] for i in self.ind_out: node_list[i].build_modules(verbose=verbose) node_list[i].run_forward(ops) # create list of Pytorch variables that are used variables = set() for o in ops: variables = variables.union(set(o[1] + o[2] + o[3])) self.variables_ind = list(variables) self.indexed_ops = self.ops_to_indexed(ops) self.module_list = nn.ModuleList([n.module for n in node_list]) self.module_cond = [(len(n.conditions) > 0) for n in node_list] self._buffers = {F'tmp_var_{i}' : None for i in range(len(variables))} # Find out the order of operations for reverse calculations ops_rev = [] for i in self.ind_in + self.ind_cond: node_list[i].run_backward(ops_rev) self.indexed_ops_rev = self.ops_to_indexed(ops_rev)
def add_module(
self, name, module)
Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args: name (string): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module.
def add_module(self, name, module): r"""Adds a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (string): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. """ if not isinstance(module, Module) and module is not None: raise TypeError("{} is not a Module subclass".format( torch.typename(module))) elif not isinstance(name, torch._six.string_classes): raise TypeError("module name should be a string. Got {}".format( torch.typename(name))) elif hasattr(self, name) and name not in self._modules: raise KeyError("attribute '{}' already exists".format(name)) elif '.' in name: raise KeyError("module name can't contain \".\"") elif name == '': raise KeyError("module name can't be empty string \"\"") self._modules[name] = module
def apply(
self, fn)
Applies fn recursively to every submodule (as returned by .children())
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:torch-nn-init).
Args:
fn (:class:Module -> None): function to be applied to each submodule
Returns: Module: self
Example::
>>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.data.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )
def apply(self, fn): r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`torch-nn-init`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.data.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) """ for module in self.children(): module.apply(fn) fn(self) return self
def buffers(
self, recurse=True)
Returns an iterator over module buffers.
Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
Yields: torch.Tensor: module buffer
Example::
>>> for buf in model.buffers(): >>> print(type(buf.data), buf.size()) <class 'torch.FloatTensor'> (20L,) <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
def buffers(self, recurse=True): r"""Returns an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> for buf in model.buffers(): >>> print(type(buf.data), buf.size()) <class 'torch.FloatTensor'> (20L,) <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) """ for name, buf in self.named_buffers(recurse=recurse): yield buf
def children(
self)
Returns an iterator over immediate children modules.
Yields: Module: a child module
def children(self): r"""Returns an iterator over immediate children modules. Yields: Module: a child module """ for name, module in self.named_children(): yield module
def cpu(
self)
Moves all model parameters and buffers to the CPU.
Returns: Module: self
def cpu(self): r"""Moves all model parameters and buffers to the CPU. Returns: Module: self """ return self._apply(lambda t: t.cpu())
def cuda(
self, device=None)
Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
Arguments: device (int, optional): if specified, all parameters will be copied to that device
Returns: Module: self
def cuda(self, device=None): r"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self """ return self._apply(lambda t: t.cuda(device))
def double(
self)
Casts all floating point parameters and buffers to double datatype.
Returns: Module: self
def double(self): r"""Casts all floating point parameters and buffers to ``double`` datatype. Returns: Module: self """ return self._apply(lambda t: t.double() if t.is_floating_point() else t)
def eval(
self)
Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm,
etc.
def eval(self): r"""Sets the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. """ return self.train(False)
def extra_repr(
self)
Set the extra representation of the module
To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.
def extra_repr(self): r"""Set the extra representation of the module To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable. """ return ''
def float(
self)
Casts all floating point parameters and buffers to float datatype.
Returns: Module: self
def float(self): r"""Casts all floating point parameters and buffers to float datatype. Returns: Module: self """ return self._apply(lambda t: t.float() if t.is_floating_point() else t)
def forward(
self, x, c=None, rev=False, intermediate_outputs=False)
Forward or backward computation of the whole net.
def forward(self, x, c=None, rev=False, intermediate_outputs=False): '''Forward or backward computation of the whole net.''' if rev: use_list = self.indexed_ops_rev input_vars, output_vars = self.return_vars, self.input_vars else: use_list = self.indexed_ops input_vars, output_vars = self.input_vars, self.return_vars # Assign input data to respective variables if isinstance(x, (list, tuple)): assert len(x) == len(input_vars), ( f"Got list of {len(x)} input tensors for " f"{'inverse' if rev else 'forward'} pass, but expected " f"{len(input_vars)}." ) for i in range(len(input_vars)): self._buffers[F'tmp_var_{input_vars[i]}'] = x[i] else: assert len(input_vars) == 1, (f"Got single input tensor for " f"{'inverse' if rev else 'forward'} " f"pass, but expected list of " f"{len(input_vars)}.") self._buffers[F'tmp_var_{input_vars[0]}'] = x # Assign conditioning data to respective variables if c is None: assert len(self.cond_vars) == 0 elif isinstance(c, (list, tuple)): assert len(c) == len(self.cond_vars), f'{len(c)}, {len(self.cond_vars)}' for i in range(len(self.cond_vars)): self._buffers[F'tmp_var_{self.cond_vars[i]}'] = c[i] else: assert len(self.cond_vars) == 1 self._buffers[F'tmp_var_{self.cond_vars[0]}'] = c # Prepare dictionary for intermediate node outputs out_dict = {} # Run all modules with the given inputs for o in use_list: try: x = [self._buffers[F'tmp_var_{i}'] for i in o[1]] if self.module_cond[o[0]]: c = [self._buffers[F'tmp_var_{i}'] for i in o[3]] results = self.module_list[o[0]](x, c=c, rev=rev) else: results = self.module_list[o[0]](x, rev=rev) except TypeError: raise RuntimeError("Are you sure all used Nodes are in the " "Node list?") out_dict[self.node_list[o[0]].name] = results for i, r in zip(o[2], results): self._buffers[F'tmp_var_{i}'] = r if intermediate_outputs: return out_dict else: out = [self._buffers[F'tmp_var_{output_vars[i]}'] for i in range(len(output_vars))] if len(out) == 1: return out[0] else: return out
def half(
self)
Casts all floating point parameters and buffers to half datatype.
Returns: Module: self
def half(self): r"""Casts all floating point parameters and buffers to ``half`` datatype. Returns: Module: self """ return self._apply(lambda t: t.half() if t.is_floating_point() else t)
def jacobian(
self, *args, **kwargs)
Compute the log jacobian determinant of the whole net.
def jacobian(self, *args, **kwargs): '''Compute the log jacobian determinant of the whole net.''' warnings.warn("This function computes the log-jacobian determinant, not the " "jacobian as the name suggest. Will be removed in the future.") return self.log_jacobian(*args, **kwargs)
def load_state_dict(
self, state_dict, *args, **kwargs)
Copies parameters and buffers from :attr:state_dict into
this module and its descendants. If :attr:strict is True, then
the keys of :attr:state_dict must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict function.
Arguments:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
:meth:~torch.nn.Module.state_dict function. Default: True
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
* missing_keys is a list of str containing the missing keys
* unexpected_keys is a list of str containing the unexpected keys
def load_state_dict(self, state_dict, *args, **kwargs): state_dict_no_buffers = {} for k,p in state_dict.items(): if k in self._buffers and self._buffers[k] is None: continue state_dict_no_buffers[k] = p return super().load_state_dict(state_dict_no_buffers, *args, **kwargs)
def log_jacobian(
self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False)
Compute the log jacobian determinant of the whole net.
def log_jacobian(self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False): '''Compute the log jacobian determinant of the whole net.''' if run_forward or c is not None: self.condition = c jacobian = 0 if rev: use_list = self.indexed_ops_rev else: use_list = self.indexed_ops if run_forward: if x is None: raise RuntimeError("You need to provide an input if you want " "to run a forward pass") self.forward(x, c, rev=rev) # Prepare dictionary for intermediate node outputs jacobian_dict = {} # Run all modules with the given inputs for o in use_list: x = [self._buffers[F'tmp_var_{i}'] for i in o[1]] if self.module_cond[o[0]]: c = [self._buffers[F'tmp_var_{i}'] for i in o[3]] module_jacobian = self.module_list[o[0]].jacobian(x, c=c, rev=rev) else: module_jacobian = self.module_list[o[0]].jacobian(x, rev=rev) jacobian += module_jacobian jacobian_dict[self.node_list[o[0]].name] = module_jacobian if intermediate_outputs: return jacobian_dict else: return jacobian
def log_jacobian_numerical(
self, x, c=None, rev=False, h=0.0001)
Approximate log Jacobian determinant via finite differences.
def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04): '''Approximate log Jacobian determinant via finite differences.''' if isinstance(x, (list, tuple)): batch_size = x[0].shape[0] ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x] ndim_x_total = sum(ndim_x_separate) x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1) else: batch_size = x.shape[0] ndim_x_total = np.prod(x.shape[1:]) x_flat = x.reshape(batch_size, -1) J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total) for i in range(ndim_x_total): offset = x[0].new_zeros(batch_size, ndim_x_total) offset[:,i] = h if isinstance(x, (list, tuple)): x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1) x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))] x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1) x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))] else: x_upper = (x_flat + offset).view(*x.shape) x_lower = (x_flat - offset).view(*x.shape) y_upper = self.forward(x_upper, c=c) y_lower = self.forward(x_lower, c=c) if isinstance(y_upper, (list, tuple)): y_upper = torch.cat([y_i.view(batch_size, -1) for y_i in y_upper], dim=1) y_lower = torch.cat([y_i.view(batch_size, -1) for y_i in y_lower], dim=1) J_num[:,:,i] = (y_upper - y_lower).view(batch_size, -1) / (2*h) logdet_num = x[0].new_zeros(batch_size) for i in range(batch_size): logdet_num[i] = torch.det(J_num[i,:,:]).abs().log() return logdet_num
def modules(
self)
Returns an iterator over all modules in the network.
Yields: Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
Example::
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)
def modules(self): r"""Returns an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) """ for name, module in self.named_modules(): yield module
def named_buffers(
self, prefix='', recurse=True)
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
Args: prefix (str): prefix to prepend to all buffer names. recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
Yields: (string, torch.Tensor): Tuple containing the name and buffer
Example::
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
def named_buffers(self, prefix='', recurse=True): r"""Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: (string, torch.Tensor): Tuple containing the name and buffer Example:: >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) """ gen = self._named_members( lambda module: module._buffers.items(), prefix=prefix, recurse=recurse) for elem in gen: yield elem
def named_children(
self)
Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
Yields: (string, Module): Tuple containing a name and child module
Example::
>>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module)
def named_children(self): r"""Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (string, Module): Tuple containing a name and child module Example:: >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) """ memo = set() for name, module in self._modules.items(): if module is not None and module not in memo: memo.add(module) yield name, module
def named_modules(
self, memo=None, prefix='')
Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Yields: (string, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
Example::
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
def named_modules(self, memo=None, prefix=''): r"""Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Yields: (string, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) """ if memo is None: memo = set() if self not in memo: memo.add(self) yield prefix, self for name, module in self._modules.items(): if module is None: continue submodule_prefix = prefix + ('.' if prefix else '') + name for m in module.named_modules(memo, submodule_prefix): yield m
def named_parameters(
self, prefix='', recurse=True)
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Yields: (string, Parameter): Tuple containing the name and parameter
Example::
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
def named_parameters(self, prefix='', recurse=True): r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: (string, Parameter): Tuple containing the name and parameter Example:: >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) """ gen = self._named_members( lambda module: module._parameters.items(), prefix=prefix, recurse=recurse) for elem in gen: yield elem
def ops_to_indexed(
self, ops)
Helper function to translate the list of variables (origin ID, channel), to variable IDs.
def ops_to_indexed(self, ops): '''Helper function to translate the list of variables (origin ID, channel), to variable IDs.''' result = [] for o in ops: try: vars_in = [self.variables_ind.index(v) for v in o[1]] except ValueError: vars_in = -1 vars_out = [self.variables_ind.index(v) for v in o[2]] vars_cond = [self.variables_ind.index(v) for v in o[3]] # Collect input/output/conditioning nodes in separate lists, but don't # add to indexed ops if o[0] in self.ind_out: self.return_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_in: self.input_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_cond: if self.variables_ind.index(o[1][0]) not in self.cond_vars: self.cond_vars.append(self.variables_ind.index(o[1][0])) else: print('Is this branch ever reached?') continue result.append((o[0], vars_in, vars_out, vars_cond)) # Sort input/output/conditioning variables so they correspond to initial # node list order self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) self.cond_vars.sort(key=lambda i: self.variables_ind[i][0]) return result
def parameters(
self, recurse=True)
Returns an iterator over module parameters.
This is typically passed to an optimizer.
Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Yields: Parameter: module parameter
Example::
>>> for param in model.parameters(): >>> print(type(param.data), param.size()) <class 'torch.FloatTensor'> (20L,) <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
def parameters(self, recurse=True): r"""Returns an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> for param in model.parameters(): >>> print(type(param.data), param.size()) <class 'torch.FloatTensor'> (20L,) <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) """ for name, param in self.named_parameters(recurse=recurse): yield param
def register_backward_hook(
self, hook)
Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:grad_input and :attr:grad_output may be tuples if the
module has multiple inputs or outputs. The hook should not modify its
arguments, but it can optionally return a new gradient with respect to
input that will be used in place of :attr:grad_input in subsequent
computations.
Returns:
:class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
.. warning ::
The current implementation will not have the presented behavior for complex :class:`Module` that perform many operations. In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only contain the gradients for a subset of the inputs and outputs. For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` directly on a specific input or output to get the required gradients.
def register_backward_hook(self, hook): r"""Registers a backward hook on the module. The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> Tensor or None The :attr:`grad_input` and :attr:`grad_output` may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of :attr:`grad_input` in subsequent computations. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` .. warning :: The current implementation will not have the presented behavior for complex :class:`Module` that perform many operations. In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only contain the gradients for a subset of the inputs and outputs. For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` directly on a specific input or output to get the required gradients. """ handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle
def register_buffer(
self, name, tensor)
Adds a persistent buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's running_mean
is not a parameter, but is part of the persistent state.
Buffers can be accessed as attributes using given names.
Args: name (string): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor): buffer to be registered.
Example::
>>> self.register_buffer('running_mean', torch.zeros(num_features))
def register_buffer(self, name, tensor): r"""Adds a persistent buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the persistent state. Buffers can be accessed as attributes using given names. Args: name (string): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor): buffer to be registered. Example:: >>> self.register_buffer('running_mean', torch.zeros(num_features)) """ if '_buffers' not in self.__dict__: raise AttributeError( "cannot assign buffer before Module.__init__() call") elif not isinstance(name, torch._six.string_classes): raise TypeError("buffer name should be a string. " "Got {}".format(torch.typename(name))) elif '.' in name: raise KeyError("buffer name can't contain \".\"") elif name == '': raise KeyError("buffer name can't be empty string \"\"") elif hasattr(self, name) and name not in self._buffers: raise KeyError("attribute '{}' already exists".format(name)) elif tensor is not None and not isinstance(tensor, torch.Tensor): raise TypeError("cannot assign '{}' object to buffer '{}' " "(torch Tensor or None required)" .format(torch.typename(tensor), name)) else: self._buffers[name] = tensor
def register_forward_hook(
self, hook)
Registers a forward hook on the module.
The hook will be called every time after :func:forward has computed an output.
It should have the following signature::
hook(module, input, output) -> None
The hook should not modify the input or output.
Returns:
:class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
def register_forward_hook(self, hook): r"""Registers a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. It should have the following signature:: hook(module, input, output) -> None The hook should not modify the input or output. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[handle.id] = hook return handle
def register_forward_pre_hook(
self, hook)
Registers a forward pre-hook on the module.
The hook will be called every time before :func:forward is invoked.
It should have the following signature::
hook(module, input) -> None
The hook should not modify the input.
Returns:
:class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
def register_forward_pre_hook(self, hook): r"""Registers a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. It should have the following signature:: hook(module, input) -> None The hook should not modify the input. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._forward_pre_hooks) self._forward_pre_hooks[handle.id] = hook return handle
def register_parameter(
self, name, param)
Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args: name (string): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter): parameter to be added to the module.
def register_parameter(self, name, param): r"""Adds a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (string): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter): parameter to be added to the module. """ if '_parameters' not in self.__dict__: raise AttributeError( "cannot assign parameter before Module.__init__() call") elif not isinstance(name, torch._six.string_classes): raise TypeError("parameter name should be a string. " "Got {}".format(torch.typename(name))) elif '.' in name: raise KeyError("parameter name can't contain \".\"") elif name == '': raise KeyError("parameter name can't be empty string \"\"") elif hasattr(self, name) and name not in self._parameters: raise KeyError("attribute '{}' already exists".format(name)) if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): raise TypeError("cannot assign '{}' object to parameter '{}' " "(torch.nn.Parameter or None required)" .format(torch.typename(param), name)) elif param.grad_fn: raise ValueError( "Cannot assign non-leaf Tensor to parameter '{0}'. Model " "parameters must be created explicitly. To express '{0}' " "as a function of another Tensor, compute the value in " "the forward() method.".format(name)) else: self._parameters[name] = param
def state_dict(
self, destination=None, prefix='', keep_vars=False)
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
Returns: dict: a dictionary containing a whole state of the module
Example::
>>> module.state_dict().keys() ['bias', 'weight']
def state_dict(self, destination=None, prefix='', keep_vars=False): r"""Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Returns: dict: a dictionary containing a whole state of the module Example:: >>> module.state_dict().keys() ['bias', 'weight'] """ if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) for name, param in self._parameters.items(): if param is not None: destination[prefix + name] = param if keep_vars else param.data for name, buf in self._buffers.items(): if buf is not None: destination[prefix + name] = buf if keep_vars else buf.data for name, module in self._modules.items(): if module is not None: module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination
def to(
self, *args, **kwargs)
Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:torch.Tensor.to, but only accepts
floating point desired :attr:dtype s. In addition, this method will
only cast the floating point parameters and buffers to :attr:dtype
(if given). The integral parameters and buffers will be moved
:attr:device, if that is given, but with dtypes unchanged. When
:attr:non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note:: This method modifies the module in-place.
Args:
device (:class:torch.device): the desired device of the parameters
and buffers in this module
dtype (:class:torch.dtype): the desired floating point type of
the floating point parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
Returns: Module: self
Example::
>>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16)
def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) .. function:: to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired :attr:`dtype` s. In addition, this method will only cast the floating point parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module Returns: Module: self Example:: >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) """ device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) if dtype is not None: if not dtype.is_floating_point: raise TypeError('nn.Module.to only accepts floating point ' 'dtypes, but got desired dtype={}'.format(dtype)) def convert(t): return t.to(device, dtype if t.is_floating_point() else None, non_blocking) return self._apply(convert)
def train(
self, mode=True)
Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm,
etc.
Returns: Module: self
def train(self, mode=True): r"""Sets the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self
def type(
self, dst_type)
Casts all parameters and buffers to :attr:dst_type.
Arguments: dst_type (type or string): the desired type
Returns: Module: self
def type(self, dst_type): r"""Casts all parameters and buffers to :attr:`dst_type`. Arguments: dst_type (type or string): the desired type Returns: Module: self """ return self._apply(lambda t: t.type(dst_type))
def zero_grad(
self)
Sets gradients of all model parameters to zero.
def zero_grad(self): r"""Sets gradients of all model parameters to zero.""" for p in self.parameters(): if p.grad is not None: p.grad.detach_() p.grad.zero_()
Instance variables
var cond_vars
var ind_cond
var indexed_ops
var indexed_ops_rev
var input_vars
var module_cond
var module_list
var node_list
var return_vars
var variables_ind