
/*
 * MML framework
 */
#ifndef LIB_AI_MML_INTERFACE_API_H
#define LIB_AI_MML_INTERFACE_API_H

#include <string>
#include <memory>

#define MML_NATIVE_API __attribute__((visibility("default")))
namespace mml_framework {
// lite backend Target Type
enum class MMLTargetType : int {
  kUnk = 0,
  kHost = 1,
  kX86 = 2,
  kCUDA = 3,
  kARM = 4,
  kOpenCL = 5,
  kFPGA = 7,
  kNPU = 8,
  kXPU = 9,
  kAny = 6,  // any target
  NUM = 10,  // number of fields.
};

/// MML string容器
struct MML_NATIVE_API MMLString {
  const char *data;
  size_t size;
};

// MML配置相关参数
struct MML_NATIVE_API MMLConfig {
  enum Precision {
    FP32 = 0
  };
  // Machine类型，目前支持BML（GBDT和LR回归与分类）、PaddleLite
  enum MachineType {
    BML = 0, PaddleMobile = 1, PaddleLite = 2
  };

  // 模型输入的数据精度
  Precision precision = FP32;

  // 模型文件路径
  std::string modelUrl;

  // 后端类型
  MachineType machine_type = MachineType::PaddleMobile;

  // 模型加密类型，mml加密，后端加密，业务加密
  enum DecryptType {
    Business = 0, Backend = 1, MML = 2
  };

  // 加解密信息
  struct MML_NATIVE_API DecryptInfo {
    DecryptType type;
    std::string decrypt_key;
    int api_level{};
  public:
    DecryptInfo() { type = Business; }
  };

    DecryptInfo decrypt_info = DecryptInfo();

    // paddle可支持在模型之外添加的前后处理类型
    enum PrePostType {
        NONE_PRE_POST = 0,  // none
        UINT8_255 = 1       // 输入数据归一化
    };

    // opencl tune config
    typedef enum {
        CL_TUNE_NONE = 0,
        CL_TUNE_RAPID = 1,
        CL_TUNE_NORMAL = 2,
        CL_TUNE_EXHAUSTIVE = 3
    } PaddleLiteCLTuneMode;

    typedef enum {
        CL_PRECISION_AUTO = 0,
        CL_PRECISION_FP32 = 1,
        CL_PRECISION_FP16 = 2,
        CL_NO_DEFINE = 3
    } PaddleLiteCLPrecisionType;

    // PaddleMobile special config
    struct MML_NATIVE_API PaddleMobileConfig {
        // 是否使用gpu
        bool is_gpu = false;
        // 是否进行优化融合
        bool optimize = true;
        // 是否量化
        bool quantification = false;
        // batch size
    int batch_size = 1;
    bool lod_mode = false;
    bool load_when_predict = true;
    // 额外的前后处理操作
    PrePostType pre_post_type = NONE_PRE_POST;
  };

  // BML special config
  struct MML_NATIVE_API BMLConfig {

  };

  // Paddle lite config
  struct MML_NATIVE_API PaddleLiteConfig {
    typedef enum {
      LITE_POWER_HIGH = 0,
      LITE_POWER_LOW = 1,
      LITE_POWER_FULL = 2,
      LITE_POWER_NO_BIND = 3,
      LITE_POWER_RAND_HIGH = 4,
      LITE_POWER_RAND_LOW = 5
    } PaddleLitePowerMode;
    PaddleLitePowerMode powermode = {PaddleLitePowerMode::LITE_POWER_NO_BIND};
    int threads = {1};

    typedef enum {
      LITE_MODEL_FROM_DIR = 0,
      LITE_MODEL_FROM_FILE = 1,
      LITE_MODEL_FROM_BUFFER = 2,
      LITE_MODEL_FROM_MODELBUFFER = 3,
    } PaddleLiteModelType;
    PaddleLiteModelType model_type = {PaddleLiteModelType::LITE_MODEL_FROM_DIR};

    struct MML_NATIVE_API PaddleLiteModelBuffer {
      char *model_buffer;
      size_t model_buffer_size;
      char *param_buffer;
        size_t param_buffer_size;
    };

