You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
4.0 KiB

/*
* @Author: xiewenji 527774126@qq.com
* @Date: 2025-07-05 19:32:41
* @LastEditors: xiewenji 527774126@qq.com
* @LastEditTime: 2025-09-06 11:13:54
* @FilePath: /AI_SO_Test/AIEngineModule/src/Engine.cpp
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
*/
#include "Engine.h"
#include <fstream>
#include <iostream>
#include <cuda_runtime_api.h>
using namespace nvinfer1;
class Logger : public ILogger
{
public:
void log(Severity severity, const char *msg) noexcept override
{
if (severity <= Severity::kWARNING)
std::cout << "[TensorRT] " << msg << std::endl;
}
} gLogger;
// 实现自定义删除器
void TensorRTDeleter::operator()(nvinfer1::IRuntime *ptr) const noexcept
{
// if (ptr)
// {
// printf("************ IRuntime \n");
// // delete ptr;
// }
}
void TensorRTDeleter::operator()(nvinfer1::IExecutionContext *ptr) const noexcept
{
// if (!ptr)
// return;
// try
// {
// printf("************ IExecutionContext \n");
// // 检查有效性、上下文等
// delete ptr;
// }
// catch (const std::exception &e)
// {
// std::cerr << "[Deleter] Exception: " << e.what() << "\n";
// }
// catch (...)
// {
// std::cerr << "[Deleter] Unknown error during deletion\n";
// }
}
void TensorRTDeleter::operator()(nvinfer1::ICudaEngine *ptr) const noexcept
{
// if (ptr)
// {
// printf("************ ICudaEngine \n");
// delete ptr;
// }
// if (ptr)
// {
// printf("************ ICudaEngine \n");
// try
// {
// // delete ptr; //
// }
// catch (...)
// {
// std::cerr << "[Deleter] Exception deleting ICudaEngine\n";
// }
// }
}
Engine::Engine(int gpuId) : gpuId_(gpuId) {}
Engine::~Engine()
{
// printf("========Engine= start ====\n");
// engine_.reset();
// runtime_.reset();
// printf("========Engine= engine_.reset();====\n");
}
bool Engine::loadFromFile(const std::string &enginePath)
{
try
{
cudaSetDevice(gpuId_);
std::ifstream file(enginePath, std::ios::binary);
if (!file)
{
std::cerr << "[Engine] Failed to open engine file: " << enginePath << std::endl;
return false;
}
file.seekg(0, file.end);
size_t size = file.tellg();
file.seekg(0, file.beg);
std::vector<char> engineData(size);
file.read(engineData.data(), size);
runtime_ = std::unique_ptr<IRuntime, TensorRTDeleter>(createInferRuntime(gLogger));
if (!runtime_)
{
std::cerr << "[Engine] Failed to create runtime." << std::endl;
return false;
}
engine_ = std::unique_ptr<ICudaEngine, TensorRTDeleter>(
runtime_->deserializeCudaEngine(engineData.data(), size));
if (!engine_)
{
std::cerr << "[Engine] Failed to deserialize engine." << std::endl;
return false;
}
return true;
}
catch (const std::exception &e)
{
std::cerr << "[Engine] Exception during loadFromFile: " << e.what() << std::endl;
return false;
}
catch (...)
{
std::cerr << "[Engine] Unknown exception during loadFromFile." << std::endl;
return false;
}
}
int Engine::getNbBindings() const
{
return engine_ ? engine_->getNbIOTensors() : 0;
}
std::string Engine::getBindingName(int index) const
{
return engine_ ? engine_->getIOTensorName(index) : "";
}
Dims Engine::getBindingDims(int index) const
{
if (!engine_)
return Dims{};
std::string name = engine_->getIOTensorName(index);
return engine_->getTensorShape(name.c_str());
}
bool Engine::bindingIsInput(int index) const
{
if (!engine_)
return false;
std::string name = engine_->getIOTensorName(index);
return engine_->getTensorIOMode(name.c_str()) == TensorIOMode::kINPUT;
}