You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

49 lines
1.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*
* @Author: xiewenji 527774126@qq.com
* @Date: 2025-07-05 19:32:31
* @LastEditors: xiewenji 527774126@qq.com
* @LastEditTime: 2025-09-06 11:14:02
* @FilePath: /AI_SO_Test/AIEngineModule/include/Engine.h
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
*/
#ifndef Engine_H_
#define Engine_H_
#pragma once
#include <NvInfer.h>
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <functional>
using namespace nvinfer1;
// 自定义删除器确保正确释放TensorRT资源
struct TensorRTDeleter
{
void operator()(nvinfer1::IRuntime *ptr) const noexcept;
void operator()(nvinfer1::ICudaEngine *ptr) const noexcept;
void operator()(nvinfer1::IExecutionContext *ptr) const noexcept; // 添加上下文删除器
};
class Engine
{
public:
Engine(int gpuId);
~Engine();
bool loadFromFile(const std::string &enginePath);
int getNbBindings() const;
std::string getBindingName(int index) const;
nvinfer1::Dims getBindingDims(int index) const;
bool bindingIsInput(int index) const;
nvinfer1::ICudaEngine *get() const { return engine_.get(); }
std::mutex mutex_;
std::unique_ptr<nvinfer1::ICudaEngine, TensorRTDeleter> engine_;
private:
int gpuId_;
std::unique_ptr<nvinfer1::IRuntime, TensorRTDeleter> runtime_;
};
#endif