# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Acknowledgement: This file is inspired by TVM.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, import-outside-toplevel
"""Base library for FFI."""
import sys
import os
import ctypes
import numpy as np
import hashlib
from . import libinfo
# ----------------------------
# library loading
# ----------------------------
string_types = (str,)
integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
if sys.platform == "win32":
    def _py_str(x):
        try:
            return x.decode('utf-8')
        except UnicodeDecodeError:
            encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
        return x.decode(encoding)
    py_str = _py_str
else:
    py_str = lambda x: x.decode('utf-8')
def _load_lib():
    """Load library by searching possible path."""
    lib_path = libinfo.find_lib_path()
    lib_pp = os.path.abspath(os.path.dirname(lib_path[0]))
    cwd = os.getcwd()
    try:
        os.chdir(lib_pp)
        lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
        with open(lib_path[0], 'rb') as lib_f:
            lib_sha1 = hashlib.sha1(lib_f.read()).hexdigest()
        lib.MATXScriptAPIGetLastError.restype = ctypes.c_char_p
    finally:
        os.chdir(cwd)
    return lib, os.path.basename(lib_path[0]), lib_sha1
def _load_cuda_lib(base_lib):
    """Load library by searching possible path."""
    if sys.platform.startswith('win32'):
        cuda_lib_name = "libmatx_cuda.dll"
    elif sys.platform.startswith('darwin'):
        cuda_lib_name = "libmatx_cuda.dylib"
    else:
        cuda_lib_name = "libmatx_cuda.so"
    lib_path = libinfo.find_lib_path(name=cuda_lib_name, optional=True)
    if lib_path is None or len(lib_path) == 0:
        msg = f"{cuda_lib_name} is not compiled!!!"
        base_lib.MATXScriptSetDeviceDriverError(2, ctypes.c_char_p(msg.encode('utf-8')))
        return None, None, None
    lib_pp = os.path.abspath(os.path.dirname(lib_path[0]))
    cwd = os.getcwd()
    try:
        os.chdir(lib_pp)
        lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
        with open(lib_path[0], 'rb') as lib_f:
            lib_sha1 = hashlib.sha1(lib_f.read()).hexdigest()
        lib.MATXScriptAPIGetLastError.restype = ctypes.c_char_p
    except:
        import traceback
        msg = traceback.format_exc()
        # cpu=1 gpu=2 ...
        base_lib.MATXScriptSetDeviceDriverError(2, ctypes.c_char_p(msg.encode('utf-8')))
        lib = None
        lib_sha1 = None
    finally:
        os.chdir(cwd)
    return lib, os.path.basename(lib_path[0]), lib_sha1
try:
    import readline  # pylint: disable=unused-import
except ImportError:
    pass
# library instance
_LIB, _LIB_NAME, _LIB_SHA1 = _load_lib()
_CUDA_LIB, _CUDA_LIB_NAME, _CUDA_LIB_SHA1 = _load_cuda_lib(_LIB)
# Whether we are runtime only
_RUNTIME_ONLY = "runtime" in _LIB_NAME
def USE_CXX11_ABI():
    return _LIB.MATXScriptAPI_USE_CXX11_ABI()
def load_lib_by_name(libname):
    """Load library by searching possible path."""
    if sys.platform.startswith('win32'):
        libname += ".dll"
    elif sys.platform.startswith('darwin'):
        libname += ".dylib"
    else:
        libname += ".so"
    lib_path = libinfo.find_lib_path(name=libname, optional=True)
    if lib_path is None or len(lib_path) == 0:
        msg = f"{libname} is not compiled!!!"
        raise RuntimeError(msg)
    lib_pp = os.path.abspath(os.path.dirname(lib_path[0]))
    cwd = os.getcwd()
    try:
        os.chdir(lib_pp)
        lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
        with open(lib_path[0], 'rb') as lib_f:
            lib_sha1 = hashlib.sha1(lib_f.read()).hexdigest()
    finally:
        os.chdir(cwd)
    return lib, os.path.basename(lib_path[0]), lib_sha1
# ----------------------------
# helper function in ctypes.
# ----------------------------
def c_str(string):
    """Create ctypes char * from a python string
    Parameters
    ----------
    string : string type
        python string
    Returns
    -------
    str : c_char_p
        A char pointer that can be passed to C API
    """
    return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
    """Create ctypes array from a python array
    Parameters
    ----------
    ctype : ctypes data type
        data type of the array we want to convert to
    values : tuple or list
        data content
    Returns
    -------
    out : ctypes array
        Created ctypes array
    """
    return (ctype * len(values))(*values)
def decorate(func, fwrapped):
    """A wrapper call of decorator package, differs to call time
    Parameters
    ----------
    func : function
        The original function
    fwrapped : function
        The wrapped function
    """
    import decorator
    return decorator.decorate(func, fwrapped)
