matx._ffi.base 源代码

# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Acknowledgement: This file is inspired by TVM.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, import-outside-toplevel
"""Base library for FFI."""
import sys
import os
import ctypes
import numpy as np
import hashlib
from . import libinfo

# ----------------------------
# library loading
# ----------------------------
string_types = (str,)
integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)

# this function is needed for python3
# to convert ctypes.char_p .value back to python str
if sys.platform == "win32":
    def _py_str(x):
        try:
            return x.decode('utf-8')
        except UnicodeDecodeError:
            encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
        return x.decode(encoding)

    py_str = _py_str
else:
    py_str = lambda x: x.decode('utf-8')


def _load_lib():
    """Load library by searching possible path."""
    lib_path = libinfo.find_lib_path()
    lib_pp = os.path.abspath(os.path.dirname(lib_path[0]))
    cwd = os.getcwd()
    try:
        os.chdir(lib_pp)
        lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
        with open(lib_path[0], 'rb') as lib_f:
            lib_sha1 = hashlib.sha1(lib_f.read()).hexdigest()
        lib.MATXScriptAPIGetLastError.restype = ctypes.c_char_p
    finally:
        os.chdir(cwd)
    return lib, os.path.basename(lib_path[0]), lib_sha1


def _load_cuda_lib(base_lib):
    """Load library by searching possible path."""
    if sys.platform.startswith('win32'):
        cuda_lib_name = "libmatx_cuda.dll"
    elif sys.platform.startswith('darwin'):
        cuda_lib_name = "libmatx_cuda.dylib"
    else:
        cuda_lib_name = "libmatx_cuda.so"
    lib_path = libinfo.find_lib_path(name=cuda_lib_name, optional=True)
    if lib_path is None or len(lib_path) == 0:
        msg = f"{cuda_lib_name} is not compiled!!!"
        base_lib.MATXScriptSetDeviceDriverError(2, ctypes.c_char_p(msg.encode('utf-8')))
        return None, None, None
    lib_pp = os.path.abspath(os.path.dirname(lib_path[0]))
    cwd = os.getcwd()
    try:
        os.chdir(lib_pp)
        lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
        with open(lib_path[0], 'rb') as lib_f:
            lib_sha1 = hashlib.sha1(lib_f.read()).hexdigest()
        lib.MATXScriptAPIGetLastError.restype = ctypes.c_char_p
    except:
        import traceback
        msg = traceback.format_exc()
        # cpu=1 gpu=2 ...
        base_lib.MATXScriptSetDeviceDriverError(2, ctypes.c_char_p(msg.encode('utf-8')))
        lib = None
        lib_sha1 = None
    finally:
        os.chdir(cwd)
    return lib, os.path.basename(lib_path[0]), lib_sha1


try:
    import readline  # pylint: disable=unused-import
except ImportError:
    pass
# library instance
_LIB, _LIB_NAME, _LIB_SHA1 = _load_lib()
_CUDA_LIB, _CUDA_LIB_NAME, _CUDA_LIB_SHA1 = _load_cuda_lib(_LIB)

# Whether we are runtime only
_RUNTIME_ONLY = "runtime" in _LIB_NAME


def USE_CXX11_ABI():
    return _LIB.MATXScriptAPI_USE_CXX11_ABI()


def load_lib_by_name(libname):
    """Load library by searching possible path."""
    if sys.platform.startswith('win32'):
        libname += ".dll"
    elif sys.platform.startswith('darwin'):
        libname += ".dylib"
    else:
        libname += ".so"
    lib_path = libinfo.find_lib_path(name=libname, optional=True)
    if lib_path is None or len(lib_path) == 0:
        msg = f"{libname} is not compiled!!!"
        raise RuntimeError(msg)
    lib_pp = os.path.abspath(os.path.dirname(lib_path[0]))
    cwd = os.getcwd()
    try:
        os.chdir(lib_pp)
        lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
        with open(lib_path[0], 'rb') as lib_f:
            lib_sha1 = hashlib.sha1(lib_f.read()).hexdigest()
    finally:
        os.chdir(cwd)
    return lib, os.path.basename(lib_path[0]), lib_sha1


# ----------------------------
# helper function in ctypes.
# ----------------------------
def c_str(string):
    """Create ctypes char * from a python string
    Parameters
    ----------
    string : string type
        python string

    Returns
    -------
    str : c_char_p
        A char pointer that can be passed to C API
    """
    return ctypes.c_char_p(string.encode('utf-8'))


def c_array(ctype, values):
    """Create ctypes array from a python array

    Parameters
    ----------
    ctype : ctypes data type
        data type of the array we want to convert to

    values : tuple or list
        data content

    Returns
    -------
    out : ctypes array
        Created ctypes array
    """
    return (ctype * len(values))(*values)


