The computation graph of PyTorch is define-by-run, which means no one knows what the computation graph looks like until an input is passed to the model. Therefore, there are some graph-level optimizations that are useful but hard to apply to models defined in pytorch since there is no certain API that offers you interface to traverse the graph, so you cannot operate the computation graph directly. But the JIT (Just-in-time compiler) of PyTorch provides a method that can compile our python code to TorchScript, which is an IR (Intermediate Representation) developed by PyTorch group and is used to lower the surface-level language (Python) to C++ / Device level. The JIT can give back a “graph,” which is the AST of the IR, and provides us some useful interfaces to walk on the tree. By taking advantage of this function torch.jit._get_trace_graph
, we are able to construct the computation graph by examining the properties of nodes in AST.
In this note, we take ResNet18
as an example (since it is very simple). In the REPL, if we run
import torch
import torchvision
model = torchvision.models.resnet18()
inp = torch.zeros([64, 3, 7, 7])
trace, grad = torch.jit._get_trace_graph(model, inp)
we can get the IR “graph” and the result of calling forward
on model
(a tensor). The major concern of our problem is trace
. It is literally a trace of the forward pass, i.e. what operators and variables we used and what parameters the forward function got.
Even though we have the trace, it is not a real graph that is traversable. So we need to construct the graph using the nodes in the AST. Here are some useful functions:
trace.nodes()
: returns a stream of AST node (iterator)trace.param_node()
: the parameter passed to the forward pass.
For an AST node object, we havenode.kind()
: returns the operator (scope::name)node.outputs()
: returns a stream of AST nodes that are output variables of the nodenode.inputs()
: returns a stream of AST nodes who are passed as input arguments to the node.node.unique()
: returns a unique identifier of the nodeTherefore, we can construct the information of edges by using the inputs and outputs of nodes in the AST. As a computation graph, it should be a DAG (directed-acyclic graph). Therefore, Node A has an outgoing edge connects to Node B if and only if $S_{o}^A \cap S_{i}^B \neq \emptyset$ where $S_{o}^A$ is the set of output nodes of Node A and $S_{i}^B$ is the set of input nodes of Node B.
Fortunately, if we don’t consider control flow and recursion, the problem is very staightforward and intuitive thanks to TorchScript. Because TorchScript is compiled to a SSA (Static Single Assignment) form, which means there is no effectful assignments (modifying the value of a variable directly like x += 1
). So, basically, it makes the work of compiling the graph back easier (will talk in Part II of the note).
So we can write out this code:
for node in graph.nodes():
op = node.kind();
outputs = [o.unique() for o in node.outputs()]
inputs = [i.unique() for i in node.inputs()]
graph_node = # Create a node for current AST Node
parsed_graph[graph_node.id] = graph_node
# Reference:
# https://github.com/waleedka/hiddenlayer/blob/master/hiddenlayer/pytorch_builder.py
for to in graph.nodes():
to_inputs = [i.unique() for i in to.inputs()]
edges = set(outputs) & set(to_inputs)
if edges:
graph_node.edges.append(Edge(to, edges))
In order to be able to compile it back to a python script, we need to record some information about the node (like shapes of the outputs). So we can define the graph node structure as follows:
class Node:
'''
A Computation Graph Node Representation
'''
def __init__(self, name, op, params, shape, inputs, outputs, output_size=None):
self.id = name
self.op = op
self.params = params
self.shape = shape
self.outputs = outputs
self.inputs = inputs
self.checkpoint = False
self.output_size = output_size
self.edges = []
def adjacent_nodes(self, graph):
if (self.edges):
for name, src_name in self.edges:
yield graph[name], src_name
def get_output_size(self):
raise NotImplementedError()
def to_python(self, ctx: dict, src=False, inline=True):
raise NotImplementedError()
There are two functions left not implemented, since we have two kinds of node: Prim scope and Aten scope. These two kinds of nodes should be handled differently while compiling it back since the prim
scope are mainly constants and Lists, whereas the aten
nodes are operators (callable objects). In order to differentiate nodes from nodes, we need a function to assign them names. Here, since the outputs of nodes are never identical in SSA, we can use that information for constructing the name.
def create_name(node):
return node.kind() \
+ '~>' \
+ hashlib.md5(str(reduce(
lambda x,y: x + y,
[str(x) for x in sorted((y.unique() for y in node.outputs()))]
)).encode()).hexdigest()
There is an implementation in HiddenLayer that uses regex to match out the shape. But it turns out (maybe because of the update of pytorch) that we can actually get the information from the typing information. For a tensor output node (of aten nodes),
output.type() # can give us the type
output.type().sizes() # can give us the shape
and for a prim
scope node output
output.toIValue() # can give us the value of the output (constant node)
Therefore, we can obtain the information using the following functions:
def get_shape(node) -> dict:
outputs = dict()
for o in node.outputs():
typeIs = o.type()
outputs[o.unique()] = Shape(type=re.match(r'\w+', typeIs.str()).group(), sizes=tuple(typeIs.sizes()))
return outputs
def get_value(node) -> dict:
outputs = dict()
for o in node.outputs():
typeIs = o.type().str()
value = o.toIValue()
outputs[o.unique()] = Value(type=typeIs, value=value,\
sizes=len(list(node.outputs())) if typeIs.endswith('[]') else 1)
return outputs
and we can do
graph_node = AtenNode(create_name(node), op, params, get_shape(node), inputs, outputs)
to get a graph node.