      struct MML_NATIVE_API PaddleLiteModel {
          MMLString model_from_file; // used when LITE_MODEL_FROM_FILE
          MMLString model_from_buffer; // used whtn LITE_MODEL_FROM_BUFFER
          PaddleLiteModelBuffer model_buffer; // used when LITE_MODEL_FROM_MODELBUFFER
      };
      PaddleLiteModel model{};

      struct MML_NATIVE_API OpenClConfig {
          // binary cache config
          std::string opencl_binary_path;
          std::string opencl_binary_name;
          // tune config
          std::string opencl_tuned_path;
          std::string opencl_tuned_name;
          PaddleLiteCLTuneMode opencl_tuned_mode{CL_TUNE_NORMAL};
          int opencl_tuned_times{4};
          // cl precision_type config
          PaddleLiteCLPrecisionType precision_type{PaddleLiteCLPrecisionType::CL_NO_DEFINE};
      };
      OpenClConfig opencl_config{};
      
      struct MML_NATIVE_API MetalConfig {
          bool metal_use_mps = {true};
          bool metal_use_aggressive = {false};
          const char *metal_lib_path = {nullptr};
          void *metal_device = {nullptr};
          bool metal_use_memory_reuse = {false};
      };

      MetalConfig metal_config{};
  };

  // 具体machine的配置
  struct MML_NATIVE_API MachineConfig {
    PaddleMobileConfig paddle_mobile_config;
    BMLConfig bml_config;
    PaddleLiteConfig paddle_lite_config;
  };
  // 专属于某个machine的配置
  MachineConfig machine_config = {PaddleMobileConfig()};

  // 用于Dev阶段的Config, Release阶段请不要配置.
  struct MML_NATIVE_API DevConfig {
    // 具备写权限的可保存日志的路径.
    std::string profile_save_path;
  };

  // 专属于某个machine的配置
  DevConfig dev_config = {};
};

enum MmlMachineType {
  // 暂时只支持PaddleMobile
  PADDLE_MOBILE = 0,
  PADDLE_LITE = 1,
};

// MMLTensor中使用的一些控制类型
using shape_t = std::vector<int64_t>;
using lod_t = std::vector<std::vector<uint64_t>>;

//MML Tensor,
class MML_NATIVE_API MMLTensor {
public:
  explicit MMLTensor(void *raw);

  explicit MMLTensor(const void *raw);

  MMLTensor() = default;

  void *tensor = 0;

  void Resize(const shape_t &shape);

  void SetLoD(const lod_t &lod);

  /// Readonly data. support int \ float \ int8_t \ uint8_t \ int64_t
  template<typename T>
  const T *data() const;

  template<typename T>
  T *mutable_data(MMLTargetType type = MMLTargetType::kHost) const;

  // iOS传入id<MTLTexture>时使用
  void* mutable_metal_data(void* ptr) const;

  shape_t shape() const;

  lod_t lod() const;

  bool autoRelease = true;

  // 如果autoRelease为true，则会在析构时释放rawData
  virtual ~MMLTensor();

  // 释放rawData，如果autoRelease为true，则不需要主动调用
  void release();
};

// MML配置相关参数
struct MML_NATIVE_API MMLData {
  enum RawDataType {
    FLOAT32,
    UINT8,
    PADDLE_LITE_TENSOR
  };

  struct MML_NATIVE_API RawDataShape {
    int n = 1;
    int c = 1;
    int h;
    int w;

  public:
    MML_NATIVE_API RawDataShape() = default;

    MML_NATIVE_API RawDataShape(int n, int c, int h, int w) {
      this->n = n;
      this->c = c;
      this->h = h;
      this->w = w;
    }
  };

  // 待预测的数据，对应input_tensor.data
  void *rawData = nullptr;
  // rawData长度，单位byte
  long dataLength;
  RawDataType rawDataType;
  // 对应input_tensor.shape，非必须
  RawDataShape rawDataShape;
  // 若为true 析构时自动调用release()
  bool autoRelease = true;
  //tensor
  MMLTensor *mmlTensor = 0;

  MMLData() = default;

  // 如果autoRelease为true，则会在析构时释放rawData
  virtual ~MMLData();

