多维数组(NDArray)

NDArray 是我们表示多维矩阵的数据类型,我们目前只实现了简单的数据装载和转换操作。 类似numpy,matx实现了自己的NDArray数据结构来表示多维数组。目前NDArray主要定位为各个深度学习框架(pytorch/tensorflow/tvm)的tensor结构进行桥接数据结构,我们并未在NDArray上定义完备的算子。

构造

构造参数列表

Args

Type

Description

arr

List

list对象,指定构造出的NDArray的内容。

shape

List

list对象,指定构造出的NDArray的shape,可以为[](空list),为[]时,构造出的NDArray shape和arr相同。

dtype

str

NDArray存储的数据类型,目前支持的类型:int32 int64 float32 float64 uint8 bool

device

str

NDArray存储的设配信息,目前支持类型:cpu cuda:%d gpu:%d,默认为cpu

示例1:指定shape,将传入的一维list变换为指定shape的多维NDArray

>>> import matx
>>> nd = matx.NDArray([1,2,3,4], [2, 2], "int32")
>>> nd
[
[ 1 2 ]
[ 3 4 ]
]

>>> nd.shape()
[2, 2]
>>> nd.dtype()
'int32'
>>>

示例2:不指定shape,按照传入的list shape构造NDArray

>>> import matx
>>> nd = matx.NDArray([[1,2],[3,4]], [], "int32")
>>> nd
[
[ 1 2 ]
[ 3 4 ]
]

>>> nd.shape()
[2, 2]
>>> nd.dtype()
'int32'
>>>

更多见api文档