Browse Source

Merge branch 'master' of gitee.com:mindspore/graphengine into master

pull/1623/head
宇世康 Gitee 4 years ago
parent
commit
a0bc608735
12 changed files with 329 additions and 24 deletions
  1. +9
    -0
      ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc
  2. +1
    -0
      ge/graph/load/model_manager/task_info/kernel_ex_task_info.h
  3. +9
    -0
      ge/graph/load/model_manager/task_info/kernel_task_info.cc
  4. +1
    -0
      ge/graph/load/model_manager/task_info/kernel_task_info.h
  5. +7
    -0
      ge/graph/passes/multi_batch_clone_pass.cc
  6. +96
    -21
      ge/graph/preprocess/graph_preprocess.cc
  7. +56
    -0
      ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc
  8. +5
    -0
      ge/hybrid/node_executor/aicpu/aicpu_ext_info.h
  9. +54
    -0
      tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc
  10. +77
    -3
      tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc
  11. +9
    -0
      third_party/fwkacllib/inc/cce/fwk_adpt_struct.h
  12. +5
    -0
      third_party/fwkacllib/inc/runtime/kernel.h

+ 9
- 0
ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc View File

@@ -53,6 +53,7 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe
"Parse kernel ext info failed, kernel_ext_info_size=%zu.", ext_info.size());
GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed.");
GELOGD("Update aicpu_task ext_info bit_map execute mode to 1.");
topic_type_flag_ = ext_handle->GetTopicTypeFlag();

bool all_shape = false;
(void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
@@ -407,6 +408,14 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const

Status KernelExTaskInfo::Distribute() {
GELOGI("KernelExTaskInfo Distribute Start.");
// Use the fifth and sixth bits of dump_flag_ indicate the value of topic_type.
// xxxxxxxx xxxxxxxx xxxxxxxx xx00xxxx: DEVICE_ONLY
// xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST
// xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY
// xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST
if (topic_type_flag_ > 0) {
dump_flag_ = dump_flag_ | topic_type_flag_;
}
rtError_t rt_ret = rtKernelLaunchEx(kernel_buf_, kernel_buf_size_, dump_flag_, stream_);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtKernelLaunchEx failed, ret:0x%X",


+ 1
- 0
ge/graph/load/model_manager/task_info/kernel_ex_task_info.h View File

@@ -76,6 +76,7 @@ class KernelExTaskInfo : public TaskInfo {
vector<void *> io_addrs_;
uint32_t args_offset_ = 0;
int64_t fixed_addr_offset_ = 0;
int32_t topic_type_flag_ = -1;
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_

+ 9
- 0
ge/graph/load/model_manager/task_info/kernel_task_info.cc View File

@@ -431,6 +431,14 @@ Status KernelTaskInfo::Distribute() {
int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail;
bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_);
if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
if (topic_type_flag_ > 0) {
// Use the fifth and sixth bits of dump_flag_ indicate the value of topic_type.
// xxxxxxxx xxxxxxxx xxxxxxxx xx00xxxx: DEVICE_ONLY
// xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST
// xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY
// xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST
dump_flag_ = dump_flag_ | topic_type_flag_;
}
GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_);
// blockDim is reserved parameter, set to 1
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()),
@@ -1116,6 +1124,7 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) {
GELOGD("Update aicpu_task ext_info session_info session_id is %lu", davinci_model_->GetSessionId());
GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed.");
GELOGD("Update aicpu_task ext_info bit_map execute mode to 1.");
topic_type_flag_ = ext_handle->GetTopicTypeFlag();

bool all_shape = false;
(void)AttrUtils::GetBool(op_desc_, kAicpuAllshape, all_shape);


+ 1
- 0
ge/graph/load/model_manager/task_info/kernel_task_info.h View File

@@ -169,6 +169,7 @@ class KernelTaskInfo : public TaskInfo {
uint16_t io_addr_offset_ = 0;
bool l2_buffer_on_ = false;
bool call_save_dump_ = false;
int32_t topic_type_flag_ = -1;

// aicpu ext_info device mem
void *aicpu_ext_info_addr_ = nullptr;


+ 7
- 0
ge/graph/passes/multi_batch_clone_pass.cc View File

@@ -42,6 +42,7 @@ const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_";
const char *const kGetNextName = "IteratorV2";
const char *const kMbatchCaseName = "mbatch-switch-name";
} // namespace

inline bool IsGetNextType(const NodePtr &node) {
@@ -943,6 +944,12 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_an
}
}
(void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
if (!AttrUtils::SetStr(node->GetOpDesc(), kMbatchCaseName, case_node_->GetName())) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to node:%s(%s) failed",
kMbatchCaseName, node->GetName().c_str(), node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", node->GetName().c_str());
return INTERNAL_ERROR;
}

GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex));
std::vector<std::string> input_dims_str;


+ 96
- 21
ge/graph/preprocess/graph_preprocess.cc View File

@@ -610,7 +610,7 @@ Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, For
return SUCCESS;
}

Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &switchn_node) {
Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &mbatch_node, int32_t &index) {
is_dynamic_batch = false;
std::string related_node_name;
if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) {
@@ -621,13 +621,17 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node
data_node->GetName().c_str());
return INTERNAL_ERROR;
}
for (const NodePtr &next_node : data_node->GetOutNodes()) {
if (next_node->GetName() == related_node_name) {
switchn_node = next_node;

auto out_data_nodes_anchors = data_node->GetOutDataNodesAndAnchors();
for (const auto &out_data_node_anchor : out_data_nodes_anchors) {
if (out_data_node_anchor.first->GetName() == related_node_name) {
mbatch_node = out_data_node_anchor.first;
index = out_data_node_anchor.second->GetIdx();
break;
}
}
if (switchn_node == nullptr) {

if (mbatch_node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E15002", {"opname", "value", "reason"},
{data_node->GetName(), related_node_name, "but can not find it on the graph"});
@@ -680,7 +684,7 @@ Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) {
// In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole
// graph optimization, GE only sets the final data_type/format/shape information for variable,
// data and netoutput, and no longer inserts the transnode.
Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node, DataType &dt_set) {
Status ProcessInputDtDynShape(NodePtr &node_ptr, NodePtr &switchn_node, DataType &dt_set) {
GE_CHECK_NOTNULL(node_ptr);
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -713,19 +717,84 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr
GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str());
}

if (is_dynamic_batch) {
GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str());
auto switchn_op_desc = switchn_node->GetOpDesc();
GE_CHECK_NOTNULL(switchn_op_desc);
auto switchn_input = switchn_op_desc->MutableInputDesc(0);
GE_CHECK_NOTNULL(switchn_input);
switchn_input->SetDataType(dt_set);
for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) {
const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i);
GE_CHECK_NOTNULL(switchn_output);
switchn_output->SetDataType(dt_set);
return SUCCESS;
}

Status UpdateInputOutputDataType(NodePtr &mbatch_node, DataType &dt_set, int32_t index) {
auto mbatch_desc = mbatch_node->GetOpDesc();
GE_CHECK_NOTNULL(mbatch_desc);
auto mbatch_input = mbatch_desc->MutableInputDesc(index);
GE_CHECK_NOTNULL(mbatch_input);
mbatch_input->SetDataType(dt_set);

if (mbatch_node->GetType() == SWITCHN) {
for (uint32_t i = 0; i < mbatch_node->GetAllOutDataAnchorsSize(); ++i) {
const GeTensorDescPtr &mbatch_output = mbatch_desc->MutableOutputDesc(i);
GE_CHECK_NOTNULL(mbatch_output);
mbatch_output->SetDataType(dt_set);
}
}

GELOGD("Update input and output data type of node[name: %s, type: %s, input index: %d] to %s.",
mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), index,
TypeUtils::DataTypeToSerialString(dt_set).c_str());

return SUCCESS;
}

