Browse Source

!853 get input info of fusionop by visitkernel

Merge pull request !853 from Etone.Chan/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
8cc5942966
2 changed files with 7 additions and 16 deletions
  1. +6
    -15
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
  2. +1
    -1
      mindspore/ops/_op_impl/tbe/relu6_grad.py

+ 6
- 15
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc View File

@@ -270,17 +270,9 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_data_type;
for (const auto &input : inputs_list) {
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == prim::kPrimTupleGetItem->name()) {
auto tuple_getitem = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
inputs_format.push_back(AnfAlgo::GetOutputFormat(
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2))))));
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2))))));
} else {
inputs_format.push_back(AnfAlgo::GetOutputFormat(input, 0));
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(input, 0));
}
auto real_input = AnfAlgo::VisitKernel(input, 0);
inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second));
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second));
}
// outputs format and data type
std::vector<std::string> outputs_format;
@@ -375,11 +367,10 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
}
}

void GetFusionScopeInputNodeList(session::KernelGraph *kernel_graph,
void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
auto manager = kernel_graph->manager();
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);

for (auto &buffer_fusion_info : *buffer_fusion_infos) {
@@ -643,7 +634,7 @@ void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
GetFusionScopeInputNodeList(kernel_graph, buffer_fusion_infos);
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
buffer_fusion_info.second.kernel_build_info =


+ 1
- 1
mindspore/ops/_op_impl/tbe/relu6_grad.py View File

@@ -17,7 +17,7 @@
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

relu6_grad_op_info = TBERegOp("ReLU6Grad") \
.fusion_type("ELEMWISE") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("relu6_grad.so") \
.compute_cost(10) \


Loading…
Cancel
Save