You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
4.0 KiB
158 lines
4.0 KiB
/*
|
|
* @Author: xiewenji 527774126@qq.com
|
|
* @Date: 2025-07-05 19:32:41
|
|
* @LastEditors: xiewenji 527774126@qq.com
|
|
* @LastEditTime: 2025-09-06 11:13:54
|
|
* @FilePath: /AI_SO_Test/AIEngineModule/src/Engine.cpp
|
|
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
|
*/
|
|
#include "Engine.h"
|
|
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <cuda_runtime_api.h>
|
|
|
|
using namespace nvinfer1;
|
|
|
|
class Logger : public ILogger
|
|
{
|
|
public:
|
|
void log(Severity severity, const char *msg) noexcept override
|
|
{
|
|
if (severity <= Severity::kWARNING)
|
|
std::cout << "[TensorRT] " << msg << std::endl;
|
|
}
|
|
} gLogger;
|
|
// 实现自定义删除器
|
|
void TensorRTDeleter::operator()(nvinfer1::IRuntime *ptr) const noexcept
|
|
{
|
|
// if (ptr)
|
|
// {
|
|
// printf("************ IRuntime \n");
|
|
// // delete ptr;
|
|
// }
|
|
}
|
|
void TensorRTDeleter::operator()(nvinfer1::IExecutionContext *ptr) const noexcept
|
|
{
|
|
// if (!ptr)
|
|
// return;
|
|
|
|
// try
|
|
// {
|
|
// printf("************ IExecutionContext \n");
|
|
// // 检查有效性、上下文等
|
|
// delete ptr;
|
|
// }
|
|
// catch (const std::exception &e)
|
|
// {
|
|
// std::cerr << "[Deleter] Exception: " << e.what() << "\n";
|
|
// }
|
|
// catch (...)
|
|
// {
|
|
// std::cerr << "[Deleter] Unknown error during deletion\n";
|
|
// }
|
|
}
|
|
void TensorRTDeleter::operator()(nvinfer1::ICudaEngine *ptr) const noexcept
|
|
{
|
|
// if (ptr)
|
|
// {
|
|
// printf("************ ICudaEngine \n");
|
|
// delete ptr;
|
|
// }
|
|
// if (ptr)
|
|
// {
|
|
// printf("************ ICudaEngine \n");
|
|
// try
|
|
// {
|
|
// // delete ptr; //
|
|
// }
|
|
// catch (...)
|
|
// {
|
|
// std::cerr << "[Deleter] Exception deleting ICudaEngine\n";
|
|
// }
|
|
// }
|
|
}
|
|
Engine::Engine(int gpuId) : gpuId_(gpuId) {}
|
|
|
|
Engine::~Engine()
|
|
{
|
|
// printf("========Engine= start ====\n");
|
|
// engine_.reset();
|
|
// runtime_.reset();
|
|
// printf("========Engine= engine_.reset();====\n");
|
|
}
|
|
|
|
bool Engine::loadFromFile(const std::string &enginePath)
|
|
{
|
|
|
|
try
|
|
{
|
|
|
|
cudaSetDevice(gpuId_);
|
|
std::ifstream file(enginePath, std::ios::binary);
|
|
if (!file)
|
|
{
|
|
std::cerr << "[Engine] Failed to open engine file: " << enginePath << std::endl;
|
|
return false;
|
|
}
|
|
|
|
file.seekg(0, file.end);
|
|
size_t size = file.tellg();
|
|
file.seekg(0, file.beg);
|
|
std::vector<char> engineData(size);
|
|
file.read(engineData.data(), size);
|
|
runtime_ = std::unique_ptr<IRuntime, TensorRTDeleter>(createInferRuntime(gLogger));
|
|
if (!runtime_)
|
|
{
|
|
std::cerr << "[Engine] Failed to create runtime." << std::endl;
|
|
return false;
|
|
}
|
|
|
|
engine_ = std::unique_ptr<ICudaEngine, TensorRTDeleter>(
|
|
runtime_->deserializeCudaEngine(engineData.data(), size));
|
|
|
|
if (!engine_)
|
|
{
|
|
std::cerr << "[Engine] Failed to deserialize engine." << std::endl;
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
catch (const std::exception &e)
|
|
{
|
|
std::cerr << "[Engine] Exception during loadFromFile: " << e.what() << std::endl;
|
|
return false;
|
|
}
|
|
catch (...)
|
|
{
|
|
std::cerr << "[Engine] Unknown exception during loadFromFile." << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
int Engine::getNbBindings() const
|
|
{
|
|
return engine_ ? engine_->getNbIOTensors() : 0;
|
|
}
|
|
|
|
std::string Engine::getBindingName(int index) const
|
|
{
|
|
return engine_ ? engine_->getIOTensorName(index) : "";
|
|
}
|
|
|
|
Dims Engine::getBindingDims(int index) const
|
|
{
|
|
if (!engine_)
|
|
return Dims{};
|
|
std::string name = engine_->getIOTensorName(index);
|
|
return engine_->getTensorShape(name.c_str());
|
|
}
|
|
|
|
bool Engine::bindingIsInput(int index) const
|
|
{
|
|
if (!engine_)
|
|
return false;
|
|
std::string name = engine_->getIOTensorName(index);
|
|
return engine_->getTensorIOMode(name.c_str()) == TensorIOMode::kINPUT;
|
|
} |