You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

241 lines
6.4 KiB

/*
* @Author: xiewenji 527774126@qq.com
* @Date: 2025-09-03 09:48:19
* @LastEditors: xiewenji 527774126@qq.com
* @LastEditTime: 2025-09-24 10:29:29
* @FilePath: /AI_SO_Test/AIEngineModule/include_base/AI_Factory.h
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
*/
/*
* @Author: xiewenji 527774126@qq.com
* @Date: 2025-07-04 14:41:18
* @LastEditors: xiewenji 527774126@qq.com
* @LastEditTime: 2025-09-03 10:27:59
* @FilePath: /AI_SO_Test/AIEngineModule/include_base/AI_Engine_Base.h
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
*/
#ifndef AI_Factory_H_
#define AI_Factory_H_
#include <string>
#include <opencv2/opencv.hpp>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <queue>
#include <vector>
#include <atomic>
#include <memory>
// GPU 可用信息
struct GPU_Config
{
int UseGPUList[4]; // 使用的GPU设备号
GPU_Config()
{
for (int i = 0; i < 4; i++)
{
UseGPUList[i] = -1;
}
UseGPUList[0] = 0;
UseGPUList[1] = 1;
}
int GetNum()
{
int num = 0;
for (int i = 0; i < 4; i++)
{
if (UseGPUList[i] >= 0)
{
num++;
}
}
return num;
}
void copy(GPU_Config tem)
{
for (int i = 0; i < 4; i++)
{
this->UseGPUList[i] = tem.UseGPUList[i];
}
}
};
// AI 模型 基础接口
class AIModel_Base
{
public:
enum Model_Input_Type
{
Input_NULL,
Input_HW,
Input_CHW,
Input_HWC,
};
// 模型相关参数
struct AIModelRun_Config
{
int Stream_num; // 流的个数
bool IsClass; // 是分类
std::string strPath; // 模型类型
std::string strName; // 名称,别名
Model_Input_Type inputType; // 数据格式类型
GPU_Config gpuconfig;
AIModelRun_Config()
{
strPath = "";
strName = "";
inputType = Input_NULL;
Stream_num = 1;
IsClass = false;
}
void Copy(AIModelRun_Config tem)
{
this->Stream_num = tem.Stream_num;
this->strPath = tem.strPath;
this->strName = tem.strName;
this->inputType = tem.inputType;
this->IsClass = tem.IsClass;
this->gpuconfig.copy(tem.gpuconfig);
}
void print(std::string str)
{
printf("%s:Stream_num %d Path:%s inputType %d\n", str.c_str(), Stream_num, strPath.c_str(), inputType);
}
};
struct AI_Image
{
int channel;
int width;
int height;
AI_Image()
{
channel = 0;
width = 0;
height = 0;
}
void copy(AI_Image tem)
{
this->channel = tem.channel;
this->height = tem.height;
this->width = tem.width;
}
void print()
{
printf("[c w h] %d %d %d\n", channel, width, height);
}
};
public:
virtual ~AIModel_Base() = default;
static std::shared_ptr<AIModel_Base> GetInstance();
// 初始化函数
virtual int Init(AIModelRun_Config config) = 0;
virtual int AIDet(const cv::Mat &inImg, cv::Mat &outimg) = 0;
virtual int AIDet(const cv::Mat &inImg, cv::Mat &outimg0, cv::Mat &outimg1) = 0;
virtual int AIClass(const cv::Mat &inImg, float *fmaxScore) = 0;
public:
bool m_bInitSuccess = false;
AI_Image input_0;
AI_Image input_1;
AI_Image output_0;
AI_Image output_1;
AI_Image output_2;
};
// 管理所有 模型
class AIFactory
{
public:
AIFactory();
~AIFactory();
static std::shared_ptr<AIFactory> GetInstance();
// 初始化所有模型
int InitALLAIModle(GPU_Config gupconfig);
private:
public:
std::shared_ptr<AIModel_Base> AI_defect_NF; // BOE测试
std::shared_ptr<AIModel_Base> AI_defect_Type2; // 第二个检测模型
std::shared_ptr<AIModel_Base> AI_defect_UP; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_Chess; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_YX_1; // L0
std::shared_ptr<AIModel_Base> AI_defect_YX_2; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_Cls; // 1023
std::shared_ptr<AIModel_Base> AI_defect_Cls_L0; //
std::shared_ptr<AIModel_Base> AI_defect_zf; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_127Cell; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_RE_POL; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_RE_AD; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_Edge_Big; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_Edge_Samll; // L127 L255
std::shared_ptr<AIModel_Base> AI_defect_LackPol; // 缺pol检测
std::shared_ptr<AIModel_Base> AI_defect_MarkLine; // 缺pol检测
std::shared_ptr<AIModel_Base> AI_defect_Edge_QX; // 边缘缺陷检测
private:
// 模型是否都加载完成了。
bool m_bInitSucc = false;
};
// 多线程推理
class AIMulThreadRunBase
{
public:
struct AITask
{
cv::Rect roi;
cv::Size srcroi;
int id;
cv::Mat input;
std::shared_ptr<cv::Mat> output;
std::shared_ptr<AIModel_Base> engine;
bool bclass = false;
int cls_label = 0;
float cls_score = 0.0;
int userflag = 0;
};
public:
AIMulThreadRunBase();
~AIMulThreadRunBase();
void Start(int num_threads = 2);
void Stop();
void SubmitTask(std::shared_ptr<AITask> task);
bool PopResult(std::shared_ptr<AITask> &result);
// 新增接口,查询当前正在处理的任务数
int GetProcessingCount() const { return m_processing_count_.load(); }
std::atomic<int> m_detnum; // 当前正在执行任务数
private:
void ThreadLoop();
private:
std::queue<std::shared_ptr<AITask>> m_tasks_;
std::queue<std::shared_ptr<AITask>> m_results_;
std::vector<std::thread> m_workers_;
std::mutex m_task_mutex_;
std::mutex m_result_mutex_;
std::mutex m_AI_mutex_;
std::condition_variable m_task_cv_;
std::atomic<bool> m_running_;
std::atomic<int> m_processing_count_; // 当前正在执行任务数
};
#endif