Top

FrEIA.framework module

The framework module contains the logic used in building the graph and inferring the order that the nodes have to be executed in forward and backward direction.

'''The framework module contains the logic used in building the graph and
inferring the order that the nodes have to be executed in forward and backward
direction.'''

import warnings
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable

from . import dummy_modules as dummys


class Node:
    '''The Node class represents one transformation in the graph, with an
    arbitrary number of in- and outputs.'''
    def __init__(self, inputs, module_type, module_args, conditions=[], name=None):
        self.inputs = self.parse_inputs(inputs)
        if isinstance(conditions, (list, tuple)):
            self.conditions = conditions
        else:
            self.conditions = [conditions,]

        self.outputs = []
        self.module_type = module_type
        self.module_args = module_args

        self.input_dims = None
        self.module = None
        self.computed = None
        self.computed_rev = None
        self.id = None

        if name:
            self.name = name
        else:
            self.name = hex(id(self))[-6:]
        for i in range(255):
            exec('self.out{0} = (self, {0})'.format(i))

    def parse_inputs(self, inputs):
        if isinstance(inputs, (list, tuple)):
            if isinstance(inputs[0], (list, tuple)):
                return inputs
            elif len(inputs) == 2:
                return [inputs,]
            else:
                raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.")
        else:
            assert isinstance(inputs, Node), "Received object of invalid type "\
                f"({type(inputs)}) as input for node '{name}'."
            return [(inputs, 0),]

    def build_modules(self, verbose=True):
        ''' Returns a list with the dimension of each output of this node,
        recursively calling build_modules of the nodes connected to the input.
        Use this information to initialize the pytorch nn.Module of this node.
        '''

        if not self.input_dims:  # Only do it if this hasn't been computed yet
            self.input_dims = [n.build_modules(verbose=verbose)[c]
                               for n, c in self.inputs]
            try:
                if len(self.conditions) > 0:
                    c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions]
                    self.module = self.module_type(self.input_dims, dims_c=c_dims,
                                                   **self.module_args)
                else:
                    self.module = self.module_type(self.input_dims,
                                                   **self.module_args)
            except Exception as e:
                print('Error in node %s' % (self.name))
                raise e

            if verbose:
                print(f"Node '{self.name}' takes the following inputs:")
                for d, (n, c) in zip(self.input_dims, self.inputs):
                    print(f"\t Output #{c} of node '{n.name}' with dims {d}")
                for c in self.conditions:
                    print(f"\t conditioned on node '{c.name}' " +
                          f"with dims {c.data.shape}")
                print()

            self.output_dims = self.module.output_dims(self.input_dims)
            self.n_outputs = len(self.output_dims)

        return self.output_dims

    def run_forward(self, op_list):
        '''Determine the order of operations needed to reach this node. Calls
        run_forward of parent nodes recursively. Each operation is appended to
        the global list op_list, in the form (node ID, input variable IDs,
        output variable IDs)'''

        if not self.computed:

            # Compute all nodes which provide inputs, filter out the
            # channels you need
            self.input_vars = []
            for i, (n, c) in enumerate(self.inputs):
                self.input_vars.append(n.run_forward(op_list)[c])
                # Register self as an output in the input node
                n.outputs.append((self, i))
            # Compute all nodes which provide conditioning
            self.condition_vars = []
            for i, c in enumerate(self.conditions):
                self.condition_vars.append(c.run_forward(op_list)[0])
                # Register self as an output in the condition node
                c.outputs.append((self, i))

            # All outputs could now be computed
            self.computed = [(self.id, i) for i in range(self.n_outputs)]
            op_list.append((self.id, self.input_vars, self.computed, self.condition_vars))

        # Return the variables you have computed (this happens mulitple times
        # without recomputing if called repeatedly)
        return self.computed

    def run_backward(self, op_list):
        '''See run_forward, this is the same, only for the reverse computation.
        Need to call run_forward first, otherwise this function will not
        work'''

        assert len(self.outputs) > 0, "Call run_forward first"
        if not self.computed_rev:

            # These are the input variables that must be computed first
            output_vars = [(self.id, i) for i in range(self.n_outputs)]

            # Recursively compute these
            for n, c in self.outputs:
                n.run_backward(op_list)

            # The variables that this node computes are the input variables
            # from the forward pass
            self.computed_rev = self.input_vars
            if len(self.condition_vars) == 0:
                self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions]
            op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars))

        return self.computed_rev


class InputNode(Node):
    '''Special type of node that represents the input data of the whole net (or
    ouput when running reverse)'''

    def __init__(self, *dims, name='node'):
        self.name = name
        self.data = dummys.dummy_data(*dims)
        self.outputs = []
        self.conditions = []
        self.condition_vars = []
        self.module = None
        self.computed_rev = None
        self.n_outputs = 1
        self.input_vars = []
        self.out0 = (self, 0)

    def build_modules(self, verbose=True):
        return [self.data.shape]

    def run_forward(self, op_list):
        return [(self.id, 0)]


class ConditionNode(Node):
    '''Special type of node that represents contitional input to the internal
    networks inside coupling layers'''

    def __init__(self, *dims, name='node'):
        self.name = name
        self.data = dummys.dummy_data(*dims)
        self.outputs = []
        self.conditions = []
        self.condition_vars = []
        self.module = None
        self.computed_rev = None
        self.n_outputs = 1
        self.input_vars = []
        self.out0 = (self, 0)

    def build_modules(self, verbose=True):
        return [self.data.shape]

    def run_forward(self, op_list):
        return [(self.id, 0)]


class OutputNode(Node):
    '''Special type of node that represents the output of the whole net (of the
    input when running in reverse)'''
    class dummy(nn.Module):

        def __init__(self, *args):
            super(OutputNode.dummy, self).__init__()

        def __call__(*args):
            return args

        def output_dims(*args):
            return args

    def __init__(self, inputs, name='node'):
        self.module_type, self.module_args = self.dummy, {}
        self.output_dims = []
        self.inputs = self.parse_inputs(inputs)
        self.conditions = []
        self.input_dims, self.module = None, None
        self.computed = None
        self.id = None
        self.name = name

        for c, inp in enumerate(self.inputs):
            inp[0].outputs.append((self, c))

    def run_backward(self, op_list):
        return [(self.id, 0)]