Status UpdateSubgraphDataOfCase(NodePtr &mbatch_node, DataType &dt_set, int32_t index) {
if (mbatch_node->GetType() != CASE) {
return SUCCESS;
}

auto subgraphs = NodeUtils::GetAllSubgraphs(*mbatch_node);
for (const auto &subgraph : subgraphs) {
GE_CHECK_NOTNULL(subgraph);
for (auto &sub_node : subgraph->GetDirectNode()) {
GE_CHECK_NOTNULL(sub_node);
if (sub_node->GetType() != DATA) {
continue;
}

auto data_desc = sub_node->GetOpDesc();
GE_CHECK_NOTNULL(data_desc);
int32_t parent_node_index = 0;
if (!AttrUtils::GetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index) ||
(parent_node_index != index)) {
continue;
}

auto data_input = data_desc->MutableInputDesc(0);
GE_CHECK_NOTNULL(data_input);
data_input->SetDataType(dt_set);
auto data_output = data_desc->MutableOutputDesc(0);
GE_CHECK_NOTNULL(data_output);
data_output->SetDataType(dt_set);
GELOGD("Update input and output data type of node[name: %s, type: %s, parent_node_index: %d] in subgraph %s "
"to %s.", data_desc->GetName().c_str(), data_desc->GetType().c_str(), parent_node_index,
subgraph->GetName().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str());
}
}

return SUCCESS;
}

Status ProcessMbatchScene(NodePtr &mbatch_node, DataType &dt_set, int32_t index) {
GELOGI("The node [%s] dtype set fp16.", mbatch_node->GetName().c_str());
if (UpdateInputOutputDataType(mbatch_node, dt_set, index) != SUCCESS) {
GELOGE(FAILED, "Update input and output data type of node[name: %s, type: %s] to %s failed.",
mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(),
TypeUtils::DataTypeToSerialString(dt_set).c_str());
return FAILED;
}

if (UpdateSubgraphDataOfCase(mbatch_node, dt_set, index) != SUCCESS) {
GELOGE(FAILED, "Update input and output data type of Data node[parent_node_index: %d] in subgraphs of "
"node[name: %s, type: %s] to %s failed.", index, mbatch_node->GetName().c_str(),
mbatch_node->GetType().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str());
return FAILED;
}

return SUCCESS;
}

@@ -786,21 +855,27 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) {
DataType dt_set = TypeUtils::SerialStringToDataType(set_dt_str);
GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str());
bool is_dynamic_batch = false;
NodePtr switchn_node = nullptr;
if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) {
NodePtr mbatch_node = nullptr;
int32_t index = 0;
if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, mbatch_node, index)) {
GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed");
return FAILED;
}
if (ProcessInputDtDynShape(node_ptr, is_dynamic_batch, switchn_node, dt_set) != SUCCESS) {
if (ProcessInputDtDynShape(node_ptr, mbatch_node, dt_set) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed");
return FAILED;
}
if (is_dynamic_batch && ProcessMbatchScene(mbatch_node, dt_set, index) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "ProcessMbatchScene failed");
return FAILED;
}

// check if need to set format
string set_format;
bool ret = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_FORMAT, set_format);
if (ret && (!set_format.empty()) && TypeUtils::SerialStringToFormat(set_format) == FORMAT_NC1HWC0) {
GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str());
if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) {
if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, mbatch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed");
return FAILED;
}


+ 56
- 0
ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc View File

@@ -24,6 +24,12 @@ namespace hybrid {
namespace {
// if dim count is not reach kMaxShapeDims(8), use INT64_MIN to mark dim end.
constexpr int64_t kDimEndFlag = INT64_MIN;
const std::map<int32_t, int32_t> kTopicTypeToRtsFlagMap {
{static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY), 0},
{static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST), RT_KERNEL_DEVICE_FIRST},
{static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY), RT_KERNEL_HOST_ONLY},
{static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST), RT_KERNEL_HOST_FIRST}
};
}

