# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
from . import _hooks
from . import runtime
from .contrib import cpp_extension
from .contrib.cpp_extension import get_cflags, get_link_flags
from . import toolchain
from . import pipeline
from .toolchain import ToolChain
from . import extension
from .runtime import msgpack_loads, msgpack_dumps
from . import text
from . import vision
from . import tools
# APIs
__all__ = [
# functions
"list_sort",
"pmap",
"pstarmap",
"load_so",
"trace",
"script",
"script_embedded_class",
"save",
"load",
"get_cflags",
"get_link_flags",
# alias
"runtime",
"toolchain",
"pipeline",
"cpp_extension",
"ToolChain",
"extension",
# version info
"__version__",
"__branch__",
"__commit_id__",
# matx._ffi
"TError",
"register_func",
"get_global_func",
"get_global_func",
# matx.runtime
"Object",
"NDArray",
"array",
"List",
"Dict",
"Set",
"Tuple",
"OpaqueObject",
"to_runtime_object",
"File",
"Regex",
"Trie",
"list_heapify",
"list_heap_replace",
"list_nth_element",
"list_heap_pushpop",
"serialize",
"deserialize",
# matx.native
"native",
"NativeObject",
"make_native_object_creator",
"make_native_object",
"make_native_function",
"load_native",
# modules
"pypi",
"ir",
"ir_module",
"contrib",
"error",
"Device",
]
# version info
__version__ = "1.8.1"
__branch__ = 'main'
__commit_id__ = '97dfa3d'
# top-level alias
# matx._ffi
from ._ffi.base import TError
from ._ffi import register_func, get_global_func, to_packed_func
# top-level alias
# matx.runtime
from .runtime.object import Object
from .runtime.ndarray import NDArray
from .runtime import ndarray as array
from .runtime import List, Dict, Set, Tuple
from .runtime import OpaqueObject
from .runtime.object_generic import to_runtime_object
from .runtime.file import File
from .runtime.regex import Regex
from .runtime.trie import Trie
from .runtime._container._list import heapify as list_heapify
from .runtime._container._list import heap_replace as list_heap_replace
from .runtime._container._list import nth_element as list_nth_element
from .runtime._container._list import heap_pushpop as list_heap_pushpop
from .runtime.cpp_logging import set_cpp_logging_level, get_cpp_logging_level
from .runtime.cpp_logging import FATAL, ERROR, WARNING, INFO, DEBUG
from .runtime.picke import serialize, deserialize
# matx.native
from . import native
from .native import NativeObject
from .native import make_native_object_creator
from .native import make_native_object
from .native import make_native_function
from .native import call_native_function
from .native import load_native
# matx.pypi
from . import pypi
# matx.ir
from . import ir
from .script import module as ir_module
# matx.contrib
from . import contrib
# matx.error
from . import error
from .typing import *
from .pipeline.ops import DeviceOp as Device
# compiling api
[文档]def list_sort(l, compare=None):
if compare is None:
if isinstance(l, list):
l.sort()
elif isinstance(l, List):
from .runtime import _ffi_api
_ffi_api.ListSort(l)
else:
raise ValueError("type of first arg must be list")
else:
from functools import cmp_to_key
if isinstance(l, list):
l.sort(key=cmp_to_key(compare))
elif isinstance(l, List):
from .pipeline.ops import OpKernel
if not isinstance(compare, OpKernel):
tmp = sorted(l, key=cmp_to_key(compare))
l.clear()
for item in tmp:
l.append(item)
else:
from .runtime import _ffi_api
_ffi_api.ListSort(l, compare)
else:
raise ValueError("type of first arg must be list")
[文档]def pmap(func, data):
from . import pipeline
from .pipeline._base import TXObject
from .pipeline import _ffi_api
from .pipeline._tracing_state import tracing
if tracing():
from .pipeline import builtin_op
pmap_op = builtin_op.get_interpreter_op("ParallelMap")
if not isinstance(func, (pipeline.ops.OpKernel, runtime.object.ObjectBase)):
raise TypeError(f"matx.pmap: the first argument '{func}' is not a traceable op")
if not isinstance(
data, (pipeline.symbol.BaseSymbol, runtime.List, runtime.Tuple, list, tuple)
):
raise TypeError(f"matx.pmap: the second argument '{data}' is not supported")
return pmap_op(func, data)
if not isinstance(func, (pipeline.ops.OpKernel, runtime.object.ObjectBase)):
# Python mode
if isinstance(data, (list, runtime.List)):
return [func(x) for x in data]
elif isinstance(data, (tuple, runtime.Tuple)):
return tuple(func(x) for x in data)
else:
raise TypeError(f"expect the second argument is list or tuple, but get '{data}'")
sess_handle = TXObject.default_sess.c_handle
return _ffi_api.ParallelMap(func, data, sess_handle)
[文档]def pstarmap(func, data):
from . import pipeline
from .pipeline._base import TXObject
from .pipeline import _ffi_api
from .pipeline._tracing_state import tracing
if tracing():
from .pipeline import builtin_op
pstarmap_op = builtin_op.get_interpreter_op("ParallelStarMap")
if not isinstance(func, (pipeline.ops.OpKernel, runtime.object.ObjectBase)):
raise TypeError(f"matx.pstarmap: the first argument '{func}' is not a traceable op")
if not isinstance(
data, (pipeline.symbol.BaseSymbol, runtime.List, runtime.Tuple, list, tuple)
):
raise TypeError(f"matx.pstarmap: the second argument '{data}' is not supported")
return pstarmap_op(func, data)
if not isinstance(func, (pipeline.ops.OpKernel, runtime.object.ObjectBase)):
# Python mode
if isinstance(data, (list, runtime.List)):
return [func(*x) for x in data]
elif isinstance(data, (tuple, runtime.Tuple)):
return tuple(func(*x) for x in data)
else:
raise TypeError(f"expect the second argument is list or tuple, but get '{data}'")
sess_handle = TXObject.default_sess.c_handle
return _ffi_api.ParallelStarMap(func, data, sess_handle)
class Future:
def __init__(self, x):
self.__x = x
def get(self):
return self.__x
def __call__(self, ):
return self.__x
def apply_async(func, *args):
from . import pipeline
from .pipeline._base import TXObject
from .pipeline import _ffi_api
from .pipeline._tracing_state import tracing
if tracing():
from .pipeline import builtin_op
async_op = builtin_op.get_interpreter_op("ApplyAsync")
if not isinstance(func, (pipeline.ops.OpKernel, runtime.object.ObjectBase)):
raise TypeError(f"matx.pmap: the first argument '{func}' is not a traceable op")
return async_op(func, *args)
if not isinstance(func, (pipeline.ops.OpKernel, runtime.object.ObjectBase)):
# Python mode
return Future(func(*args))
sess_handle = TXObject.default_sess.c_handle
return _ffi_api.ApplyAsync(func, *args, sess_handle)
[文档]def load_so(dso_path=""):
from ._ffi.base import USE_CXX11_ABI
if isinstance(dso_path, str):
load_native(dso_path)
elif isinstance(dso_path, tuple):
if USE_CXX11_ABI():
load_native(dso_path[1])
else:
load_native(dso_path[0])
[文档]def trace(func, *args, **kwargs):
"""Trace a function and return an executable module that will be optimized using just-in-time compilation.
Parameters
----------
func : callable
A Python function or a matx Symbol(s) that will be run with `args`.
`func` arguments and return values must be Operators returned from Script.
args :
func inputs
kwargs :
func inputs
Returns
-------
module : JITModule
an executable module
"""
try:
return pipeline.trace(func, *args, **kwargs)
except BaseException:
raise
[文档]def script(compiling_obj, *args, backend=None, **kwargs):
"""Entry function for compiling. Given a python object including function,
simple class, compile it to a matx object which mostly
keep the behavior of the original python object.
Parameters
----------
compiling_obj : ([function, class])
[input python object to be compiled.]
args
not used
backend : [str, None]
'torch' or 'pytorch' or None
kwargs
keyword arguments passed into different script backend
Returns
-------
obj : OpKernel
the compiled object.
"""
from . import extension
if isinstance(compiling_obj, pipeline.ops.OpKernel):
return compiling_obj
if getattr(compiling_obj, "__FLAG_COMPILED_OBJECT__", None) is toolchain.FLAG_COMPILED_OBJECT:
return compiling_obj
if isinstance(backend, str):
backend = backend.lower()
if backend in ('torch', 'pytorch'):
return extension.pytorch.script(compiling_obj, *args, **kwargs)
elif backend in ('tensorflow',):
return extension.tensorflow.script(compiling_obj, *args, **kwargs)
else:
if extension.pytorch.check(compiling_obj):
return extension.pytorch.script(compiling_obj, *args, **kwargs)
elif extension.tensorflow.check(compiling_obj):
return extension.tensorflow.script(compiling_obj, *args, **kwargs)
return toolchain.script(compiling_obj, *args, **kwargs)
[文档]def script_embedded_class(code, is_path=False):
return toolchain.script_embedded_class(code, is_path)
[文档]def save(jit_module, folder, force_override=False):
return pipeline.save(jit_module, folder, force_override)
[文档]def load(folder, device):
return pipeline.load(folder, device)