class ReversibleGraphNet(nn.Module):
    '''This class represents the invertible net itself. It is a subclass of
    torch.nn.Module and supports the same methods. The forward method has an
    additional option 'rev', whith which the net can be computed in reverse.'''

    def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True):
        '''node_list should be a list of all nodes involved, and ind_in,
        ind_out are the indexes of the special nodes InputNode and OutputNode
        in this list.'''
        super(ReversibleGraphNet, self).__init__()

        # Gather lists of input, output and condition nodes
        if ind_in is not None:
            warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " +
                          "input and output nodes are detected automatically.")
            if isinstance(ind_in, int):
                self.ind_in = list([ind_in])
            else:
                self.ind_in = ind_in
        else:
            self.ind_in = [i for i in range(len(node_list))
                           if isinstance(node_list[i], InputNode)]
            assert len(self.ind_in) > 0, "No input nodes specified."
        if ind_out is not None:
            warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " +
                          "input and output nodes are detected automatically.")
            if isinstance(ind_out, int):
                self.ind_out = list([ind_out])
            else:
                self.ind_out = ind_out
        else:
            self.ind_out = [i for i in range(len(node_list))
                            if isinstance(node_list[i], OutputNode)]
            assert len(self.ind_out) > 0, "No output nodes specified."
        self.ind_cond = [i for i in range(len(node_list))
                         if isinstance(node_list[i], ConditionNode)]

        self.return_vars = []
        self.input_vars = []
        self.cond_vars = []

        # Assign each node a unique ID
        self.node_list = node_list
        for i, n in enumerate(node_list):
            n.id = i
            n.graph = self

        # Recursively build the nodes nn.Modules and determine order of
        # operations
        ops = []
        for i in self.ind_out:
            node_list[i].build_modules(verbose=verbose)
            node_list[i].run_forward(ops)

        # create list of Pytorch variables that are used
        variables = set()
        for o in ops:
            variables = variables.union(set(o[1] + o[2] + o[3]))
        self.variables_ind = list(variables)

        self.indexed_ops = self.ops_to_indexed(ops)

        self.module_list = nn.ModuleList([n.module for n in node_list])
        self.module_cond = [(len(n.conditions) > 0) for n in node_list]
        self._buffers = {F'tmp_var_{i}' : None for i in range(len(variables))}

        # Find out the order of operations for reverse calculations
        ops_rev = []
        for i in self.ind_in + self.ind_cond:
            node_list[i].run_backward(ops_rev)
        self.indexed_ops_rev = self.ops_to_indexed(ops_rev)

    def ops_to_indexed(self, ops):
        '''Helper function to translate the list of variables (origin ID, channel),
        to variable IDs.'''
        result = []

        for o in ops:
            try:
                vars_in = [self.variables_ind.index(v) for v in o[1]]
            except ValueError:
                vars_in = -1

            vars_out = [self.variables_ind.index(v) for v in o[2]]
            vars_cond = [self.variables_ind.index(v) for v in o[3]]

            # Collect input/output/conditioning nodes in separate lists, but don't
            # add to indexed ops
            if o[0] in self.ind_out:
                self.return_vars.append(self.variables_ind.index(o[1][0]))
                continue
            if o[0] in self.ind_in:
                self.input_vars.append(self.variables_ind.index(o[1][0]))
                continue
            if o[0] in self.ind_cond:
                if self.variables_ind.index(o[1][0]) not in self.cond_vars:
                    self.cond_vars.append(self.variables_ind.index(o[1][0]))
                else:
                    print('Is this branch ever reached?')
                continue

            result.append((o[0], vars_in, vars_out, vars_cond))

        # Sort input/output/conditioning variables so they correspond to initial
        # node list order
        self.return_vars.sort(key=lambda i: self.variables_ind[i][0])
        self.input_vars.sort(key=lambda i: self.variables_ind[i][0])
        self.cond_vars.sort(key=lambda i: self.variables_ind[i][0])

        return result

    def forward(self, x, c=None, rev=False, intermediate_outputs=False):
        '''Forward or backward computation of the whole net.'''

        if rev:
            use_list = self.indexed_ops_rev
            input_vars, output_vars = self.return_vars, self.input_vars
        else:
            use_list = self.indexed_ops
            input_vars, output_vars = self.input_vars, self.return_vars

        # Assign input data to respective variables
        if isinstance(x, (list, tuple)):
            assert len(x) == len(input_vars), (
                f"Got list of {len(x)} input tensors for "
                f"{'inverse' if rev else 'forward'} pass, but expected "
                f"{len(input_vars)}."
            )
            for i in range(len(input_vars)):
                self._buffers[F'tmp_var_{input_vars[i]}'] = x[i]
        else:
            assert len(input_vars) == 1, (f"Got single input tensor for "
                                          f"{'inverse' if rev else 'forward'} "
                                          f"pass, but expected list of "
                                          f"{len(input_vars)}.")
            self._buffers[F'tmp_var_{input_vars[0]}'] = x

        # Assign conditioning data to respective variables
        if c is None:
            assert len(self.cond_vars) == 0
        elif isinstance(c, (list, tuple)):
            assert len(c) == len(self.cond_vars), f'{len(c)}, {len(self.cond_vars)}'
            for i in range(len(self.cond_vars)):
                self._buffers[F'tmp_var_{self.cond_vars[i]}'] = c[i]
        else:
            assert len(self.cond_vars) == 1
            self._buffers[F'tmp_var_{self.cond_vars[0]}'] = c

        # Prepare dictionary for intermediate node outputs
        out_dict = {}

        # Run all modules with the given inputs
        for o in use_list:
            try:
                x = [self._buffers[F'tmp_var_{i}'] for i in o[1]]
                if self.module_cond[o[0]]:
                    c = [self._buffers[F'tmp_var_{i}'] for i in o[3]]
                    results = self.module_list[o[0]](x, c=c, rev=rev)
                else:
                    results = self.module_list[o[0]](x, rev=rev)
            except TypeError:
                raise RuntimeError("Are you sure all used Nodes are in the "
                                   "Node list?")
            out_dict[self.node_list[o[0]].name] = results
            for i, r in zip(o[2], results):
                self._buffers[F'tmp_var_{i}'] = r

        if intermediate_outputs:
            return out_dict
        else:
            out = [self._buffers[F'tmp_var_{output_vars[i]}']
                   for i in range(len(output_vars))]
            if len(out) == 1:
                return out[0]
            else:
                return out

    def log_jacobian(self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False):
        '''Compute the log jacobian determinant of the whole net.'''
        if run_forward or c is not None:
            self.condition = c
        jacobian = 0

        if rev:
            use_list = self.indexed_ops_rev
        else:
            use_list = self.indexed_ops

        if run_forward:
            if x is None:
                raise RuntimeError("You need to provide an input if you want "
                                   "to run a forward pass")
            self.forward(x, c, rev=rev)

        # Prepare dictionary for intermediate node outputs
        jacobian_dict = {}

        # Run all modules with the given inputs
        for o in use_list:
            x = [self._buffers[F'tmp_var_{i}'] for i in o[1]]
            if self.module_cond[o[0]]:
                c = [self._buffers[F'tmp_var_{i}'] for i in o[3]]
                module_jacobian = self.module_list[o[0]].jacobian(x, c=c, rev=rev)
            else:
                module_jacobian = self.module_list[o[0]].jacobian(x, rev=rev)
            jacobian += module_jacobian
            jacobian_dict[self.node_list[o[0]].name] = module_jacobian

        if intermediate_outputs:
            return jacobian_dict
        else:
            return jacobian

    def jacobian(self, *args, **kwargs):
        '''Compute the log jacobian determinant of the whole net.'''
        warnings.warn("This function computes the log-jacobian determinant, not the "
                      "jacobian as the name suggest. Will be removed in the future.")
        return self.log_jacobian(*args, **kwargs)

    def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04):
        '''Approximate log Jacobian determinant via finite differences.'''
        if isinstance(x, (list, tuple)):
            batch_size = x[0].shape[0]
            ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x]
            ndim_x_total = sum(ndim_x_separate)
            x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1)
        else:
            batch_size = x.shape[0]
            ndim_x_total = np.prod(x.shape[1:])
            x_flat = x.reshape(batch_size, -1)

        J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total)
        for i in range(ndim_x_total):
            offset = x[0].new_zeros(batch_size, ndim_x_total)
            offset[:,i] = h
            if isinstance(x, (list, tuple)):
                x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1)
                x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))]
                x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1)
                x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))]
            else:
                x_upper = (x_flat + offset).view(*x.shape)
                x_lower = (x_flat - offset).view(*x.shape)
            y_upper = self.forward(x_upper, c=c)
            y_lower = self.forward(x_lower, c=c)
            if isinstance(y_upper, (list, tuple)):
                y_upper = torch.cat([y_i.view(batch_size, -1) for y_i in y_upper], dim=1)
                y_lower = torch.cat([y_i.view(batch_size, -1) for y_i in y_lower], dim=1)
            J_num[:,:,i] = (y_upper - y_lower).view(batch_size, -1) / (2*h)
        logdet_num = x[0].new_zeros(batch_size)
        for i in range(batch_size):
            logdet_num[i] = torch.det(J_num[i,:,:]).abs().log()

        return logdet_num

    def load_state_dict(self, state_dict, *args, **kwargs):

        state_dict_no_buffers = {}
        for k,p in state_dict.items():
            if k in self._buffers and self._buffers[k] is None:
                continue
            state_dict_no_buffers[k] = p

        return super().load_state_dict(state_dict_no_buffers, *args, **kwargs)


# Testing example
if __name__ == '__main__':
    inp = InputNode(4, 64, 64, name='input')
    t1 = Node([(inp, 0)], dummys.dummy_mux, {}, name='t1')
    s1 = Node([(t1, 0)], dummys.dummy_2split, {}, name='s1')

    t2 = Node([(s1, 0)], dummys.dummy_module, {}, name='t2')
    s2 = Node([(s1, 1)], dummys.dummy_2split, {}, name='s2')
    t3 = Node([(s2, 0)], dummys.dummy_module, {}, name='t3')

    m1 = Node([(t3, 0), (s2, 1)], dummys.dummy_2merge, {}, name='m1')
    m2 = Node([(t2, 0), (m1, 0)], dummys.dummy_2merge, {}, name='m2')
    outp = OutputNode([(m2, 0)], name='output')

    all_nodes = [inp, outp, t1, s1, t2, s2, t3, m1, m2]

    net = ReversibleGraphNet(all_nodes, 0, 1)

Classes

class ConditionNode

Special type of node that represents contitional input to the internal networks inside coupling layers

class ConditionNode(Node):
    '''Special type of node that represents contitional input to the internal
    networks inside coupling layers'''

    def __init__(self, *dims, name='node'):
        self.name = name
        self.data = dummys.dummy_data(*dims)
        self.outputs = []
        self.conditions = []
        self.condition_vars = []
        self.module = None
        self.computed_rev = None
        self.n_outputs = 1
        self.input_vars = []
        self.out0 = (self, 0)

    def build_modules(self, verbose=True):
        return [self.data.shape]

    def run_forward(self, op_list):
        return [(self.id, 0)]

Ancestors (in MRO)

Static methods

def __init__(

self, *dims)

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, *dims, name='node'):
    self.name = name
    self.data = dummys.dummy_data(*dims)
    self.outputs = []
    self.conditions = []
    self.condition_vars = []
    self.module = None
    self.computed_rev = None
    self.n_outputs = 1
    self.input_vars = []
    self.out0 = (self, 0)

def build_modules(

self, verbose=True)

Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.

def build_modules(self, verbose=True):
    return [self.data.shape]

def parse_inputs(

self, inputs)

def parse_inputs(self, inputs):
    if isinstance(inputs, (list, tuple)):
        if isinstance(inputs[0], (list, tuple)):
            return inputs
        elif len(inputs) == 2:
            return [inputs,]
        else:
            raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.")
    else:
        assert isinstance(inputs, Node), "Received object of invalid type "\
            f"({type(inputs)}) as input for node '{name}'."
        return [(inputs, 0),]

def run_backward(

self, op_list)

See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work

def run_backward(self, op_list):
    '''See run_forward, this is the same, only for the reverse computation.
    Need to call run_forward first, otherwise this function will not
    work'''
    assert len(self.outputs) > 0, "Call run_forward first"
    if not self.computed_rev:
        # These are the input variables that must be computed first
        output_vars = [(self.id, i) for i in range(self.n_outputs)]
        # Recursively compute these
        for n, c in self.outputs:
            n.run_backward(op_list)
        # The variables that this node computes are the input variables
        # from the forward pass
        self.computed_rev = self.input_vars
        if len(self.condition_vars) == 0:
            self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions]
        op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars))
    return self.computed_rev

