TensorFlow 集成¶
保存模型¶
import tensorflow as tf
class Adder(tf.Module):
@tf.function
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, "./", signatures=model.add.get_concrete_function(tf.TensorSpec([], tf.float32)))
使用 matxscript 加载 SavedModel¶
import matx
tf_op = matx.script("./", backend='TensorFlow', device=-1, use_xla=0, allow_growth=False)
进行trace和推理¶
ix = matx.NDArray([1], [1], 'float32')
def process(x):
return tf_op({"x":x})
ret = process(ix)
print(ret)
s = matx.trace(process, ix)
ret = s.run({"x":ix})
print(ret)