Status AicpuExtInfoHandler::Parse(const std::string &ext_info) {
@@ -72,6 +78,9 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) {
case aicpu::FWKAdapter::FWK_ADPT_EXT_UPDATE_ADDR:
GE_CHK_STATUS_RET(ParseExtUpdateAddr(aicpu_ext_info), "[Parse][ExtUpdateAddr] failed.");
break;
case aicpu::FWKAdapter::FWK_ADPT_EXT_TOPIC_TYPE:
GE_CHK_STATUS_RET(ParseExtTopicType(aicpu_ext_info), "[Parse][ExtTopicType] failed.");
break;
default:
GELOGD("Node[%s] ignore infoType=%d, infoLen=%u.",
node_name_.c_str(), aicpu_ext_info->infoType, aicpu_ext_info->infoLen);
@@ -207,6 +216,44 @@ Status AicpuExtInfoHandler::ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info) {
return SUCCESS;
}

Status AicpuExtInfoHandler::ParseExtTopicType(AicpuExtInfo *aicpu_ext_info) {
if (aicpu_ext_info->infoLen != sizeof(int32_t)) {
REPORT_INNER_ERROR("E19999",
"Node[%s] parse topic_type info failed as infoLen must be %zu but %u.",
node_name_.c_str(), sizeof(int32_t), aicpu_ext_info->infoLen);
GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"[Check][DataLen]Node[%s] parse topic_type info failed as infoLen must be %zu but %u.",
node_name_.c_str(), sizeof(int32_t), aicpu_ext_info->infoLen);
return ACL_ERROR_GE_PARAM_INVALID;
}
GE_CHECK_NOTNULL(aicpu_ext_info->infoMsg);
auto type = *reinterpret_cast<const int32_t *>(aicpu_ext_info->infoMsg);

topic_type_flag_ = TopicTypeToRtsFlag(type);
if (topic_type_flag_ == -1) {
REPORT_INNER_ERROR("E19999", "Node[%s] parse ext topic type failed as need %d %d %d %d but %d.",
node_name_.c_str(),
aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY,
aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST,
aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY,
aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST,
type);
GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"[Check][Type]Node[%s] parse ext shape type failed as need %d %d %d %d but %d.",
node_name_.c_str(),
aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY,
aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST,
aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY,
aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST,
type);
return ACL_ERROR_GE_PARAM_INVALID;
}

GELOGI("Node[%s] parse ext topic type info success infoLen=%u, topic_type=%d, rts_flag=%d.",
node_name_.c_str(), aicpu_ext_info->infoLen, type, topic_type_flag_);
return SUCCESS;
}

Status AicpuExtInfoHandler::UpdateExecuteMode(bool flag) {
if (bit_map_ == nullptr) {
GELOGD("There is no bit_map in ext_info, no need update.");
@@ -341,5 +388,14 @@ void AicpuExtInfoHandler::GetShapeAndType(const AicpuShapeAndType *shape_and_typ
data_type = static_cast<DataType>(shape_and_type->type);
shape = GeShape(dims);
}

int32_t AicpuExtInfoHandler::TopicTypeToRtsFlag(int32_t topic_type) {
auto it = kTopicTypeToRtsFlagMap.find(topic_type);
if (it != kTopicTypeToRtsFlagMap.end()) {
return it->second;
}

return -1;
}
} // namespace hybrid
} // namespace ge

+ 5
- 0
ge/hybrid/node_executor/aicpu/aicpu_ext_info.h View File