def run_forward(

self, op_list)

Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)

def run_forward(self, op_list):
    return [(self.id, 0)]

Instance variables

var computed_rev

var condition_vars

var conditions

var data

var input_vars

var module

var n_outputs

var name

var out0

var outputs

class InputNode

Special type of node that represents the input data of the whole net (or ouput when running reverse)

class InputNode(Node):
    '''Special type of node that represents the input data of the whole net (or
    ouput when running reverse)'''

    def __init__(self, *dims, name='node'):
        self.name = name
        self.data = dummys.dummy_data(*dims)
        self.outputs = []
        self.conditions = []
        self.condition_vars = []
        self.module = None
        self.computed_rev = None
        self.n_outputs = 1
        self.input_vars = []
        self.out0 = (self, 0)

    def build_modules(self, verbose=True):
        return [self.data.shape]

    def run_forward(self, op_list):
        return [(self.id, 0)]

Ancestors (in MRO)

Static methods

def __init__(

self, *dims)

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, *dims, name='node'):
    self.name = name
    self.data = dummys.dummy_data(*dims)
    self.outputs = []
    self.conditions = []
    self.condition_vars = []
    self.module = None
    self.computed_rev = None
    self.n_outputs = 1
    self.input_vars = []
    self.out0 = (self, 0)

def build_modules(

self, verbose=True)

Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.

def build_modules(self, verbose=True):
    return [self.data.shape]

def parse_inputs(

self, inputs)

def parse_inputs(self, inputs):
    if isinstance(inputs, (list, tuple)):
        if isinstance(inputs[0], (list, tuple)):
            return inputs
        elif len(inputs) == 2:
            return [inputs,]
        else:
            raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.")
    else:
        assert isinstance(inputs, Node), "Received object of invalid type "\
            f"({type(inputs)}) as input for node '{name}'."
        return [(inputs, 0),]

def run_backward(

self, op_list)

See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work

def run_backward(self, op_list):
    '''See run_forward, this is the same, only for the reverse computation.
    Need to call run_forward first, otherwise this function will not
    work'''
    assert len(self.outputs) > 0, "Call run_forward first"
    if not self.computed_rev:
        # These are the input variables that must be computed first
        output_vars = [(self.id, i) for i in range(self.n_outputs)]
        # Recursively compute these
        for n, c in self.outputs:
            n.run_backward(op_list)
        # The variables that this node computes are the input variables
        # from the forward pass
        self.computed_rev = self.input_vars
        if len(self.condition_vars) == 0:
            self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions]
        op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars))
    return self.computed_rev

def run_forward(

self, op_list)

Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)

def run_forward(self, op_list):
    return [(self.id, 0)]

Instance variables

var computed_rev

var condition_vars

var conditions

var data

var input_vars

var module

var n_outputs

var name

var out0

var outputs

class Node

The Node class represents one transformation in the graph, with an arbitrary number of in- and outputs.

class Node:
    '''The Node class represents one transformation in the graph, with an
    arbitrary number of in- and outputs.'''
    def __init__(self, inputs, module_type, module_args, conditions=[], name=None):
        self.inputs = self.parse_inputs(inputs)
        if isinstance(conditions, (list, tuple)):
            self.conditions = conditions
        else:
            self.conditions = [conditions,]

        self.outputs = []
        self.module_type = module_type
        self.module_args = module_args

        self.input_dims = None
        self.module = None
        self.computed = None
        self.computed_rev = None
        self.id = None

        if name:
            self.name = name
        else:
            self.name = hex(id(self))[-6:]
        for i in range(255):
            exec('self.out{0} = (self, {0})'.format(i))

    def parse_inputs(self, inputs):
        if isinstance(inputs, (list, tuple)):
            if isinstance(inputs[0], (list, tuple)):
                return inputs
            elif len(inputs) == 2:
                return [inputs,]
            else:
                raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.")
        else:
            assert isinstance(inputs, Node), "Received object of invalid type "\
                f"({type(inputs)}) as input for node '{name}'."
            return [(inputs, 0),]

    def build_modules(self, verbose=True):
        ''' Returns a list with the dimension of each output of this node,
        recursively calling build_modules of the nodes connected to the input.
        Use this information to initialize the pytorch nn.Module of this node.
        '''

        if not self.input_dims:  # Only do it if this hasn't been computed yet
            self.input_dims = [n.build_modules(verbose=verbose)[c]
                               for n, c in self.inputs]
            try:
                if len(self.conditions) > 0:
                    c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions]
                    self.module = self.module_type(self.input_dims, dims_c=c_dims,
                                                   **self.module_args)
                else:
                    self.module = self.module_type(self.input_dims,
                                                   **self.module_args)
            except Exception as e:
                print('Error in node %s' % (self.name))
                raise e

            if verbose:
                print(f"Node '{self.name}' takes the following inputs:")
                for d, (n, c) in zip(self.input_dims, self.inputs):
                    print(f"\t Output #{c} of node '{n.name}' with dims {d}")
                for c in self.conditions:
                    print(f"\t conditioned on node '{c.name}' " +
                          f"with dims {c.data.shape}")
                print()

            self.output_dims = self.module.output_dims(self.input_dims)
            self.n_outputs = len(self.output_dims)

        return self.output_dims

    def run_forward(self, op_list):
        '''Determine the order of operations needed to reach this node. Calls
        run_forward of parent nodes recursively. Each operation is appended to
        the global list op_list, in the form (node ID, input variable IDs,
        output variable IDs)'''

        if not self.computed:

            # Compute all nodes which provide inputs, filter out the
            # channels you need
            self.input_vars = []
            for i, (n, c) in enumerate(self.inputs):
                self.input_vars.append(n.run_forward(op_list)[c])
                # Register self as an output in the input node
                n.outputs.append((self, i))
            # Compute all nodes which provide conditioning
            self.condition_vars = []
            for i, c in enumerate(self.conditions):
                self.condition_vars.append(c.run_forward(op_list)[0])
                # Register self as an output in the condition node
                c.outputs.append((self, i))

            # All outputs could now be computed
            self.computed = [(self.id, i) for i in range(self.n_outputs)]
            op_list.append((self.id, self.input_vars, self.computed, self.condition_vars))

        # Return the variables you have computed (this happens mulitple times
        # without recomputing if called repeatedly)
        return self.computed

    def run_backward(self, op_list):
        '''See run_forward, this is the same, only for the reverse computation.
        Need to call run_forward first, otherwise this function will not
        work'''

        assert len(self.outputs) > 0, "Call run_forward first"
        if not self.computed_rev:

            # These are the input variables that must be computed first
            output_vars = [(self.id, i) for i in range(self.n_outputs)]

            # Recursively compute these
            for n, c in self.outputs:
                n.run_backward(op_list)

            # The variables that this node computes are the input variables
            # from the forward pass
            self.computed_rev = self.input_vars
            if len(self.condition_vars) == 0:
                self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions]
            op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars))

        return self.computed_rev

Ancestors (in MRO)

  • Node
  • builtins.object

Static methods

def __init__(

self, inputs, module_type, module_args, conditions=[], name=None)

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, inputs, module_type, module_args, conditions=[], name=None):
    self.inputs = self.parse_inputs(inputs)
    if isinstance(conditions, (list, tuple)):
        self.conditions = conditions
    else:
        self.conditions = [conditions,]
    self.outputs = []
    self.module_type = module_type
    self.module_args = module_args
    self.input_dims = None
    self.module = None
    self.computed = None
    self.computed_rev = None
    self.id = None
    if name:
        self.name = name
    else:
        self.name = hex(id(self))[-6:]
    for i in range(255):
        exec('self.out{0} = (self, {0})'.format(i))

def build_modules(

self, verbose=True)

Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.

def build_modules(self, verbose=True):
    ''' Returns a list with the dimension of each output of this node,
    recursively calling build_modules of the nodes connected to the input.
    Use this information to initialize the pytorch nn.Module of this node.
    '''
    if not self.input_dims:  # Only do it if this hasn't been computed yet
        self.input_dims = [n.build_modules(verbose=verbose)[c]
                           for n, c in self.inputs]
        try:
            if len(self.conditions) > 0:
                c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions]
                self.module = self.module_type(self.input_dims, dims_c=c_dims,
                                               **self.module_args)
            else:
                self.module = self.module_type(self.input_dims,
                                               **self.module_args)
        except Exception as e:
            print('Error in node %s' % (self.name))
            raise e
        if verbose:
            print(f"Node '{self.name}' takes the following inputs:")
            for d, (n, c) in zip(self.input_dims, self.inputs):
                print(f"\t Output #{c} of node '{n.name}' with dims {d}")
            for c in self.conditions:
                print(f"\t conditioned on node '{c.name}' " +
                      f"with dims {c.data.shape}")
            print()
        self.output_dims = self.module.output_dims(self.input_dims)
        self.n_outputs = len(self.output_dims)
    return self.output_dims

def parse_inputs(

self, inputs)

def parse_inputs(self, inputs):
    if isinstance(inputs, (list, tuple)):
        if isinstance(inputs[0], (list, tuple)):
            return inputs
        elif len(inputs) == 2:
            return [inputs,]
        else:
            raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.")
    else:
        assert isinstance(inputs, Node), "Received object of invalid type "\
            f"({type(inputs)}) as input for node '{name}'."
        return [(inputs, 0),]

def run_backward(

self, op_list)

See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work