# -----------------------------------------
# Base code for structured error handling.
# -----------------------------------------
# Maps error type to its constructor
ERROR_TYPE = {}
[文档]class TError(RuntimeError):
    """Default error thrown by packed functions.
    TError will be raised if you do not give any error type specification,
    """ 
def register_error(func_name=None, cls=None):
    """Register an error class so it can be recognized by the ffi error handler.
    Parameters
    ----------
    func_name : str or function or class
        The name of the error function.
    cls : function
        The function to create the class
    Returns
    -------
    fregister : function
        Register function if f is not specified.
    Examples
    --------
    .. code-block:: python
      @matx.error.register_error
      class MyError(RuntimeError):
          pass
      err_inst = matx.error.create_ffi_error("MyError: xyz")
      assert isinstance(err_inst, MyError)
    """
    if callable(func_name):
        cls = func_name
        func_name = cls.__name__
    def register(mycls):
        """internal register function"""
        err_name = func_name if isinstance(func_name, str) else mycls.__name__
        ERROR_TYPE[err_name] = mycls
        return mycls
    if cls is None:
        return register
    return register(cls)
def _valid_error_name(name):
    """Check whether name is a valid error name."""
    return all(x.isalnum() or x in "_." for x in name)
def _find_error_type(line):
    """Find the error name given the first line of the error message.
    Parameters
    ----------
    line : str
        The first line of error message.
    Returns
    -------
    name : str The error name
    """
    if sys.platform == "win32":
        # Stack traces aren't logged on Windows due to a DMLC limitation,
        # so we should try to get the underlying error another way.
        # DMLC formats errors "[timestamp] file:line: ErrorMessage"
        # ErrorMessage is usually formatted "ErrorType: message"
        # We can try to extract the error type using the final ":"
        end_pos = line.rfind(":")
        if end_pos == -1:
            return None
        start_pos = line.rfind(":", 0, end_pos)
        if start_pos == -1:
            err_name = line[:end_pos].strip()
        else:
            err_name = line[start_pos + 1: end_pos].strip()
        if _valid_error_name(err_name):
            return err_name
        return None
    end_pos = line.find(":")
    if end_pos == -1:
        return None
    err_name = line[:end_pos]
    if _valid_error_name(err_name):
        return err_name
    return None
def c2pyerror(err_msg):
    """Translate C API error message to python style.
    Parameters
    ----------
    err_msg : str
        The error message.
    Returns
    -------
    new_msg : str
        Translated message.
    err_type : str
        Detected error type.
    """
    arr = err_msg.split("\n")
    if arr[-1] == "":
        arr.pop()
    err_type = _find_error_type(arr[0])
    trace_mode = False
    stack_trace = []
    message = []
    for line in arr:
        if trace_mode:
            if line.startswith("  "):
                stack_trace.append(line)
            else:
                trace_mode = False
        if not trace_mode:
            if line.startswith("Stack trace"):
                trace_mode = True
            else:
                message.append(line)
    out_msg = ""
    if stack_trace:
        out_msg += "Traceback (most recent call last):\n"
        out_msg += "\n".join(reversed(stack_trace)) + "\n"
    out_msg += "\n".join(message)
    return out_msg, err_type
def py2cerror(err_msg):
    """Translate python style error message to C style.
    Parameters
    ----------
    err_msg : str
        The error message.
    Returns
    -------
    new_msg : str
        Translated message.
    """
    arr = err_msg.split("\n")
    if arr[-1] == "":
        arr.pop()
    trace_mode = False
    stack_trace = []
    message = []
    for line in arr:
        if trace_mode:
            if line.startswith("  "):
                stack_trace.append(line)
            else:
                trace_mode = False
        if not trace_mode:
            if line.find("Traceback") != -1:
                trace_mode = True
            else:
                message.append(line)
    # Remove the first error name if there are two of them.
    # RuntimeError: MyErrorName: message => MyErrorName: message
    head_arr = message[0].split(":", 3)
    if len(head_arr) >= 3 and _valid_error_name(head_arr[1].strip()):
        head_arr[1] = head_arr[1].strip()
        message[0] = ":".join(head_arr[1:])
    # reverse the stack trace.
    out_msg = "\n".join(message)
    if stack_trace:
        out_msg += "\nStack trace:\n"
        out_msg += "\n".join(reversed(stack_trace)) + "\n"
    return out_msg
def get_last_ffi_error():
    """Create error object given result of MATXAPIGetLastError.
    Returns
    -------
    err : object
        The error object based on the err_msg
    """
    c_err_msg = py_str(_LIB.MATXScriptAPIGetLastError())
    py_err_msg, err_type = c2pyerror(c_err_msg)
    if err_type is not None and err_type.startswith("matx.error."):
        err_type = err_type[11:]
    return ERROR_TYPE.get(err_type, TError)(py_err_msg)
def check_call(ret):
    """Check the return value of C API call
    This function will raise exception when error occurs.
    Wrap every API call with this function
    Parameters
    ----------
    ret : int
        return value from API calls
    """
    if ret != 0:
        raise get_last_ffi_error()