  // delete rawData、mmlTensor，如果autoRelease为true，则不需要主动调用
  void release();
};

/**
 * 前后处理接口类
 */
class MML_NATIVE_API MMLDataProcessor {
public:
  /**
   * 前处理回调
   * @param preProcessInputData
   * @param preProcessOutputData
   * @return
   */
  virtual int
  preProcess(const MMLData &preProcessInputData, MMLData *preProcessOutputData) = 0;

  /**
   * 后处理回调
   * @param postProcessInputData
   * @param postProcessOutputData
   * @return
   */
  virtual int
  postProcess(const MMLData &postProcessInputData, MMLData *postProcessOutputData) = 0;

  virtual ~MMLDataProcessor() {};
};

/**
 * 判别当前手机的OpenCl是否支持
 * @return
 */
bool MML_NATIVE_API IsOpenCLBackendValid();

/**
 * MMLMachine管理类。我们把加载了某个模型的一种Inference引擎（如PaddleMobile或BML）称做一个MMLMachine。
 */
class MML_NATIVE_API MMLMachineService {
public:
  std::unique_ptr<mml_framework::MMLData> getInputData(int i);

  std::unique_ptr<const mml_framework::MMLData> getOutputData(int i);

  std::vector<std::string> getInputNames();

  std::vector<std::string> getOutputNames();

  std::unique_ptr<mml_framework::MMLData> getInputByName(const std::string &name);

private:
  // MachinePredictor指针
  void *machineHandle = nullptr;
  // MachinePredictor类型，如PaddleMobile、PaddleLite，BML等
  MmlMachineType mmlMachineType;
  // 前后处理回调实现
  MMLDataProcessor *mProcessorImpl = nullptr;

  /**
   * 预测函数，不包含前后处理
   *
   * @param modelInputData
   * @param modelOutputData
   * @return
   */
  int predict(MMLData &modelInputData, MMLData *modelOutputData);

  int predict();

public:

  /**
   * 如果autoRelease为true，则会在析构时释放machineHandle
   */
  bool autoRelease = true;

  /**
   * 设置前后处理回调实现
   *
   * @param processorImpl
   */
  void setInterceptProcessor(MMLDataProcessor *processorImpl);

  /**
   * 预测函数，如果设置了InterceptProcessor，则会先执行InterceptProcessor的前处理回调，再执行predict，
   * 再执行InterceptProcessor的后处理回调。如果未设置InterceptProcessor，则会直接执行predict
   *
   * @param inputData
   * @param outputData
   * @return
   */
  int run(MMLData &inputData, MMLData *outputData);

    int run();

    /**
     * 根据MMLConfig配置，创建MachinePredictor，并加载模型
     *
     * @param config
     * @return
     */
    int load(const MMLConfig &config);

    /**
      * 在设置输入后调用
      *
      * run 并且保存输入与输出到配置路径.
      *
      * @param config
      * @return
      */
    int profile();

    /**
     * 获得内部Predictor的Version
     * @return
     */
    std::string version();
    /**
     * 如果autoRelease为true，则会在析构时delete mProcessorImpl以及machineHandle指向的predictor
     */
    virtual ~MMLMachineService();

    /**
    * 释放mProcessorImpl与machineHandle，如果autoRelease为true，则不需要主动调用
    */
    void release();
};

/**
 * 模型加载错误，成功/解密错误/参数错误/引擎创建错误
 */
enum ErrorCode {
  SUCCESS = 0,
  ERR_PARAM = -1,

  LOAD_ERR_OTHER = -11,
  LOAD_ERR_DECRYPT = -12,
  LOAD_ERR_MACHINE_TYPE = -13,

  RUN_ERR_OTHER = -20,
  RUN_ERR_PREPROECESS = -21,
  RUN_ERR_POSTPROECESS = -22,
  RUN_ERR_PREDICT = -23,
  RUN_ERR_MACHINE_TYPE = -24,
  RUN_ERR_MACHINE_HANDLE = -25,

};

std::shared_ptr<MMLMachineService> CreateMMLMachineService(MMLConfig &config);

}

#endif //LIB_AI_MML_INTERFACE_API_H