def run_backward(self, op_list):
    '''See run_forward, this is the same, only for the reverse computation.
    Need to call run_forward first, otherwise this function will not
    work'''
    assert len(self.outputs) > 0, "Call run_forward first"
    if not self.computed_rev:
        # These are the input variables that must be computed first
        output_vars = [(self.id, i) for i in range(self.n_outputs)]
        # Recursively compute these
        for n, c in self.outputs:
            n.run_backward(op_list)
        # The variables that this node computes are the input variables
        # from the forward pass
        self.computed_rev = self.input_vars
        if len(self.condition_vars) == 0:
            self.condition_vars = [c.run_forward(op_list)[0] for c in self.conditions]
        op_list.append((self.id, output_vars, self.computed_rev, self.condition_vars))
    return self.computed_rev

def run_forward(

self, op_list)

Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)

def run_forward(self, op_list):
    '''Determine the order of operations needed to reach this node. Calls
    run_forward of parent nodes recursively. Each operation is appended to
    the global list op_list, in the form (node ID, input variable IDs,
    output variable IDs)'''
    if not self.computed:
        # Compute all nodes which provide inputs, filter out the
        # channels you need
        self.input_vars = []
        for i, (n, c) in enumerate(self.inputs):
            self.input_vars.append(n.run_forward(op_list)[c])
            # Register self as an output in the input node
            n.outputs.append((self, i))
        # Compute all nodes which provide conditioning
        self.condition_vars = []
        for i, c in enumerate(self.conditions):
            self.condition_vars.append(c.run_forward(op_list)[0])
            # Register self as an output in the condition node
            c.outputs.append((self, i))
        # All outputs could now be computed
        self.computed = [(self.id, i) for i in range(self.n_outputs)]
        op_list.append((self.id, self.input_vars, self.computed, self.condition_vars))
    # Return the variables you have computed (this happens mulitple times
    # without recomputing if called repeatedly)
    return self.computed

Instance variables

var computed

var computed_rev

var id

var input_dims

var inputs

var module

var module_args

var module_type

var outputs

class OutputNode

Special type of node that represents the output of the whole net (of the input when running in reverse)

class OutputNode(Node):
    '''Special type of node that represents the output of the whole net (of the
    input when running in reverse)'''
    class dummy(nn.Module):

        def __init__(self, *args):
            super(OutputNode.dummy, self).__init__()

        def __call__(*args):
            return args

        def output_dims(*args):
            return args

    def __init__(self, inputs, name='node'):
        self.module_type, self.module_args = self.dummy, {}
        self.output_dims = []
        self.inputs = self.parse_inputs(inputs)
        self.conditions = []
        self.input_dims, self.module = None, None
        self.computed = None
        self.id = None
        self.name = name

        for c, inp in enumerate(self.inputs):
            inp[0].outputs.append((self, c))

    def run_backward(self, op_list):
        return [(self.id, 0)]

Ancestors (in MRO)

Class variables

var dummy

Static methods

def __init__(

self, inputs, name='node')

Initialize self. See help(type(self)) for accurate signature.

def __init__(self, inputs, name='node'):
    self.module_type, self.module_args = self.dummy, {}
    self.output_dims = []
    self.inputs = self.parse_inputs(inputs)
    self.conditions = []
    self.input_dims, self.module = None, None
    self.computed = None
    self.id = None
    self.name = name
    for c, inp in enumerate(self.inputs):
        inp[0].outputs.append((self, c))

def build_modules(

self, verbose=True)

Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node.

def build_modules(self, verbose=True):
    ''' Returns a list with the dimension of each output of this node,
    recursively calling build_modules of the nodes connected to the input.
    Use this information to initialize the pytorch nn.Module of this node.
    '''
    if not self.input_dims:  # Only do it if this hasn't been computed yet
        self.input_dims = [n.build_modules(verbose=verbose)[c]
                           for n, c in self.inputs]
        try:
            if len(self.conditions) > 0:
                c_dims = [c.build_modules(verbose=verbose)[0] for c in self.conditions]
                self.module = self.module_type(self.input_dims, dims_c=c_dims,
                                               **self.module_args)
            else:
                self.module = self.module_type(self.input_dims,
                                               **self.module_args)
        except Exception as e:
            print('Error in node %s' % (self.name))
            raise e
        if verbose:
            print(f"Node '{self.name}' takes the following inputs:")
            for d, (n, c) in zip(self.input_dims, self.inputs):
                print(f"\t Output #{c} of node '{n.name}' with dims {d}")
            for c in self.conditions:
                print(f"\t conditioned on node '{c.name}' " +
                      f"with dims {c.data.shape}")
            print()
        self.output_dims = self.module.output_dims(self.input_dims)
        self.n_outputs = len(self.output_dims)
    return self.output_dims

def parse_inputs(

self, inputs)

def parse_inputs(self, inputs):
    if isinstance(inputs, (list, tuple)):
        if isinstance(inputs[0], (list, tuple)):
            return inputs
        elif len(inputs) == 2:
            return [inputs,]
        else:
            raise RuntimeError(f"Cannot parse inputs provided to node '{name}'.")
    else:
        assert isinstance(inputs, Node), "Received object of invalid type "\
            f"({type(inputs)}) as input for node '{name}'."
        return [(inputs, 0),]

def run_backward(

self, op_list)

See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work

def run_backward(self, op_list):
    return [(self.id, 0)]

def run_forward(

self, op_list)

Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)

def run_forward(self, op_list):
    '''Determine the order of operations needed to reach this node. Calls
    run_forward of parent nodes recursively. Each operation is appended to
    the global list op_list, in the form (node ID, input variable IDs,
    output variable IDs)'''
    if not self.computed:
        # Compute all nodes which provide inputs, filter out the
        # channels you need
        self.input_vars = []
        for i, (n, c) in enumerate(self.inputs):
            self.input_vars.append(n.run_forward(op_list)[c])
            # Register self as an output in the input node
            n.outputs.append((self, i))
        # Compute all nodes which provide conditioning
        self.condition_vars = []
        for i, c in enumerate(self.conditions):
            self.condition_vars.append(c.run_forward(op_list)[0])
            # Register self as an output in the condition node
            c.outputs.append((self, i))
        # All outputs could now be computed
        self.computed = [(self.id, i) for i in range(self.n_outputs)]
        op_list.append((self.id, self.input_vars, self.computed, self.condition_vars))
    # Return the variables you have computed (this happens mulitple times
    # without recomputing if called repeatedly)
    return self.computed

Instance variables

var computed

var conditions

var id

var inputs

var name

var output_dims

class ReversibleGraphNet

This class represents the invertible net itself. It is a subclass of torch.nn.Module and supports the same methods. The forward method has an additional option 'rev', whith which the net can be computed in reverse.