def decorate(func, fwrapped):
    """A wrapper call of decorator package, differs to call time

    Parameters
    ----------
    func : function
        The original function

    fwrapped : function
        The wrapped function
    """
    import decorator
    return decorator.decorate(func, fwrapped)


# -----------------------------------------
# Base code for structured error handling.
# -----------------------------------------
# Maps error type to its constructor
ERROR_TYPE = {}


[文档]class TError(RuntimeError): """Default error thrown by packed functions. TError will be raised if you do not give any error type specification, """
def register_error(func_name=None, cls=None): """Register an error class so it can be recognized by the ffi error handler. Parameters ---------- func_name : str or function or class The name of the error function. cls : function The function to create the class Returns ------- fregister : function Register function if f is not specified. Examples -------- .. code-block:: python @matx.error.register_error class MyError(RuntimeError): pass err_inst = matx.error.create_ffi_error("MyError: xyz") assert isinstance(err_inst, MyError) """ if callable(func_name): cls = func_name func_name = cls.__name__ def register(mycls): """internal register function""" err_name = func_name if isinstance(func_name, str) else mycls.__name__ ERROR_TYPE[err_name] = mycls return mycls if cls is None: return register return register(cls) def _valid_error_name(name): """Check whether name is a valid error name.""" return all(x.isalnum() or x in "_." for x in name) def _find_error_type(line): """Find the error name given the first line of the error message. Parameters ---------- line : str The first line of error message. Returns ------- name : str The error name """ if sys.platform == "win32": # Stack traces aren't logged on Windows due to a DMLC limitation, # so we should try to get the underlying error another way. # DMLC formats errors "[timestamp] file:line: ErrorMessage" # ErrorMessage is usually formatted "ErrorType: message" # We can try to extract the error type using the final ":" end_pos = line.rfind(":") if end_pos == -1: return None start_pos = line.rfind(":", 0, end_pos) if start_pos == -1: err_name = line[:end_pos].strip() else: err_name = line[start_pos + 1: end_pos].strip() if _valid_error_name(err_name): return err_name return None end_pos = line.find(":") if end_pos == -1: return None err_name = line[:end_pos] if _valid_error_name(err_name): return err_name return None def c2pyerror(err_msg): """Translate C API error message to python style. Parameters ---------- err_msg : str The error message. Returns ------- new_msg : str Translated message. err_type : str Detected error type. """ arr = err_msg.split("\n") if arr[-1] == "": arr.pop() err_type = _find_error_type(arr[0]) trace_mode = False stack_trace = [] message = [] for line in arr: if trace_mode: if line.startswith(" "): stack_trace.append(line) else: trace_mode = False if not trace_mode: if line.startswith("Stack trace"): trace_mode = True else: message.append(line) out_msg = "" if stack_trace: out_msg += "Traceback (most recent call last):\n" out_msg += "\n".join(reversed(stack_trace)) + "\n" out_msg += "\n".join(message) return out_msg, err_type def py2cerror(err_msg): """Translate python style error message to C style. Parameters ---------- err_msg : str The error message. Returns ------- new_msg : str Translated message. """ arr = err_msg.split("\n") if arr[-1] == "": arr.pop() trace_mode = False stack_trace = [] message = [] for line in arr: if trace_mode: if line.startswith(" "): stack_trace.append(line) else: trace_mode = False if not trace_mode: if line.find("Traceback") != -1: trace_mode = True else: message.append(line) # Remove the first error name if there are two of them. # RuntimeError: MyErrorName: message => MyErrorName: message head_arr = message[0].split(":", 3) if len(head_arr) >= 3 and _valid_error_name(head_arr[1].strip()): head_arr[1] = head_arr[1].strip() message[0] = ":".join(head_arr[1:]) # reverse the stack trace. out_msg = "\n".join(message) if stack_trace: out_msg += "\nStack trace:\n" out_msg += "\n".join(reversed(stack_trace)) + "\n" return out_msg def get_last_ffi_error(): """Create error object given result of MATXAPIGetLastError. Returns ------- err : object The error object based on the err_msg """ c_err_msg = py_str(_LIB.MATXScriptAPIGetLastError()) py_err_msg, err_type = c2pyerror(c_err_msg) if err_type is not None and err_type.startswith("matx.error."): err_type = err_type[11:] return ERROR_TYPE.get(err_type, TError)(py_err_msg) def check_call(ret): """Check the return value of C API call This function will raise exception when error occurs. Wrap every API call with this function Parameters ---------- ret : int return value from API calls """ if ret != 0: raise get_last_ffi_error()