@@ -62,6 +62,7 @@ class AicpuExtInfoHandler {
Status GetOutputShapeAndType(uint32_t output_index, GeShape &shape, DataType &data_type);

bool IsNeedRefreshIOAddr();
int32_t GetTopicTypeFlag() const { return topic_type_flag_; }

private:

@@ -71,6 +72,7 @@ class AicpuExtInfoHandler {
Status ParseExtSessionInfo(AicpuExtInfo *aicpu_ext_info);
Status ParseExtBitMap(AicpuExtInfo *aicpu_ext_info);
Status ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info);
Status ParseExtTopicType(AicpuExtInfo *aicpu_ext_info);

static Status UpdateShapeAndType(const GeShape &shape,
DataType data_type,
@@ -81,6 +83,8 @@ class AicpuExtInfoHandler {
DataType &data_type);

private:
int32_t TopicTypeToRtsFlag(int32_t topic_type);

const std::string node_name_;
const uint32_t input_num_;
const uint32_t output_num_;
@@ -88,6 +92,7 @@ class AicpuExtInfoHandler {
AicpuSessionInfo *session_info_ = nullptr;
uint64_t *bit_map_ = nullptr;
uint32_t *update_addr_ = nullptr;
int32_t topic_type_flag_ = -1;

std::unique_ptr<uint8_t[]> ext_info_;
size_t ext_info_len_ = 0;


+ 54
- 0
tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc View File

@@ -155,4 +155,58 @@ TEST_F(UtestKernelExTaskInfo, parse_update_addr) {
KernelExTaskInfo kernel_ex_task_info;
EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}

TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_1) {
const string ext_info = {7,0,0,0,4,0,0,0,0,0,0,0};
const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp");
AttrUtils::SetBool(op_desc, "_AllShape", true);

KernelExTaskInfo kernel_ex_task_info;
EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}

TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_2) {
const string ext_info = {7,0,0,0,4,0,0,0,1,0,0,0};
const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp");
AttrUtils::SetBool(op_desc, "_AllShape", true);

KernelExTaskInfo kernel_ex_task_info;
EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}

TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_3) {
const string ext_info = {7,0,0,0,4,0,0,0,2,0,0,0};
const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp");
AttrUtils::SetBool(op_desc, "_AllShape", true);

KernelExTaskInfo kernel_ex_task_info;
EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}

TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_4) {
const string ext_info = {7,0,0,0,4,0,0,0,3,0,0,0};
const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp");
AttrUtils::SetBool(op_desc, "_AllShape", true);

KernelExTaskInfo kernel_ex_task_info;
EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}

TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_1) {
const string ext_info = {7,0,0,0,4,0,0,0,4,0,0,0};
const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp");
AttrUtils::SetBool(op_desc, "_AllShape", true);

KernelExTaskInfo kernel_ex_task_info;
EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}

TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_2) {
const string ext_info = {7,0,0,0,2,0,0,0,2,0,0,0};
const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp");
AttrUtils::SetBool(op_desc, "_AllShape", true);

KernelExTaskInfo kernel_ex_task_info;
EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS);
}
} // namespace ge

+ 77
- 3
tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc View File

@@ -72,8 +72,8 @@ ComputeGraphPtr BuildGraph3() {
return builder.GetGraph();
}

ComputeGraphPtr BuildGraph4() {
auto builder = ut::GraphBuilder("g4");
ComputeGraphPtr BuildGraph5() {
auto builder = ut::GraphBuilder("g5");
auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
auto add = builder.AddNode("add", ADD, 2, 1);
@@ -85,6 +85,41 @@ ComputeGraphPtr BuildGraph4() {
return builder.GetGraph();
}

/*
* MapIndex Data1 subgraph1 subgraph2
* \ /
* Case ===> Data2 Data3
* |
* Netoutput
*/
ComputeGraphPtr BuildGraph4() {
auto builder = ut::GraphBuilder("mbatch_Case");

auto data1 = builder.AddNode("data1", DATA, 1, 1);
auto data_desc = data1->GetOpDesc();
AttrUtils::SetStr(data_desc, ATTR_ATC_USER_DEFINE_DATATYPE, "DT_FLOAT16");
AttrUtils::SetStr(data_desc, "mbatch-switch-name", "case1");
AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, 0);

auto mapindex1 = builder.AddNode("mapindex1", "MapIndex", 0, 1);
auto case1 = builder.AddNode("case1", CASE, 2, 1);
auto netoutput1 = builder.AddNode("netoutput1", NETOUTPUT, 1, 0);

builder.AddDataEdge(mapindex1, 0, case1, 0);
builder.AddDataEdge(data1, 0, case1, 1);
builder.AddDataEdge(case1, 0, netoutput1, 0);

return builder.GetGraph();
}

ComputeGraphPtr BuildGraph4_Subgraph(string graph_name) {
auto builder = ut::GraphBuilder(graph_name);
auto data1 = builder.AddNode(graph_name + "_data1", DATA, 1, 1);
auto data_desc = data1->GetOpDesc();
AttrUtils::SetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, 1);
return builder.GetGraph();
}

TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph1();
@@ -135,7 +170,7 @@ TEST_F(UtestGraphPreproces, test_update_input_output1) {

TEST_F(UtestGraphPreproces, check_ref_op_data_succ) {
GraphPrepare graph_preparer;
ComputeGraphPtr graph_test = BuildGraph4();
ComputeGraphPtr graph_test = BuildGraph5();
NodePtr add_node = nullptr;
for (auto &node : graph_test->GetAllNodes()) {
if (node->GetName() == "add") {
@@ -149,4 +184,43 @@ TEST_F(UtestGraphPreproces, check_ref_op_data_succ) {
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph4();
auto parent_graph = graph_prepare.compute_graph_;
auto subgraph1 = BuildGraph4_Subgraph("subgraph1");
auto subgraph2 = BuildGraph4_Subgraph("subgraph2");

auto data1 = parent_graph->FindNode("data1");
auto data_desc = data1->GetOpDesc();

auto case_node = parent_graph->FindNode("case1");
EXPECT_NE(case_node, nullptr);
case_node->GetOpDesc()->AddSubgraphName("subgraph1");
case_node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph1");
subgraph1->SetParentNode(case_node);
subgraph1->SetParentGraph(parent_graph);
EXPECT_EQ(parent_graph->AddSubgraph("subgraph1", subgraph1), GRAPH_SUCCESS);

case_node->GetOpDesc()->AddSubgraphName("subgraph2");
case_node->GetOpDesc()->SetSubgraphInstanceName(1, "subgraph2");
subgraph2->SetParentNode(case_node);
subgraph2->SetParentGraph(parent_graph);
EXPECT_EQ(parent_graph->AddSubgraph("subgraph2", subgraph2), GRAPH_SUCCESS);

Status ret = graph_prepare.UpdateInputOutputByOptions();
EXPECT_EQ(ret, SUCCESS);

auto case_desc = case_node->GetOpDesc();
auto case_input = case_desc->MutableInputDesc(1);
EXPECT_EQ(case_input->GetDataType(), 1);

auto sub1_data1 = subgraph1->FindNode("subgraph1_data1");
EXPECT_NE(sub1_data1, nullptr);
auto data1_desc = sub1_data1->GetOpDesc();
auto data1_input = data1_desc->MutableInputDesc(0);
EXPECT_EQ(data1_input->GetDataType(), 1);
auto data1_output = data1_desc->MutableOutputDesc(0);
EXPECT_EQ(data1_output->GetDataType(), 1);
}
}

+ 9
- 0
third_party/fwkacllib/inc/cce/fwk_adpt_struct.h View File

@@ -61,9 +61,18 @@ enum FWKTaskExtInfoType {
FWK_ADPT_EXT_OP_NAME,
FWK_ADPT_EXT_SESSION_INFO,
FWK_ADPT_EXT_BITMAP,
FWK_ADPT_EXT_TOPIC_TYPE,
FWK_ADPT_EXT_INVALID
};

enum FWKExtTopicType {
FWK_ADPT_TOPIC_DEVICE_ONLY = 0,
FWK_ADPT_TOPIC_DEVICE_FIRST,
FWK_ADPT_TOPIC_HOST_ONLY,
FWK_ADPT_TOPIC_HOST_FIRST,
FWK_ADPT_TOPIC_INVALID
};

enum FWKExtUpdateAddrType {
FWK_ADPT_UPDATE_NULL = 0,
FWK_ADPT_UPDATE_INPUT,


+ 5
- 0
third_party/fwkacllib/inc/runtime/kernel.h View File

@@ -191,6 +191,11 @@ typedef void (*rtCallback_t)(void *fnData);
#define RT_FUSION_KERNEL_DUMPFLAG (0x04)
#define RT_KERNEL_CUSTOM_AICPU (0x08)

// STARS topic scheduler sqe : topic_type
#define RT_KERNEL_DEVICE_FIRST (0X10)
#define RT_KERNEL_HOST_ONLY (0X20)
#define RT_KERNEL_HOST_FIRST (0X30)

/**
* @ingroup rt_kernel
* @brief kernel mode


Loading…
Cancel
Save