From Models to Computation Graphs (Part I)

2020-1-25 created by AD1024
PyTorch
Machine Learning

Define-By-Run

The computation graph of PyTorch is define-by-run, which means no one knows what the computation graph looks like until an input is passed to the model. Therefore, there are some graph-level optimizations that are useful but hard to apply to models defined in pytorch since there is no certain API that offers you interface to traverse the graph, so you cannot operate the computation graph directly. But the JIT (Just-in-time compiler) of PyTorch provides a method that can compile our python code to TorchScript, which is an IR (Intermediate Representation) developed by PyTorch group and is used to lower the surface-level language (Python) to C++ / Device level. The JIT can give back a “graph,” which is the AST of the IR, and provides us some useful interfaces to walk on the tree. By taking advantage of this function torch.jit._get_trace_graph, we are able to construct the computation graph by examining the properties of nodes in AST.

First Step - Getting the “Graph”

In this note, we take ResNet18 as an example (since it is very simple). In the REPL, if we run

import torch
import torchvision
model = torchvision.models.resnet18()
inp   = torch.zeros([64, 3, 7, 7])
trace, grad = torch.jit._get_trace_graph(model, inp)

we can get the IR “graph” and the result of calling forward on model (a tensor). The major concern of our problem is trace. It is literally a trace of the forward pass, i.e. what operators and variables we used and what parameters the forward function got.

Second Step - Parsing to a real Graph

Even though we have the trace, it is not a real graph that is traversable. So we need to construct the graph using the nodes in the AST. Here are some useful functions:

Therefore, we can construct the information of edges by using the inputs and outputs of nodes in the AST. As a computation graph, it should be a DAG (directed-acyclic graph). Therefore, Node A has an outgoing edge connects to Node B if and only if $S_{o}^A \cap S_{i}^B \neq \emptyset$ where $S_{o}^A$ is the set of output nodes of Node A and $S_{i}^B$ is the set of input nodes of Node B.

Fortunately, if we don’t consider control flow and recursion, the problem is very staightforward and intuitive thanks to TorchScript. Because TorchScript is compiled to a SSA (Static Single Assignment) form, which means there is no effectful assignments (modifying the value of a variable directly like x += 1). So, basically, it makes the work of compiling the graph back easier (will talk in Part II of the note).

So we can write out this code:

for node in graph.nodes():
        op = node.kind();
        outputs = [o.unique() for o in node.outputs()]
        inputs  = [i.unique() for i in node.inputs()]
        graph_node = # Create a node for current AST Node
        parsed_graph[graph_node.id] = graph_node
        # Reference:
        # https://github.com/waleedka/hiddenlayer/blob/master/hiddenlayer/pytorch_builder.py
        for to in graph.nodes():
            to_inputs = [i.unique() for i in to.inputs()]
            edges = set(outputs) & set(to_inputs)
            if edges:
                graph_node.edges.append(Edge(to, edges))

Structure of Graph Node

In order to be able to compile it back to a python script, we need to record some information about the node (like shapes of the outputs). So we can define the graph node structure as follows:

class Node:
    '''
        A Computation Graph Node Representation
    '''
    def __init__(self, name, op, params, shape, inputs, outputs, output_size=None):
        self.id = name
        self.op = op
        self.params = params
        self.shape = shape
        self.outputs = outputs
        self.inputs  = inputs
        self.checkpoint = False
        self.output_size = output_size
        self.edges = []
    
    def adjacent_nodes(self, graph):
        if (self.edges):
            for name, src_name in self.edges:
                yield graph[name], src_name
    
    def get_output_size(self):
        raise NotImplementedError()
    
    def to_python(self, ctx: dict, src=False, inline=True):
        raise NotImplementedError()

There are two functions left not implemented, since we have two kinds of node: Prim scope and Aten scope. These two kinds of nodes should be handled differently while compiling it back since the prim scope are mainly constants and Lists, whereas the aten nodes are operators (callable objects). In order to differentiate nodes from nodes, we need a function to assign them names. Here, since the outputs of nodes are never identical in SSA, we can use that information for constructing the name.

def create_name(node):
    return node.kind() \
            + '~>'     \
            + hashlib.md5(str(reduce(
                                lambda x,y: x + y,
                                [str(x) for x in sorted((y.unique() for y in node.outputs()))]
                            )).encode()).hexdigest()

Getting Values and Shapes of an Output Node

There is an implementation in HiddenLayer that uses regex to match out the shape. But it turns out (maybe because of the update of pytorch) that we can actually get the information from the typing information. For a tensor output node (of aten nodes),

output.type() # can give us the type
output.type().sizes() # can give us the shape

and for a prim scope node output

output.toIValue() # can give us the value of the output (constant node)

Therefore, we can obtain the information using the following functions:

def get_shape(node) -> dict:
    outputs = dict()
    for o in node.outputs():
        typeIs = o.type()
        outputs[o.unique()] = Shape(type=re.match(r'\w+', typeIs.str()).group(), sizes=tuple(typeIs.sizes()))
    return outputs

def get_value(node) -> dict:
    outputs = dict()
    for o in node.outputs():
        typeIs = o.type().str()
        value  = o.toIValue()
        outputs[o.unique()] = Value(type=typeIs, value=value,\
                                    sizes=len(list(node.outputs())) if typeIs.endswith('[]') else 1)
    return outputs

and we can do

graph_node = AtenNode(create_name(node), op, params, get_shape(node), inputs, outputs)

to get a graph node.