TensorFlow 集成

保存模型

import tensorflow as tf

class Adder(tf.Module):
@tf.function
def add(self, x):
    return x + x

model = Adder()

tf.saved_model.save(model, "./", signatures=model.add.get_concrete_function(tf.TensorSpec([], tf.float32)))

使用 matxscript 加载 SavedModel

import matx

tf_op = matx.script("./", backend='TensorFlow', device=-1, use_xla=0, allow_growth=False)

进行trace和推理

ix = matx.NDArray([1], [1], 'float32')

def process(x):
    return tf_op({"x":x})

ret = process(ix)
print(ret)

s = matx.trace(process, ix)
ret = s.run({"x":ix})
print(ret)