PyTorch 集成¶
简介¶
matxscript内置了对pytorch的支持,通过matx.script()将一个Pytorch实例包装成一个InferenceOp,可以用于被trace的pipeline中。
使用方式¶
从 ScriptModule(ScriptFunction) 构建¶
import torch
class MyCell(torch.nn.Module):
def __init__(self):
    super(MyCell, self).__init__()
    self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
    new_h = torch.tanh(self.linear(x) + h)
    return new_h, new_h
device = torch.device("cuda:0")
my_cell = MyCell().to(device)
script_model = torch.jit.trace(my_cell, (torch.rand(3, 4, device=device), torch.rand(3, 4, device=device)))
通过给定ScriptModule和设备(device id),我们可以将一个ScriptModule封装成matx op
import matx
infer_op = matx.script(script_model, device=0)
import matx
infer_op = matx.script("model", backend='pytorch', device=0)
x = matx.array.rand([3, 4])
h = matx.array.rand([3, 4])
def process(x, h):
    return infer_op(x, h)
r = process(x, h)
print(r)
mod = matx.trace(process, x, h)
r = mod.run({'x': x, 'h': h})
print(r)
从 nn.Module 构建¶
沿用MyCell模型,我们将其直接封装为InferenceOp。
infer_op = matx.script(my_cell, device=0)
同样infer_op可以进行调用及pipeline trace,在pipeline trace的过程中,InferenceOp内部会调用torch.jit.trace将nn.Module转换为ScriptModule, 因此nn.Module构造InferenceOp在本质上和ScriptModule没有什么不同,需要值得注意的是,在使用nn.Module构造的InfereceOp进行trace时,需要用户保证该nn.Module可以使用torch.jit.trace转换为ScriptModule。
注意事项¶
- InferenceOp需要指定device id,通用在session加载时也需要指定device id,其关系如下: - InferenceOp device为cpu,则不关心session加载的device,InferenceOp在cpu上执行。 
- InferenceOp device为gpu,session加载的device为gpu,则忽略InferenceOp的id号,InfereceOp在session加载的device上执行。 
- InferenceOp device为gpu,session加载的device为cpu,行为未定义。 
 
- 目前要求pytorch model输出的tensor是contiguous的,对于pytorch model非contiguous的tensor,pytorch model内部调用tensor.contiguous在输出前对其进行转换。