2020-1-25 created by AD1024

PyTorch

Machine Learning

The computation graph of PyTorch is define-by-run, which means no one knows what the computation graph looks like until an input is passed to the model. Therefore, there are some graph-level optimizations that are useful but hard to apply to models defined in pytorch since there is no certain API that offers you interface to traverse the graph, so you cannot operate the computation graph directly. But the JIT (Just-in-time compiler) of PyTorch provides a method that can compile our python code to TorchScript, which is an IR (Intermediate Representation) developed by PyTorch group and is used to lower the surface-level language (Python) to C++ / Device level. The JIT can give back a “graph,” which is the AST of the IR, and provides us some useful interfaces to walk on the tree. By taking advantage of this function `torch.jit._get_trace_graph`

, we are able to construct the computation graph by examining the properties of nodes in AST.

In this note, we take `ResNet18`

as an example (since it is very simple). In the REPL, if we run

```
import torch
import torchvision
model = torchvision.models.resnet18()
inp = torch.zeros([64, 3, 7, 7])
trace, grad = torch.jit._get_trace_graph(model, inp)
```

we can get the IR “graph” and the result of calling `forward`

on `model`

(a tensor). The major concern of our problem is `trace`

. It is literally a trace of the forward pass, i.e. what operators and variables we used and what parameters the forward function got.

Even though we have the trace, it is not a real graph that is traversable. So we need to construct the graph using the nodes in the AST. Here are some useful functions:

`trace.nodes()`

: returns a stream of AST node (iterator)`trace.param_node()`

: the parameter passed to the forward pass. For an AST node object, we have`node.kind()`

: returns the operator (scope::name)`node.outputs()`

: returns a stream of AST nodes that are output variables of the node`node.inputs()`

: returns a stream of AST nodes who are passed as input arguments to the node.`node.unique()`

: returns a unique identifier of the node

Therefore, we can construct the information of edges by using the inputs and outputs of nodes in the AST. As a computation graph, it should be a DAG (directed-acyclic graph). Therefore, Node *A* has an outgoing edge connects to Node *B* if and only if $S_{o}^A \cap S_{i}^B \neq \emptyset$ where $S_{o}^A$ is the set of output nodes of Node *A* and $S_{i}^B$ is the set of input nodes of Node *B*.

Fortunately, if we don’t consider control flow and recursion, the problem is very staightforward and intuitive thanks to TorchScript. Because TorchScript is compiled to a SSA (Static Single Assignment) form, which means there is no effectful assignments (modifying the value of a variable directly like `x += 1`

). So, basically, it makes the work of compiling the graph back easier (will talk in Part II of the note).

So we can write out this code:

```
for node in graph.nodes():
op = node.kind();
outputs = [o.unique() for o in node.outputs()]
inputs = [i.unique() for i in node.inputs()]
graph_node = # Create a node for current AST Node
parsed_graph[graph_node.id] = graph_node
# Reference:
# https://github.com/waleedka/hiddenlayer/blob/master/hiddenlayer/pytorch_builder.py
for to in graph.nodes():
to_inputs = [i.unique() for i in to.inputs()]
edges = set(outputs) & set(to_inputs)
if edges:
graph_node.edges.append(Edge(to, edges))
```

In order to be able to compile it back to a python script, we need to record some information about the node (like shapes of the outputs). So we can define the graph node structure as follows:

```
class Node:
'''
A Computation Graph Node Representation
'''
def __init__(self, name, op, params, shape, inputs, outputs, output_size=None):
self.id = name
self.op = op
self.params = params
self.shape = shape
self.outputs = outputs
self.inputs = inputs
self.checkpoint = False
self.output_size = output_size
self.edges = []
def adjacent_nodes(self, graph):
if (self.edges):
for name, src_name in self.edges:
yield graph[name], src_name
def get_output_size(self):
raise NotImplementedError()
def to_python(self, ctx: dict, src=False, inline=True):
raise NotImplementedError()
```

There are two functions left not implemented, since we have two kinds of node: Prim scope and Aten scope. These two kinds of nodes should be handled differently while compiling it back since the `prim`

scope are mainly constants and Lists, whereas the `aten`

nodes are operators (callable objects). In order to differentiate nodes from nodes, we need a function to assign them names. Here, since the outputs of nodes are never identical in SSA, we can use that information for constructing the name.

```
def create_name(node):
return node.kind() \
+ '~>' \
+ hashlib.md5(str(reduce(
lambda x,y: x + y,
[str(x) for x in sorted((y.unique() for y in node.outputs()))]
)).encode()).hexdigest()
```

There is an implementation in HiddenLayer that uses regex to match out the shape. But it turns out (maybe because of the update of pytorch) that we can actually get the information from the typing information. For a tensor output node (of aten nodes),

```
output.type() # can give us the type
output.type().sizes() # can give us the shape
```

and for a `prim`

scope node output

```
output.toIValue() # can give us the value of the output (constant node)
```

Therefore, we can obtain the information using the following functions:

```
def get_shape(node) -> dict:
outputs = dict()
for o in node.outputs():
typeIs = o.type()
outputs[o.unique()] = Shape(type=re.match(r'\w+', typeIs.str()).group(), sizes=tuple(typeIs.sizes()))
return outputs
def get_value(node) -> dict:
outputs = dict()
for o in node.outputs():
typeIs = o.type().str()
value = o.toIValue()
outputs[o.unique()] = Value(type=typeIs, value=value,\
sizes=len(list(node.outputs())) if typeIs.endswith('[]') else 1)
return outputs
```

and we can do

```
graph_node = AtenNode(create_name(node), op, params, get_shape(node), inputs, outputs)
```

to get a graph node.