class ReversibleGraphNet(nn.Module):
    '''This class represents the invertible net itself. It is a subclass of
    torch.nn.Module and supports the same methods. The forward method has an
    additional option 'rev', whith which the net can be computed in reverse.'''

    def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True):
        '''node_list should be a list of all nodes involved, and ind_in,
        ind_out are the indexes of the special nodes InputNode and OutputNode
        in this list.'''
        super(ReversibleGraphNet, self).__init__()

        # Gather lists of input, output and condition nodes
        if ind_in is not None:
            warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " +
                          "input and output nodes are detected automatically.")
            if isinstance(ind_in, int):
                self.ind_in = list([ind_in])
            else:
                self.ind_in = ind_in
        else:
            self.ind_in = [i for i in range(len(node_list))
                           if isinstance(node_list[i], InputNode)]
            assert len(self.ind_in) > 0, "No input nodes specified."
        if ind_out is not None:
            warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " +
                          "input and output nodes are detected automatically.")
            if isinstance(ind_out, int):
                self.ind_out = list([ind_out])
            else:
                self.ind_out = ind_out
        else:
            self.ind_out = [i for i in range(len(node_list))
                            if isinstance(node_list[i], OutputNode)]
            assert len(self.ind_out) > 0, "No output nodes specified."
        self.ind_cond = [i for i in range(len(node_list))
                         if isinstance(node_list[i], ConditionNode)]

        self.return_vars = []
        self.input_vars = []
        self.cond_vars = []

        # Assign each node a unique ID
        self.node_list = node_list
        for i, n in enumerate(node_list):
            n.id = i
            n.graph = self

        # Recursively build the nodes nn.Modules and determine order of
        # operations
        ops = []
        for i in self.ind_out:
            node_list[i].build_modules(verbose=verbose)
            node_list[i].run_forward(ops)

        # create list of Pytorch variables that are used
        variables = set()
        for o in ops:
            variables = variables.union(set(o[1] + o[2] + o[3]))
        self.variables_ind = list(variables)

        self.indexed_ops = self.ops_to_indexed(ops)

        self.module_list = nn.ModuleList([n.module for n in node_list])
        self.module_cond = [(len(n.conditions) > 0) for n in node_list]
        self._buffers = {F'tmp_var_{i}' : None for i in range(len(variables))}

        # Find out the order of operations for reverse calculations
        ops_rev = []
        for i in self.ind_in + self.ind_cond:
            node_list[i].run_backward(ops_rev)
        self.indexed_ops_rev = self.ops_to_indexed(ops_rev)

    def ops_to_indexed(self, ops):
        '''Helper function to translate the list of variables (origin ID, channel),
        to variable IDs.'''
        result = []

        for o in ops:
            try:
                vars_in = [self.variables_ind.index(v) for v in o[1]]
            except ValueError:
                vars_in = -1

            vars_out = [self.variables_ind.index(v) for v in o[2]]
            vars_cond = [self.variables_ind.index(v) for v in o[3]]

            # Collect input/output/conditioning nodes in separate lists, but don't
            # add to indexed ops
            if o[0] in self.ind_out:
                self.return_vars.append(self.variables_ind.index(o[1][0]))
                continue
            if o[0] in self.ind_in:
                self.input_vars.append(self.variables_ind.index(o[1][0]))
                continue
            if o[0] in self.ind_cond:
                if self.variables_ind.index(o[1][0]) not in self.cond_vars:
                    self.cond_vars.append(self.variables_ind.index(o[1][0]))
                else:
                    print('Is this branch ever reached?')
                continue

            result.append((o[0], vars_in, vars_out, vars_cond))

        # Sort input/output/conditioning variables so they correspond to initial
        # node list order
        self.return_vars.sort(key=lambda i: self.variables_ind[i][0])
        self.input_vars.sort(key=lambda i: self.variables_ind[i][0])
        self.cond_vars.sort(key=lambda i: self.variables_ind[i][0])

        return result

    def forward(self, x, c=None, rev=False, intermediate_outputs=False):
        '''Forward or backward computation of the whole net.'''

        if rev:
            use_list = self.indexed_ops_rev
            input_vars, output_vars = self.return_vars, self.input_vars
        else:
            use_list = self.indexed_ops
            input_vars, output_vars = self.input_vars, self.return_vars

        # Assign input data to respective variables
        if isinstance(x, (list, tuple)):
            assert len(x) == len(input_vars), (
                f"Got list of {len(x)} input tensors for "
                f"{'inverse' if rev else 'forward'} pass, but expected "
                f"{len(input_vars)}."
            )
            for i in range(len(input_vars)):
                self._buffers[F'tmp_var_{input_vars[i]}'] = x[i]
        else:
            assert len(input_vars) == 1, (f"Got single input tensor for "
                                          f"{'inverse' if rev else 'forward'} "
                                          f"pass, but expected list of "
                                          f"{len(input_vars)}.")
            self._buffers[F'tmp_var_{input_vars[0]}'] = x

        # Assign conditioning data to respective variables
        if c is None:
            assert len(self.cond_vars) == 0
        elif isinstance(c, (list, tuple)):
            assert len(c) == len(self.cond_vars), f'{len(c)}, {len(self.cond_vars)}'
            for i in range(len(self.cond_vars)):
                self._buffers[F'tmp_var_{self.cond_vars[i]}'] = c[i]
        else:
            assert len(self.cond_vars) == 1
            self._buffers[F'tmp_var_{self.cond_vars[0]}'] = c

        # Prepare dictionary for intermediate node outputs
        out_dict = {}

        # Run all modules with the given inputs
        for o in use_list:
            try:
                x = [self._buffers[F'tmp_var_{i}'] for i in o[1]]
                if self.module_cond[o[0]]:
                    c = [self._buffers[F'tmp_var_{i}'] for i in o[3]]
                    results = self.module_list[o[0]](x, c=c, rev=rev)
                else:
                    results = self.module_list[o[0]](x, rev=rev)
            except TypeError:
                raise RuntimeError("Are you sure all used Nodes are in the "
                                   "Node list?")
            out_dict[self.node_list[o[0]].name] = results
            for i, r in zip(o[2], results):
                self._buffers[F'tmp_var_{i}'] = r

        if intermediate_outputs:
            return out_dict
        else:
            out = [self._buffers[F'tmp_var_{output_vars[i]}']
                   for i in range(len(output_vars))]
            if len(out) == 1:
                return out[0]
            else:
                return out

    def log_jacobian(self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False):
        '''Compute the log jacobian determinant of the whole net.'''
        if run_forward or c is not None:
            self.condition = c
        jacobian = 0

        if rev:
            use_list = self.indexed_ops_rev
        else:
            use_list = self.indexed_ops

        if run_forward:
            if x is None:
                raise RuntimeError("You need to provide an input if you want "
                                   "to run a forward pass")
            self.forward(x, c, rev=rev)

        # Prepare dictionary for intermediate node outputs
        jacobian_dict = {}

        # Run all modules with the given inputs
        for o in use_list:
            x = [self._buffers[F'tmp_var_{i}'] for i in o[1]]
            if self.module_cond[o[0]]:
                c = [self._buffers[F'tmp_var_{i}'] for i in o[3]]
                module_jacobian = self.module_list[o[0]].jacobian(x, c=c, rev=rev)
            else:
                module_jacobian = self.module_list[o[0]].jacobian(x, rev=rev)
            jacobian += module_jacobian
            jacobian_dict[self.node_list[o[0]].name] = module_jacobian

        if intermediate_outputs:
            return jacobian_dict
        else:
            return jacobian

    def jacobian(self, *args, **kwargs):
        '''Compute the log jacobian determinant of the whole net.'''
        warnings.warn("This function computes the log-jacobian determinant, not the "
                      "jacobian as the name suggest. Will be removed in the future.")
        return self.log_jacobian(*args, **kwargs)

    def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04):
        '''Approximate log Jacobian determinant via finite differences.'''
        if isinstance(x, (list, tuple)):
            batch_size = x[0].shape[0]
            ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x]
            ndim_x_total = sum(ndim_x_separate)
            x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1)
        else:
            batch_size = x.shape[0]
            ndim_x_total = np.prod(x.shape[1:])
            x_flat = x.reshape(batch_size, -1)

        J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total)
        for i in range(ndim_x_total):
            offset = x[0].new_zeros(batch_size, ndim_x_total)
            offset[:,i] = h
            if isinstance(x, (list, tuple)):
                x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1)
                x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))]
                x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1)
                x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))]
            else:
                x_upper = (x_flat + offset).view(*x.shape)
                x_lower = (x_flat - offset).view(*x.shape)
            y_upper = self.forward(x_upper, c=c)
            y_lower = self.forward(x_lower, c=c)
            if isinstance(y_upper, (list, tuple)):
                y_upper = torch.cat([y_i.view(batch_size, -1) for y_i in y_upper], dim=1)
                y_lower = torch.cat([y_i.view(batch_size, -1) for y_i in y_lower], dim=1)
            J_num[:,:,i] = (y_upper - y_lower).view(batch_size, -1) / (2*h)
        logdet_num = x[0].new_zeros(batch_size)
        for i in range(batch_size):
            logdet_num[i] = torch.det(J_num[i,:,:]).abs().log()

        return logdet_num

    def load_state_dict(self, state_dict, *args, **kwargs):

        state_dict_no_buffers = {}
        for k,p in state_dict.items():
            if k in self._buffers and self._buffers[k] is None:
                continue
            state_dict_no_buffers[k] = p

        return super().load_state_dict(state_dict_no_buffers, *args, **kwargs)

Ancestors (in MRO)

Class variables

var dump_patches

Static methods

def __init__(

self, node_list, ind_in=None, ind_out=None, verbose=True)

node_list should be a list of all nodes involved, and ind_in, ind_out are the indexes of the special nodes InputNode and OutputNode in this list.

def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True):
    '''node_list should be a list of all nodes involved, and ind_in,
    ind_out are the indexes of the special nodes InputNode and OutputNode
    in this list.'''
    super(ReversibleGraphNet, self).__init__()
    # Gather lists of input, output and condition nodes
    if ind_in is not None:
        warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " +
                      "input and output nodes are detected automatically.")
        if isinstance(ind_in, int):
            self.ind_in = list([ind_in])
        else:
            self.ind_in = ind_in
    else:
        self.ind_in = [i for i in range(len(node_list))
                       if isinstance(node_list[i], InputNode)]
        assert len(self.ind_in) > 0, "No input nodes specified."
    if ind_out is not None:
        warnings.warn("Use of 'ind_in' and 'ind_out' for ReversibleGraphNet is deprecated, " +
                      "input and output nodes are detected automatically.")
        if isinstance(ind_out, int):
            self.ind_out = list([ind_out])
        else:
            self.ind_out = ind_out
    else:
        self.ind_out = [i for i in range(len(node_list))
                        if isinstance(node_list[i], OutputNode)]
        assert len(self.ind_out) > 0, "No output nodes specified."
    self.ind_cond = [i for i in range(len(node_list))
                     if isinstance(node_list[i], ConditionNode)]
    self.return_vars = []
    self.input_vars = []
    self.cond_vars = []
    # Assign each node a unique ID
    self.node_list = node_list
    for i, n in enumerate(node_list):
        n.id = i
        n.graph = self
    # Recursively build the nodes nn.Modules and determine order of
    # operations
    ops = []
    for i in self.ind_out:
        node_list[i].build_modules(verbose=verbose)
        node_list[i].run_forward(ops)
    # create list of Pytorch variables that are used
    variables = set()
    for o in ops:
        variables = variables.union(set(o[1] + o[2] + o[3]))
    self.variables_ind = list(variables)
    self.indexed_ops = self.ops_to_indexed(ops)
    self.module_list = nn.ModuleList([n.module for n in node_list])
    self.module_cond = [(len(n.conditions) > 0) for n in node_list]
    self._buffers = {F'tmp_var_{i}' : None for i in range(len(variables))}
    # Find out the order of operations for reverse calculations
    ops_rev = []
    for i in self.ind_in + self.ind_cond:
        node_list[i].run_backward(ops_rev)
    self.indexed_ops_rev = self.ops_to_indexed(ops_rev)

def add_module(

self, name, module)

Adds a child module to the current module.

The module can be accessed as an attribute using the given name.

Args: name (string): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module.

