|
|
|
@@ -32,6 +32,7 @@ |
|
|
|
#include "graph/common/local_context.h" |
|
|
|
#include "common/formats/utils/formats_trans_utils.h" |
|
|
|
#include "hybrid/hybrid_davinci_model.h" |
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
thread_local uint32_t device_count = 0; |
|
|
|
@@ -49,9 +50,13 @@ const std::string kCmdTypeProfModelSubscribe = "prof_model_subscribe"; |
|
|
|
const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; |
|
|
|
const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; |
|
|
|
const char *const kDeleteCustOp = "deleteCustOp"; |
|
|
|
const char *const kUsedStreamNum = "used_stream_num"; |
|
|
|
const char *const kStreamResource = "stream"; |
|
|
|
const char *const kEventResource = "event"; |
|
|
|
const int kTimeSpecNano = 1000000000; |
|
|
|
const int kTimeSpecMiro = 1000000; |
|
|
|
const int kSessionMaxBias = 100; |
|
|
|
const int kMaxEventNum = 1024; |
|
|
|
struct CustAicpuSoBuf { |
|
|
|
uint64_t kernelSoBuf; |
|
|
|
uint32_t kernelSoBufLen; |
|
|
|
@@ -69,7 +74,7 @@ std::mutex ModelManager::exeception_infos_mutex_; |
|
|
|
|
|
|
|
std::shared_ptr<ModelManager> ModelManager::GetInstance() { |
|
|
|
static const std::shared_ptr<ModelManager> instance_ptr = |
|
|
|
shared_ptr<ModelManager>(new (std::nothrow) ModelManager(), ModelManager::FinalizeForPtr); |
|
|
|
shared_ptr<ModelManager>(new (std::nothrow) ModelManager(), ModelManager::FinalizeForPtr); |
|
|
|
return instance_ptr; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -119,7 +124,7 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u |
|
|
|
} |
|
|
|
|
|
|
|
rt_ret = |
|
|
|
rtMemcpy(devicebase, sizeof(STR_FWK_OP_KERNEL), ¶m_base, sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); |
|
|
|
rtMemcpy(devicebase, sizeof(STR_FWK_OP_KERNEL), ¶m_base, sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); |
|
|
|
if (rt_ret != RT_ERROR_NONE) { |
|
|
|
GELOGE(RT_FAILED, "memory copy to device failed. ret: 0x%X", rt_ret); |
|
|
|
GE_IF_BOOL_EXEC(aicpu_kernel_addr != nullptr, GE_CHK_RT(rtFree(aicpu_kernel_addr))); |
|
|
|
@@ -368,6 +373,177 @@ void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciMode |
|
|
|
hybrid_model_map_[id] = hybrid_model; |
|
|
|
} |
|
|
|
|
|
|
|
Status ModelManager::CheckStreamAndEventResource(const GeModelPtr &ge_model) { |
|
|
|
GE_CHK_BOOL_EXEC(ge_model != nullptr, return FAILED, "ge model ptr is null"); |
|
|
|
int64_t value = 0; |
|
|
|
bool ret = ge::AttrUtils::GetInt(ge_model, ATTR_MODEL_STREAM_NUM, value); |
|
|
|
int64_t need_stream_num = ret ? value : 0; |
|
|
|
ret = ge::AttrUtils::GetInt(ge_model, ATTR_MODEL_EVENT_NUM, value); |
|
|
|
int64_t need_event_num = ret ? value : 0; |
|
|
|
|
|
|
|
int64_t hccl_follow_stream = 0; |
|
|
|
Status status = CalculateFollowStream(ge_model, hccl_follow_stream); |
|
|
|
if (status != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Calculate hccl follow stream failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
need_stream_num = need_stream_num + hccl_follow_stream; |
|
|
|
|
|
|
|
int64_t free_stream_num = 0; |
|
|
|
int64_t free_event_num = 0; |
|
|
|
status = GetFreeStreamAndEvent(free_stream_num, free_event_num); |
|
|
|
if (status != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Get free steam and event failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (need_stream_num > free_stream_num) { |
|
|
|
status = ReleaseRsource(need_stream_num, free_stream_num, kStreamResource); |
|
|
|
if (status != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Release stream resource failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
if (need_event_num > free_event_num) { |
|
|
|
status = ReleaseRsource(need_event_num, free_event_num, kEventResource); |
|
|
|
if (status != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Release event resource failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status ModelManager::CalculateFollowStream(const GeModelPtr &ge_model, int64_t &hccl_fellow_stream_num) { |
|
|
|
const auto &model_def = ge_model->GetModelTaskDefPtr(); |
|
|
|
GE_CHECK_NOTNULL(model_def); |
|
|
|
Graph graph = ge_model->GetGraph(); |
|
|
|
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
map<uint32_t, OpDescPtr> op_list; |
|
|
|
for (auto &node : compute_graph->GetDirectNode()) { |
|
|
|
OpDescPtr op_desc = node->GetOpDesc(); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
GELOGE(PARAM_INVALID, "Op desc is nullptr"); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
op_list.emplace(op_desc->GetId(), op_desc); |
|
|
|
} |
|
|
|
|
|
|
|
std::multimap<int64_t, int64_t> main_follow_num; |
|
|
|
for (int i = 0; i < model_def->task_size(); i++) { |
|
|
|
const domi::TaskDef &task = model_def->task(i); |
|
|
|
if (static_cast<rtModelTaskType_t>(task.type()) == RT_MODEL_TASK_HCCL) { |
|
|
|
auto hccl_def = task.kernel_hccl(); |
|
|
|
OpDescPtr hccl_op_desc = op_list.at(hccl_def.op_index()); |
|
|
|
int64_t main_stream_id = hccl_op_desc->GetStreamId(); |
|
|
|
int64_t follow_stream_num = 0; |
|
|
|
if (!ge::AttrUtils::GetInt(hccl_op_desc, kUsedStreamNum, follow_stream_num)) { |
|
|
|
GELOGW("Get used_stream_num failed, op is %s", hccl_op_desc->GetName().c_str()); |
|
|
|
} |
|
|
|
main_follow_num.emplace(main_stream_id, follow_stream_num); |
|
|
|
} |
|
|
|
} |
|
|
|
hccl_fellow_stream_num = CalFollowStreamSum(main_follow_num); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t ModelManager::CalFollowStreamSum(const std::multimap<int64_t, int64_t> &hccl_stream_map) { |
|
|
|
int64_t need_follow_stream_num = 0; |
|
|
|
std::map<int64_t, int64_t> max_follow_stream_map; |
|
|
|
for (auto &it : hccl_stream_map) { |
|
|
|
auto max_it = max_follow_stream_map.find(it.first); |
|
|
|
if (max_it != max_follow_stream_map.end() && (it.second) > (max_it->second)) { |
|
|
|
max_follow_stream_map.emplace(it.first, it.second); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto &follow_it : max_follow_stream_map) { |
|
|
|
need_follow_stream_num = need_follow_stream_num + follow_it.second; |
|
|
|
} |
|
|
|
return need_follow_stream_num; |
|
|
|
} |
|
|
|
|
|
|
|
Status ModelManager::ReleaseRsource(int64_t need_resource, int64_t free_resource, const string &resource_kind) { |
|
|
|
while (need_resource > free_resource) { |
|
|
|
uint32_t max_stream_model_id = 0; |
|
|
|
uint32_t max_event_model_id = 0; |
|
|
|
GetMaxStreamAndEventModel(max_stream_model_id, max_event_model_id); |
|
|
|
GELOGD("The max stream num model is: %u,the max event num model is %u", max_stream_model_id, max_event_model_id); |
|
|
|
std::lock_guard<std::mutex> lock(resource_mutex_); |
|
|
|
if (resource_kind == "stream") { |
|
|
|
Status ret = Unload(max_stream_model_id); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Unload max stream model failed ,model id : %u", max_stream_model_id); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
free_resource = free_resource - stream_map_.at(max_stream_model_id); |
|
|
|
stream_map_.erase(max_stream_model_id); |
|
|
|
GELOGD("Unload model for stream, model id : %u, stream num :%ld", max_stream_model_id, |
|
|
|
stream_map_.at(max_stream_model_id)); |
|
|
|
} |
|
|
|
|
|
|
|
if (resource_kind == "event") { |
|
|
|
Status ret = Unload(max_event_model_id); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Unload max stream model failed ,model id : %u", max_stream_model_id); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
free_resource = free_resource - event_map_.at(max_event_model_id); |
|
|
|
event_map_.erase(max_event_model_id); |
|
|
|
GELOGD("Unload model for event, model id : %u, event num :%ld", max_event_model_id, |
|
|
|
event_map_.at(max_event_model_id)); |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status ModelManager::GetFreeStreamAndEvent(int64_t &free_stream, int64_t &free_event) { |
|
|
|
uint32_t max_stream_cout; |
|
|
|
uint32_t max_task_cout; |
|
|
|
rtError_t ret = rtGetMaxStreamAndTask(RT_NORMAL_STREAM, &max_stream_cout, &max_task_cout); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
GELOGE(FAILED, "Get max stream and task cout failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
GELOGD("Allowed max stream count: %u,max task cout per stream:%u", max_stream_cout, max_task_cout); |
|
|
|
std::lock_guard<std::mutex> lock(resource_mutex_); |
|
|
|
int64_t stream_sum = 0; |
|
|
|
int64_t event_sum = 0; |
|
|
|
for (auto &it : stream_map_) { |
|
|
|
stream_sum = stream_sum + it.second; |
|
|
|
} |
|
|
|
for (auto &it : event_map_) { |
|
|
|
event_sum = event_sum + it.second; |
|
|
|
} |
|
|
|
free_stream = max_stream_cout - stream_sum; |
|
|
|
free_event = kMaxEventNum - event_sum; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void ModelManager::GetMaxStreamAndEventModel(uint32_t &max_stream_model, uint32_t &max_event_model) { |
|
|
|
std::lock_guard<std::mutex> lock(resource_mutex_); |
|
|
|
int64_t max_stream_num = 0; |
|
|
|
for (auto &it : stream_map_) { |
|
|
|
if (it.second > max_stream_num) { |
|
|
|
max_stream_num = it.second; |
|
|
|
max_stream_model = it.first; |
|
|
|
} |
|
|
|
} |
|
|
|
int64_t max_event_num = 0; |
|
|
|
for (auto &it : event_map_) { |
|
|
|
if (it.second > max_event_num) { |
|
|
|
max_event_num = it.second; |
|
|
|
max_event_model = it.first; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ModelManager::InsertModelResource(uint32_t model_id, int64_t stream_num, int64_t event_num) { |
|
|
|
std::lock_guard<std::mutex> lock(resource_mutex_); |
|
|
|
stream_map_.emplace(model_id, stream_num); |
|
|
|
event_map_.emplace(model_id, event_num); |
|
|
|
} |
|
|
|
|
|
|
|
Status ModelManager::DeleteModel(uint32_t id) { |
|
|
|
std::lock_guard<std::mutex> lock(map_mutex_); |
|
|
|
|
|
|
|
@@ -459,8 +635,7 @@ Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_ |
|
|
|
vector<int64_t> &cur_dynamic_dims) { |
|
|
|
GELOGD(" Start get cur dynamic dims."); |
|
|
|
if (user_real_input_dims.size() != user_input_dims.size()) { |
|
|
|
GELOGE(INTERNAL_ERROR, |
|
|
|
"The input count of user: %zu should be equal to the data count of graph: %zu", |
|
|
|
GELOGE(INTERNAL_ERROR, "The input count of user: %zu should be equal to the data count of graph: %zu", |
|
|
|
user_real_input_dims.size(), user_input_dims.size()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
@@ -516,6 +691,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT |
|
|
|
DataBuffer data; |
|
|
|
data.data = inputs[i].data; |
|
|
|
data.length = inputs[i].length; |
|
|
|
data.placement = inputs[i].placement; |
|
|
|
input_data.blobs.push_back(data); |
|
|
|
} |
|
|
|
if (!GetLocalOmgContext().user_input_dims.empty() && GetLocalOmgContext().need_multi_batch) { |
|
|
|
@@ -527,7 +703,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
DataBuffer data; |
|
|
|
data.data = new(std::nothrow) int64_t[cur_dynamic_dims.size()]; |
|
|
|
data.data = new (std::nothrow) int64_t[cur_dynamic_dims.size()]; |
|
|
|
GE_CHECK_NOTNULL(data.data); |
|
|
|
uint64_t length = static_cast<uint64_t>(cur_dynamic_dims.size() * sizeof(int64_t)); |
|
|
|
GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, return INTERNAL_ERROR, |
|
|
|
@@ -630,11 +806,13 @@ Status ModelManager::Stop(uint32_t model_id) { |
|
|
|
/// |
|
|
|
Status ModelManager::HandleCommand(const Command &command) { |
|
|
|
static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = { |
|
|
|
{kCmdTypeDump, HandleDumpCommand}, {kCmdTypeProfInit, HandleProfInitCommand}, |
|
|
|
{kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand}, |
|
|
|
{kCmdTypeProfStop, HandleProfStopCommand}, |
|
|
|
{kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, |
|
|
|
{kCmdTypeProfModelUnsubscribe, HandleProfModelUnsubscribeCommand}}; |
|
|
|
{kCmdTypeDump, HandleDumpCommand}, |
|
|
|
{kCmdTypeProfInit, HandleProfInitCommand}, |
|
|
|
{kCmdTypeProfFinalize, HandleProfFinalizeCommand}, |
|
|
|
{kCmdTypeProfStart, HandleProfStartCommand}, |
|
|
|
{kCmdTypeProfStop, HandleProfStopCommand}, |
|
|
|
{kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, |
|
|
|
{kCmdTypeProfModelUnsubscribe, HandleProfModelUnsubscribeCommand}}; |
|
|
|
|
|
|
|
auto iter = cmds.find(command.cmd_type); |
|
|
|
if (iter == cmds.end()) { |
|
|
|
@@ -645,17 +823,16 @@ Status ModelManager::HandleCommand(const Command &command) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status ModelManager::GetModelByCmd(const Command &command, |
|
|
|
std::shared_ptr<DavinciModel> &davinci_model) { |
|
|
|
Status ModelManager::GetModelByCmd(const Command &command, std::shared_ptr<DavinciModel> &davinci_model) { |
|
|
|
if (command.cmd_params.size() < kCmdParSize) { |
|
|
|
GELOGE(PARAM_INVALID, "When the cmd_type is '%s', the size of cmd_params must larger than 2.", |
|
|
|
command.cmd_type.c_str()); |
|
|
|
command.cmd_type.c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
std::string map_key = command.cmd_params[0]; |
|
|
|
std::string value = command.cmd_params[1]; |
|
|
|
if (map_key == PROFILE_MODEL_ID) { |
|
|
|
if (map_key == PROFILE_MODEL_ID) { |
|
|
|
int32_t model_id = 0; |
|
|
|
try { |
|
|
|
model_id = std::stoi(value); |
|
|
|
@@ -692,8 +869,8 @@ Status ModelManager::HandleProfModelSubscribeCommand(const Command &command) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
if (ProfilingManager::Instance().ProfModelSubscribe(command.module_index, |
|
|
|
static_cast<void *>(davinci_model.get())) != SUCCESS) { |
|
|
|
if (ProfilingManager::Instance().ProfModelSubscribe(command.module_index, static_cast<void *>(davinci_model.get())) != |
|
|
|
SUCCESS) { |
|
|
|
GELOGE(FAILED, "Handle prof model subscribe failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -1011,8 +1188,7 @@ Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, |
|
|
|
Status ModelManager::GetAIPPInfo(const uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) { |
|
|
|
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); |
|
|
|
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, |
|
|
|
"GetAIPPInfo failed, invalid model_id is %u.", |
|
|
|
model_id); |
|
|
|
"GetAIPPInfo failed, invalid model_id is %u.", model_id); |
|
|
|
|
|
|
|
return davinci_model->GetAIPPInfo(index, aipp_info); |
|
|
|
} |
|
|
|
@@ -1020,8 +1196,7 @@ Status ModelManager::GetAIPPInfo(const uint32_t model_id, uint32_t index, AippCo |
|
|
|
Status ModelManager::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) { |
|
|
|
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); |
|
|
|
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, |
|
|
|
"GetAIPPInfo failed, invalid model_id is %u.", |
|
|
|
model_id); |
|
|
|
"GetAIPPInfo failed, invalid model_id is %u.", model_id); |
|
|
|
|
|
|
|
return davinci_model->GetAippType(index, type, aipp_index); |
|
|
|
} |
|
|
|
@@ -1047,8 +1222,8 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { |
|
|
|
|
|
|
|
Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr<ModelListener> listener, |
|
|
|
void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { |
|
|
|
GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK, |
|
|
|
ACL_ERROR_GE_PARAM_INVALID, "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); |
|
|
|
GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK, ACL_ERROR_GE_PARAM_INVALID, |
|
|
|
"input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); |
|
|
|
GenModelId(&model_id); |
|
|
|
|
|
|
|
shared_ptr<DavinciModel> davinci_model = nullptr; |
|
|
|
@@ -1142,8 +1317,8 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d |
|
|
|
const std::vector<uint32_t> &input_queue_ids, |
|
|
|
const std::vector<uint32_t> &output_queue_ids) { |
|
|
|
GE_CHK_BOOL_RET_STATUS(model_data.key.empty() || mmAccess2(model_data.key.c_str(), M_F_OK) == EN_OK, |
|
|
|
ACL_ERROR_GE_PARAM_INVALID, "input key file path %s is not valid, %s", |
|
|
|
model_data.key.c_str(), strerror(errno)); |
|
|
|
ACL_ERROR_GE_PARAM_INVALID, "input key file path %s is not valid, %s", model_data.key.c_str(), |
|
|
|
strerror(errno)); |
|
|
|
|
|
|
|
ModelHelper model_helper; |
|
|
|
Status ret = model_helper.LoadModel(model_data); |
|
|
|
@@ -1344,8 +1519,8 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { |
|
|
|
} |
|
|
|
allocated_mem.push_back(d_so_name); |
|
|
|
GE_CHK_RT(rtMemcpy(d_aicpu_data, aicpu_data_length, aicpu_data, aicpu_data_length, RT_MEMCPY_HOST_TO_DEVICE)); |
|
|
|
GE_CHK_RT(rtMemcpy(d_so_name, so_name.size(), reinterpret_cast<const void *>(so_name.c_str()), |
|
|
|
so_name.size(), RT_MEMCPY_HOST_TO_DEVICE)); |
|
|
|
GE_CHK_RT(rtMemcpy(d_so_name, so_name.size(), reinterpret_cast<const void *>(so_name.c_str()), so_name.size(), |
|
|
|
RT_MEMCPY_HOST_TO_DEVICE)); |
|
|
|
|
|
|
|
CustAicpuSoBuf cust_aicpu_so_buf; |
|
|
|
cust_aicpu_so_buf.kernelSoBuf = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_aicpu_data)); |
|
|
|
@@ -1379,8 +1554,8 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { |
|
|
|
return RT_ERROR_TO_GE_STATUS(status); |
|
|
|
} |
|
|
|
allocated_mem.push_back(batch_args); |
|
|
|
GE_CHK_RT(rtMemcpy(batch_args, batch_args_size, static_cast<void *>(&batch_cust_so), |
|
|
|
batch_args_size, RT_MEMCPY_HOST_TO_DEVICE)); |
|
|
|
GE_CHK_RT(rtMemcpy(batch_args, batch_args_size, static_cast<void *>(&batch_cust_so), batch_args_size, |
|
|
|
RT_MEMCPY_HOST_TO_DEVICE)); |
|
|
|
|
|
|
|
GE_CHK_RT(rtStreamCreate(&stream, 0)); |
|
|
|
GE_CHK_RT(rtCpuKernelLaunch(nullptr, kernel_name.c_str(), 1, batch_args, batch_args_size, nullptr, stream)); |
|
|
|
@@ -1473,8 +1648,7 @@ void ModelManager::GenModelId(uint32_t *id) { |
|
|
|
Status ModelManager::GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info) { |
|
|
|
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); |
|
|
|
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, |
|
|
|
"GetOrigInputInfo failed, invalid model_id is %u.", |
|
|
|
model_id); |
|
|
|
"GetOrigInputInfo failed, invalid model_id is %u.", model_id); |
|
|
|
|
|
|
|
return davinci_model->GetOrigInputInfo(index, orig_input_info); |
|
|
|
} |
|
|
|
|