Browse Source

add more transform format insert transdata

tags/v0.5.0-beta
WilliamLian 5 years ago
parent
commit
5d25bf7ca2
5 changed files with 11 additions and 12 deletions
  1. +1
    -1
      mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
  2. +4
    -7
      mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
  3. +3
    -3
      mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
  4. +1
    -1
      mindspore/ccsrc/utils/utils.h
  5. +2
    -0
      mindspore/ops/_op_impl/tbe/trans_data.py

+ 1
- 1
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc View File

@@ -70,7 +70,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) {
kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) {
priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
is_init = true;
}


+ 4
- 7
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc View File

@@ -31,6 +31,7 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs;
@@ -110,13 +111,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
}
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
<< "when inserting the transdata node " << node->DebugString();
}
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
@@ -133,7 +130,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString();
}
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
}
@@ -154,7 +151,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
} else {
// No need insert trans op.


+ 3
- 3
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc View File

@@ -97,7 +97,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
std::string convert_format;
for (const auto &do_mask : do_mask_node_list) {
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0);
if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) {
if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) {
special_format = do_mask_data_format;
}
if (format_counter.find(do_mask_data_format) == format_counter.end()) {
@@ -111,7 +111,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
convert_format = kOpFormat_DEFAULT;
break;
}
if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() &&
if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() &&
special_format != do_mask_data_format) {
convert_format = kOpFormat_DEFAULT;
break;
@@ -133,7 +133,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
if (counter < iter.second) {
convert_format = iter.first;
}
if (counter == iter.second && kNeedTransFormatSet.find(convert_format) == kNeedTransFormatSet.end()) {
if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) {
convert_format = iter.first;
}
}


+ 1
- 1
mindspore/ccsrc/utils/utils.h View File

@@ -265,7 +265,7 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName,
};

const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};



+ 2
- 0
mindspore/ops/_op_impl/tbe/trans_data.py View File

@@ -58,6 +58,8 @@ trans_data_op_info = TBERegOp("TransData") \
.dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \
.dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \
.dtype_format(DataType.F32_Default, DataType.F32_NCHW) \
.dtype_format(DataType.F32_HWCN, DataType.F32_Default) \
.get_op_info()




Loading…
Cancel
Save