def add_module(self, name, module):
    r"""Adds a child module to the current module.
    The module can be accessed as an attribute using the given name.
    Args:
        name (string): name of the child module. The child module can be
            accessed from this module using the given name
        module (Module): child module to be added to the module.
    """
    if not isinstance(module, Module) and module is not None:
        raise TypeError("{} is not a Module subclass".format(
            torch.typename(module)))
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("module name should be a string. Got {}".format(
            torch.typename(name)))
    elif hasattr(self, name) and name not in self._modules:
        raise KeyError("attribute '{}' already exists".format(name))
    elif '.' in name:
        raise KeyError("module name can't contain \".\"")
    elif name == '':
        raise KeyError("module name can't be empty string \"\"")
    self._modules[name] = module

def apply(

self, fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:torch-nn-init).

Args: fn (:class:Module -> None): function to be applied to each submodule

Returns: Module: self

Example::

>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.data.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1.,  1.],
        [ 1.,  1.]])
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1.,  1.],
        [ 1.,  1.]])
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
def apply(self, fn):
    r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
    as well as self. Typical use includes initializing the parameters of a model
    (see also :ref:`torch-nn-init`).
    Args:
        fn (:class:`Module` -> None): function to be applied to each submodule
    Returns:
        Module: self
    Example::
        >>> def init_weights(m):
        >>>     print(m)
        >>>     if type(m) == nn.Linear:
        >>>         m.weight.data.fill_(1.0)
        >>>         print(m.weight)
        >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
        >>> net.apply(init_weights)
        Linear(in_features=2, out_features=2, bias=True)
        Parameter containing:
        tensor([[ 1.,  1.],
                [ 1.,  1.]])
        Linear(in_features=2, out_features=2, bias=True)
        Parameter containing:
        tensor([[ 1.,  1.],
                [ 1.,  1.]])
        Sequential(
          (0): Linear(in_features=2, out_features=2, bias=True)
          (1): Linear(in_features=2, out_features=2, bias=True)
        )
        Sequential(
          (0): Linear(in_features=2, out_features=2, bias=True)
          (1): Linear(in_features=2, out_features=2, bias=True)
        )
    """
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

def buffers(

self, recurse=True)

Returns an iterator over module buffers.

Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields: torch.Tensor: module buffer

Example::

>>> for buf in model.buffers():
>>>     print(type(buf.data), buf.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
def buffers(self, recurse=True):
    r"""Returns an iterator over module buffers.
    Args:
        recurse (bool): if True, then yields buffers of this module
            and all submodules. Otherwise, yields only buffers that
            are direct members of this module.
    Yields:
        torch.Tensor: module buffer
    Example::
        >>> for buf in model.buffers():
        >>>     print(type(buf.data), buf.size())
        <class 'torch.FloatTensor'> (20L,)
        <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
    """
    for name, buf in self.named_buffers(recurse=recurse):
        yield buf

def children(

self)

Returns an iterator over immediate children modules.

Yields: Module: a child module

def children(self):
    r"""Returns an iterator over immediate children modules.
    Yields:
        Module: a child module
    """
    for name, module in self.named_children():
        yield module

def cpu(

self)

Moves all model parameters and buffers to the CPU.

Returns: Module: self

def cpu(self):
    r"""Moves all model parameters and buffers to the CPU.
    Returns:
        Module: self
    """
    return self._apply(lambda t: t.cpu())

def cuda(

self, device=None)

Moves all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Arguments: device (int, optional): if specified, all parameters will be copied to that device

Returns: Module: self

def cuda(self, device=None):
    r"""Moves all model parameters and buffers to the GPU.
    This also makes associated parameters and buffers different objects. So
    it should be called before constructing optimizer if the module will
    live on GPU while being optimized.
    Arguments:
        device (int, optional): if specified, all parameters will be
            copied to that device
    Returns:
        Module: self
    """
    return self._apply(lambda t: t.cuda(device))

def double(

self)

Casts all floating point parameters and buffers to double datatype.

Returns: Module: self

def double(self):
    r"""Casts all floating point parameters and buffers to ``double`` datatype.
    Returns:
        Module: self
    """
    return self._apply(lambda t: t.double() if t.is_floating_point() else t)

def eval(

self)

Sets the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

def eval(self):
    r"""Sets the module in evaluation mode.
    This has any effect only on certain modules. See documentations of
    particular modules for details of their behaviors in training/evaluation
    mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
    etc.
    """
    return self.train(False)

def extra_repr(

self)

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

def extra_repr(self):
    r"""Set the extra representation of the module
    To print customized extra information, you should reimplement
    this method in your own modules. Both single-line and multi-line
    strings are acceptable.
    """
    return ''

def float(

self)

Casts all floating point parameters and buffers to float datatype.

Returns: Module: self

def float(self):
    r"""Casts all floating point parameters and buffers to float datatype.
    Returns:
        Module: self
    """
    return self._apply(lambda t: t.float() if t.is_floating_point() else t)

def forward(

self, x, c=None, rev=False, intermediate_outputs=False)

Forward or backward computation of the whole net.

def forward(self, x, c=None, rev=False, intermediate_outputs=False):
    '''Forward or backward computation of the whole net.'''
    if rev:
        use_list = self.indexed_ops_rev
        input_vars, output_vars = self.return_vars, self.input_vars
    else:
        use_list = self.indexed_ops
        input_vars, output_vars = self.input_vars, self.return_vars
    # Assign input data to respective variables
    if isinstance(x, (list, tuple)):
        assert len(x) == len(input_vars), (
            f"Got list of {len(x)} input tensors for "
            f"{'inverse' if rev else 'forward'} pass, but expected "
            f"{len(input_vars)}."
        )
        for i in range(len(input_vars)):
            self._buffers[F'tmp_var_{input_vars[i]}'] = x[i]
    else:
        assert len(input_vars) == 1, (f"Got single input tensor for "
                                      f"{'inverse' if rev else 'forward'} "
                                      f"pass, but expected list of "
                                      f"{len(input_vars)}.")
        self._buffers[F'tmp_var_{input_vars[0]}'] = x
    # Assign conditioning data to respective variables
    if c is None:
        assert len(self.cond_vars) == 0
    elif isinstance(c, (list, tuple)):
        assert len(c) == len(self.cond_vars), f'{len(c)}, {len(self.cond_vars)}'
        for i in range(len(self.cond_vars)):
            self._buffers[F'tmp_var_{self.cond_vars[i]}'] = c[i]
    else:
        assert len(self.cond_vars) == 1
        self._buffers[F'tmp_var_{self.cond_vars[0]}'] = c
    # Prepare dictionary for intermediate node outputs
    out_dict = {}
    # Run all modules with the given inputs
    for o in use_list:
        try:
            x = [self._buffers[F'tmp_var_{i}'] for i in o[1]]
            if self.module_cond[o[0]]:
                c = [self._buffers[F'tmp_var_{i}'] for i in o[3]]
                results = self.module_list[o[0]](x, c=c, rev=rev)
            else:
                results = self.module_list[o[0]](x, rev=rev)
        except TypeError:
            raise RuntimeError("Are you sure all used Nodes are in the "
                               "Node list?")
        out_dict[self.node_list[o[0]].name] = results
        for i, r in zip(o[2], results):
            self._buffers[F'tmp_var_{i}'] = r
    if intermediate_outputs:
        return out_dict
    else:
        out = [self._buffers[F'tmp_var_{output_vars[i]}']
               for i in range(len(output_vars))]
        if len(out) == 1:
            return out[0]
        else:
            return out

def half(

self)

Casts all floating point parameters and buffers to half datatype.

Returns: Module: self

def half(self):
    r"""Casts all floating point parameters and buffers to ``half`` datatype.
    Returns:
        Module: self
    """
    return self._apply(lambda t: t.half() if t.is_floating_point() else t)

def jacobian(

self, *args, **kwargs)

Compute the log jacobian determinant of the whole net.

def jacobian(self, *args, **kwargs):
    '''Compute the log jacobian determinant of the whole net.'''
    warnings.warn("This function computes the log-jacobian determinant, not the "
                  "jacobian as the name suggest. Will be removed in the future.")
    return self.log_jacobian(*args, **kwargs)

def load_state_dict(

self, state_dict, *args, **kwargs)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

Arguments: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's :meth:~torch.nn.Module.state_dict function. Default: True

Returns: NamedTuple with missing_keys and unexpected_keys fields: * missing_keys is a list of str containing the missing keys * unexpected_keys is a list of str containing the unexpected keys

def load_state_dict(self, state_dict, *args, **kwargs):
    state_dict_no_buffers = {}
    for k,p in state_dict.items():
        if k in self._buffers and self._buffers[k] is None:
            continue
        state_dict_no_buffers[k] = p
    return super().load_state_dict(state_dict_no_buffers, *args, **kwargs)

def log_jacobian(

self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False)

Compute the log jacobian determinant of the whole net.

def log_jacobian(self, x=None, c=None, rev=False, run_forward=True, intermediate_outputs=False):
    '''Compute the log jacobian determinant of the whole net.'''
    if run_forward or c is not None:
        self.condition = c
    jacobian = 0
    if rev:
        use_list = self.indexed_ops_rev
    else:
        use_list = self.indexed_ops
    if run_forward:
        if x is None:
            raise RuntimeError("You need to provide an input if you want "
                               "to run a forward pass")
        self.forward(x, c, rev=rev)
    # Prepare dictionary for intermediate node outputs
    jacobian_dict = {}
    # Run all modules with the given inputs
    for o in use_list:
        x = [self._buffers[F'tmp_var_{i}'] for i in o[1]]
        if self.module_cond[o[0]]:
            c = [self._buffers[F'tmp_var_{i}'] for i in o[3]]
            module_jacobian = self.module_list[o[0]].jacobian(x, c=c, rev=rev)
        else:
            module_jacobian = self.module_list[o[0]].jacobian(x, rev=rev)
        jacobian += module_jacobian
        jacobian_dict[self.node_list[o[0]].name] = module_jacobian
    if intermediate_outputs:
        return jacobian_dict
    else:
        return jacobian

def log_jacobian_numerical(

self, x, c=None, rev=False, h=0.0001)

Approximate log Jacobian determinant via finite differences.

def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04):
    '''Approximate log Jacobian determinant via finite differences.'''
    if isinstance(x, (list, tuple)):
        batch_size = x[0].shape[0]
        ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x]
        ndim_x_total = sum(ndim_x_separate)
        x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1)
    else:
        batch_size = x.shape[0]
        ndim_x_total = np.prod(x.shape[1:])
        x_flat = x.reshape(batch_size, -1)
    J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total)
    for i in range(ndim_x_total):
        offset = x[0].new_zeros(batch_size, ndim_x_total)
        offset[:,i] = h
        if isinstance(x, (list, tuple)):
            x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1)
            x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))]
            x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1)
            x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))]
        else:
            x_upper = (x_flat + offset).view(*x.shape)
            x_lower = (x_flat - offset).view(*x.shape)
        y_upper = self.forward(x_upper, c=c)
        y_lower = self.forward(x_lower, c=c)
        if isinstance(y_upper, (list, tuple)):
            y_upper = torch.cat([y_i.view(batch_size, -1) for y_i in y_upper], dim=1)
            y_lower = torch.cat([y_i.view(batch_size, -1) for y_i in y_lower], dim=1)
        J_num[:,:,i] = (y_upper - y_lower).view(batch_size, -1) / (2*h)
    logdet_num = x[0].new_zeros(batch_size)
    for i in range(batch_size):
        logdet_num[i] = torch.det(J_num[i,:,:]).abs().log()
    return logdet_num

def modules(

self)

Returns an iterator over all modules in the network.

Yields: Module: a module in the network

Note: Duplicate modules are returned only once. In the following example, l will be returned only once.

Example::

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
        print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
def modules(self):
    r"""Returns an iterator over all modules in the network.
    Yields:
        Module: a module in the network
    Note:
        Duplicate modules are returned only once. In the following
        example, ``l`` will be returned only once.
    Example::
        >>> l = nn.Linear(2, 2)
        >>> net = nn.Sequential(l, l)
        >>> for idx, m in enumerate(net.modules()):
                print(idx, '->', m)
        0 -> Sequential(
          (0): Linear(in_features=2, out_features=2, bias=True)
          (1): Linear(in_features=2, out_features=2, bias=True)
        )
        1 -> Linear(in_features=2, out_features=2, bias=True)
    """
    for name, module in self.named_modules():
        yield module

def named_buffers(

self, prefix='', recurse=True)

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Args: prefix (str): prefix to prepend to all buffer names. recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields: (string, torch.Tensor): Tuple containing the name and buffer

Example::

>>> for name, buf in self.named_buffers():
>>>    if name in ['running_var']:
>>>        print(buf.size())
def named_buffers(self, prefix='', recurse=True):
    r"""Returns an iterator over module buffers, yielding both the
    name of the buffer as well as the buffer itself.
    Args:
        prefix (str): prefix to prepend to all buffer names.
        recurse (bool): if True, then yields buffers of this module
            and all submodules. Otherwise, yields only buffers that
            are direct members of this module.
    Yields:
        (string, torch.Tensor): Tuple containing the name and buffer
    Example::
        >>> for name, buf in self.named_buffers():
        >>>    if name in ['running_var']:
        >>>        print(buf.size())
    """
    gen = self._named_members(
        lambda module: module._buffers.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

def named_children(

self)

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields: (string, Module): Tuple containing a name and child module

Example::

>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
def named_children(self):
    r"""Returns an iterator over immediate children modules, yielding both
    the name of the module as well as the module itself.
    Yields:
        (string, Module): Tuple containing a name and child module
    Example::
        >>> for name, module in model.named_children():
        >>>     if name in ['conv4', 'conv5']:
        >>>         print(module)
    """
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module

def named_modules(

self, memo=None, prefix='')

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Yields: (string, Module): Tuple of name and module

Note: Duplicate modules are returned only once. In the following example, l will be returned only once.

Example::

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
        print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
def named_modules(self, memo=None, prefix=''):
    r"""Returns an iterator over all modules in the network, yielding
    both the name of the module as well as the module itself.
    Yields:
        (string, Module): Tuple of name and module
    Note:
        Duplicate modules are returned only once. In the following
        example, ``l`` will be returned only once.
    Example::
        >>> l = nn.Linear(2, 2)
        >>> net = nn.Sequential(l, l)
        >>> for idx, m in enumerate(net.named_modules()):
                print(idx, '->', m)
        0 -> ('', Sequential(
          (0): Linear(in_features=2, out_features=2, bias=True)
          (1): Linear(in_features=2, out_features=2, bias=True)
        ))
        1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
    """
    if memo is None:
        memo = set()
    if self not in memo:
        memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            for m in module.named_modules(memo, submodule_prefix):
                yield m

def named_parameters(

self, prefix='', recurse=True)

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields: (string, Parameter): Tuple containing the name and parameter

Example::

>>> for name, param in self.named_parameters():
>>>    if name in ['bias']:
>>>        print(param.size())
def named_parameters(self, prefix='', recurse=True):
    r"""Returns an iterator over module parameters, yielding both the
    name of the parameter as well as the parameter itself.
    Args:
        prefix (str): prefix to prepend to all parameter names.
        recurse (bool): if True, then yields parameters of this module
            and all submodules. Otherwise, yields only parameters that
            are direct members of this module.
    Yields:
        (string, Parameter): Tuple containing the name and parameter
    Example::
        >>> for name, param in self.named_parameters():
        >>>    if name in ['bias']:
        >>>        print(param.size())
    """
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

def ops_to_indexed(

self, ops)

Helper function to translate the list of variables (origin ID, channel), to variable IDs.

def ops_to_indexed(self, ops):
    '''Helper function to translate the list of variables (origin ID, channel),
    to variable IDs.'''
    result = []
    for o in ops:
        try:
            vars_in = [self.variables_ind.index(v) for v in o[1]]
        except ValueError:
            vars_in = -1
        vars_out = [self.variables_ind.index(v) for v in o[2]]
        vars_cond = [self.variables_ind.index(v) for v in o[3]]
        # Collect input/output/conditioning nodes in separate lists, but don't
        # add to indexed ops
        if o[0] in self.ind_out:
            self.return_vars.append(self.variables_ind.index(o[1][0]))
            continue
        if o[0] in self.ind_in:
            self.input_vars.append(self.variables_ind.index(o[1][0]))
            continue
        if o[0] in self.ind_cond:
            if self.variables_ind.index(o[1][0]) not in self.cond_vars:
                self.cond_vars.append(self.variables_ind.index(o[1][0]))
            else:
                print('Is this branch ever reached?')
            continue
        result.append((o[0], vars_in, vars_out, vars_cond))
    # Sort input/output/conditioning variables so they correspond to initial
    # node list order
    self.return_vars.sort(key=lambda i: self.variables_ind[i][0])
    self.input_vars.sort(key=lambda i: self.variables_ind[i][0])
    self.cond_vars.sort(key=lambda i: self.variables_ind[i][0])
    return result

def parameters(

self, recurse=True)

Returns an iterator over module parameters.

This is typically passed to an optimizer.

Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields: Parameter: module parameter

Example::

>>> for param in model.parameters():
>>>     print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
def parameters(self, recurse=True):
    r"""Returns an iterator over module parameters.
    This is typically passed to an optimizer.
    Args:
        recurse (bool): if True, then yields parameters of this module
            and all submodules. Otherwise, yields only parameters that
            are direct members of this module.
    Yields:
        Parameter: module parameter
    Example::
        >>> for param in model.parameters():
        >>>     print(type(param.data), param.size())
        <class 'torch.FloatTensor'> (20L,)
        <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
    """
    for name, param in self.named_parameters(recurse=recurse):
        yield param

def register_backward_hook(

self, hook)

Registers a backward hook on the module.

The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature::

hook(module, grad_input, grad_output) -> Tensor or None

The :attr:grad_input and :attr:grad_output may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of :attr:grad_input in subsequent computations.

Returns: :class:torch.utils.hooks.RemovableHandle: a handle that can be used to remove the added hook by calling handle.remove()

.. warning ::

The current implementation will not have the presented behavior
for complex :class:`Module` that perform many operations.
In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
contain the gradients for a subset of the inputs and outputs.
For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
directly on a specific input or output to get the required gradients.
def register_backward_hook(self, hook):
    r"""Registers a backward hook on the module.
    The hook will be called every time the gradients with respect to module
    inputs are computed. The hook should have the following signature::
        hook(module, grad_input, grad_output) -> Tensor or None
    The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
    module has multiple inputs or outputs. The hook should not modify its
    arguments, but it can optionally return a new gradient with respect to
    input that will be used in place of :attr:`grad_input` in subsequent
    computations.
    Returns:
        :class:`torch.utils.hooks.RemovableHandle`:
            a handle that can be used to remove the added hook by calling
            ``handle.remove()``
    .. warning ::
        The current implementation will not have the presented behavior
        for complex :class:`Module` that perform many operations.
        In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
        contain the gradients for a subset of the inputs and outputs.
        For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
        directly on a specific input or output to get the required gradients.
    """
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle

def register_buffer(

self, name, tensor)

Adds a persistent buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the persistent state.

Buffers can be accessed as attributes using given names.

Args: name (string): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor): buffer to be registered.

Example::

>>> self.register_buffer('running_mean', torch.zeros(num_features))
def register_buffer(self, name, tensor):
    r"""Adds a persistent buffer to the module.
    This is typically used to register a buffer that should not to be
    considered a model parameter. For example, BatchNorm's ``running_mean``
    is not a parameter, but is part of the persistent state.
    Buffers can be accessed as attributes using given names.
    Args:
        name (string): name of the buffer. The buffer can be accessed
            from this module using the given name
        tensor (Tensor): buffer to be registered.
    Example::
        >>> self.register_buffer('running_mean', torch.zeros(num_features))
    """
    if '_buffers' not in self.__dict__:
        raise AttributeError(
            "cannot assign buffer before Module.__init__() call")
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("buffer name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("buffer name can't contain \".\"")
    elif name == '':
        raise KeyError("buffer name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._buffers:
        raise KeyError("attribute '{}' already exists".format(name))
    elif tensor is not None and not isinstance(tensor, torch.Tensor):
        raise TypeError("cannot assign '{}' object to buffer '{}' "
                        "(torch Tensor or None required)"
                        .format(torch.typename(tensor), name))
    else:
        self._buffers[name] = tensor

def register_forward_hook(

self, hook)

Registers a forward hook on the module.

The hook will be called every time after :func:forward has computed an output. It should have the following signature::

hook(module, input, output) -> None

The hook should not modify the input or output.

Returns: :class:torch.utils.hooks.RemovableHandle: a handle that can be used to remove the added hook by calling handle.remove()

def register_forward_hook(self, hook):
    r"""Registers a forward hook on the module.
    The hook will be called every time after :func:`forward` has computed an output.
    It should have the following signature::
        hook(module, input, output) -> None
    The hook should not modify the input or output.
    Returns:
        :class:`torch.utils.hooks.RemovableHandle`:
            a handle that can be used to remove the added hook by calling
            ``handle.remove()``
    """
    handle = hooks.RemovableHandle(self._forward_hooks)
    self._forward_hooks[handle.id] = hook
    return handle

def register_forward_pre_hook(

self, hook)

Registers a forward pre-hook on the module.

The hook will be called every time before :func:forward is invoked. It should have the following signature::

hook(module, input) -> None

The hook should not modify the input.

Returns: :class:torch.utils.hooks.RemovableHandle: a handle that can be used to remove the added hook by calling handle.remove()

def register_forward_pre_hook(self, hook):
    r"""Registers a forward pre-hook on the module.
    The hook will be called every time before :func:`forward` is invoked.
    It should have the following signature::
        hook(module, input) -> None
    The hook should not modify the input.
    Returns:
        :class:`torch.utils.hooks.RemovableHandle`:
            a handle that can be used to remove the added hook by calling
            ``handle.remove()``
    """
    handle = hooks.RemovableHandle(self._forward_pre_hooks)
    self._forward_pre_hooks[handle.id] = hook
    return handle

def register_parameter(

self, name, param)

Adds a parameter to the module.

The parameter can be accessed as an attribute using given name.

Args: name (string): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter): parameter to be added to the module.

def register_parameter(self, name, param):
    r"""Adds a parameter to the module.
    The parameter can be accessed as an attribute using given name.
    Args:
        name (string): name of the parameter. The parameter can be accessed
            from this module using the given name
        param (Parameter): parameter to be added to the module.
    """
    if '_parameters' not in self.__dict__:
        raise AttributeError(
            "cannot assign parameter before Module.__init__() call")
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("parameter name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("parameter name can't contain \".\"")
    elif name == '':
        raise KeyError("parameter name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._parameters:
        raise KeyError("attribute '{}' already exists".format(name))
    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        raise TypeError("cannot assign '{}' object to parameter '{}' "
                        "(torch.nn.Parameter or None required)"
                        .format(torch.typename(param), name))
    elif param.grad_fn:
        raise ValueError(
            "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
            "parameters must be created explicitly. To express '{0}' "
            "as a function of another Tensor, compute the value in "
            "the forward() method.".format(name))
    else:
        self._parameters[name] = param

def share_memory(

self)

def share_memory(self):
    return self._apply(lambda t: t.share_memory_())

def state_dict(

self, destination=None, prefix='', keep_vars=False)

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

Returns: dict: a dictionary containing a whole state of the module

Example::

>>> module.state_dict().keys()
['bias', 'weight']
def state_dict(self, destination=None, prefix='', keep_vars=False):
    r"""Returns a dictionary containing a whole state of the module.
    Both parameters and persistent buffers (e.g. running averages) are
    included. Keys are corresponding parameter and buffer names.
    Returns:
        dict:
            a dictionary containing a whole state of the module
    Example::
        >>> module.state_dict().keys()
        ['bias', 'weight']
    """
    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()
    destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
    for name, param in self._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.data
    for name, buf in self._buffers.items():
        if buf is not None:
            destination[prefix + name] = buf if keep_vars else buf.data
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
    for hook in self._state_dict_hooks.values():
        hook_result = hook(self, destination, prefix, local_metadata)
        if hook_result is not None:
            destination = hook_result
    return destination

def to(

self, *args, **kwargs)

Moves and/or casts the parameters and buffers.

This can be called as

.. function:: to(device=None, dtype=None, non_blocking=False)

.. function:: to(dtype, non_blocking=False)

.. function:: to(tensor, non_blocking=False)

Its signature is similar to :meth:torch.Tensor.to, but only accepts floating point desired :attr:dtype s. In addition, this method will only cast the floating point parameters and buffers to :attr:dtype (if given). The integral parameters and buffers will be moved :attr:device, if that is given, but with dtypes unchanged. When :attr:non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

.. note:: This method modifies the module in-place.

Args: device (:class:torch.device): the desired device of the parameters and buffers in this module dtype (:class:torch.dtype): the desired floating point type of the floating point parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

Returns: Module: self

Example::

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)
def to(self, *args, **kwargs):
    r"""Moves and/or casts the parameters and buffers.
    This can be called as
    .. function:: to(device=None, dtype=None, non_blocking=False)
    .. function:: to(dtype, non_blocking=False)
    .. function:: to(tensor, non_blocking=False)
    Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
    floating point desired :attr:`dtype` s. In addition, this method will
    only cast the floating point parameters and buffers to :attr:`dtype`
    (if given). The integral parameters and buffers will be moved
    :attr:`device`, if that is given, but with dtypes unchanged. When
    :attr:`non_blocking` is set, it tries to convert/move asynchronously
    with respect to the host if possible, e.g., moving CPU Tensors with
    pinned memory to CUDA devices.
    See below for examples.
    .. note::
        This method modifies the module in-place.
    Args:
        device (:class:`torch.device`): the desired device of the parameters
            and buffers in this module
        dtype (:class:`torch.dtype`): the desired floating point type of
            the floating point parameters and buffers in this module
        tensor (torch.Tensor): Tensor whose dtype and device are the desired
            dtype and device for all parameters and buffers in this module
    Returns:
        Module: self
    Example::
        >>> linear = nn.Linear(2, 2)
        >>> linear.weight
        Parameter containing:
        tensor([[ 0.1913, -0.3420],
                [-0.5113, -0.2325]])
        >>> linear.to(torch.double)
        Linear(in_features=2, out_features=2, bias=True)
        >>> linear.weight
        Parameter containing:
        tensor([[ 0.1913, -0.3420],
                [-0.5113, -0.2325]], dtype=torch.float64)
        >>> gpu1 = torch.device("cuda:1")
        >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
        Linear(in_features=2, out_features=2, bias=True)
        >>> linear.weight
        Parameter containing:
        tensor([[ 0.1914, -0.3420],
                [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
        >>> cpu = torch.device("cpu")
        >>> linear.to(cpu)
        Linear(in_features=2, out_features=2, bias=True)
        >>> linear.weight
        Parameter containing:
        tensor([[ 0.1914, -0.3420],
                [-0.5112, -0.2324]], dtype=torch.float16)
    """
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
    if dtype is not None:
        if not dtype.is_floating_point:
            raise TypeError('nn.Module.to only accepts floating point '
                            'dtypes, but got desired dtype={}'.format(dtype))
    def convert(t):
        return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
    return self._apply(convert)

def train(

self, mode=True)

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

Returns: Module: self

def train(self, mode=True):
    r"""Sets the module in training mode.
    This has any effect only on certain modules. See documentations of
    particular modules for details of their behaviors in training/evaluation
    mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
    etc.
    Returns:
        Module: self
    """
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

def type(

self, dst_type)

Casts all parameters and buffers to :attr:dst_type.

Arguments: dst_type (type or string): the desired type

Returns: Module: self

def type(self, dst_type):
    r"""Casts all parameters and buffers to :attr:`dst_type`.
    Arguments:
        dst_type (type or string): the desired type
    Returns:
        Module: self
    """
    return self._apply(lambda t: t.type(dst_type))

def zero_grad(

self)

Sets gradients of all model parameters to zero.

def zero_grad(self):
    r"""Sets gradients of all model parameters to zero."""
    for p in self.parameters():
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()

Instance variables

var cond_vars

var ind_cond

var indexed_ops

var indexed_ops_rev

var input_vars

var module_cond

var module_list

var node_list

var return_vars

var variables_ind