PyTorch 集成

简介

matxscript内置了对pytorch的支持,通过matx.script()将一个Pytorch实例包装成一个InferenceOp,可以用于被trace的pipeline中。

使用方式

从 ScriptModule(ScriptFunction) 构建

1. Define a nn.Module and call torch.jit.trace
import torch

class MyCell(torch.nn.Module):
def __init__(self):
    super(MyCell, self).__init__()
    self.linear = torch.nn.Linear(4, 4)

def forward(self, x, h):
    new_h = torch.tanh(self.linear(x) + h)
    return new_h, new_h

device = torch.device("cuda:0")
my_cell = MyCell().to(device)
script_model = torch.jit.trace(my_cell, (torch.rand(3, 4, device=device), torch.rand(3, 4, device=device)))
2. Construct InferenceOp

通过给定ScriptModule和设备(device id),我们可以将一个ScriptModule封装成matx op

2.1 从已有实例构建
import matx

infer_op = matx.script(script_model, device=0)
2.2 从本地文件构建
import matx

infer_op = matx.script("model", backend='pytorch', device=0)
3. Now we can use infer_op as a normal matx op or call it in pipeline for trace. Notice that the inputs for calling infer_op are the same as ScriptModule, but users have to substitute torch.tensor with matx.NDArray.
x = matx.array.rand([3, 4])
h = matx.array.rand([3, 4])

def process(x, h):
    return infer_op(x, h)

r = process(x, h)
print(r)
mod = matx.trace(process, x, h)
r = mod.run({'x': x, 'h': h})
print(r)

从 nn.Module 构建

沿用MyCell模型,我们将其直接封装为InferenceOp。

infer_op = matx.script(my_cell, device=0)

同样infer_op可以进行调用及pipeline trace,在pipeline trace的过程中,InferenceOp内部会调用torch.jit.trace将nn.Module转换为ScriptModule, 因此nn.Module构造InferenceOp在本质上和ScriptModule没有什么不同,需要值得注意的是,在使用nn.Module构造的InfereceOp进行trace时,需要用户保证该nn.Module可以使用torch.jit.trace转换为ScriptModule。

注意事项

  1. InferenceOp需要指定device id,通用在session加载时也需要指定device id,其关系如下:

    1. InferenceOp device为cpu,则不关心session加载的device,InferenceOp在cpu上执行。

    2. InferenceOp device为gpu,session加载的device为gpu,则忽略InferenceOp的id号,InfereceOp在session加载的device上执行。

    3. InferenceOp device为gpu,session加载的device为cpu,行为未定义。

  2. 目前要求pytorch model输出的tensor是contiguous的,对于pytorch model非contiguous的tensor,pytorch model内部调用tensor.contiguous在输出前对其进行转换。