Browse Source

git update relplace_bn_grad_cast

tags/v0.7.0-beta
VectorSL 5 years ago
parent
commit
9dd0282a17
3 changed files with 27 additions and 44 deletions
  1. +0
    -2
      mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc
  2. +0
    -12
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
  3. +27
    -30
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc

+ 0
- 2
mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc View File

@@ -36,13 +36,11 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto A = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto B = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
MS_EXCEPTION_IF_NULL(A);
MS_EXCEPTION_IF_NULL(B);
int num_input = AnfAlgo::GetNodeAttr<int>(node, "n");
if (num_input == 2) {
auto prim = std::make_shared<Primitive>(prim::kPrimTensorAdd->name());
MS_EXCEPTION_IF_NULL(prim);


+ 0
- 12
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc View File

@@ -38,27 +38,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 4);
MS_EXCEPTION_IF_NULL(fbn2);
MS_EXCEPTION_IF_NULL(x_after);
MS_EXCEPTION_IF_NULL(x_before);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var);
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto outlist = GetRealNodeUsedList(graph, fbn2);
for (size_t i = 0; i < outlist->size(); i++) {
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
@@ -76,7 +65,6 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
}
}
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
outputs_type.clear();
outputs_shape.clear();


+ 27
- 30
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc View File

@@ -33,6 +33,30 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const {
return tupleget;
}
const void HandleOutput(const FuncGraphPtr &graph, const mindspore::CNodePtr &kernel) {
auto outlist = GetRealNodeUsedList(graph, kernel);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (size_t j = 0; j < outlist->size(); j++) {
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(j).first), 1);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(j).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
continue;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(j).first));
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get());
}
}
}
const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
@@ -40,26 +64,17 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(equiv);
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto dy_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 0);
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0);
MS_EXCEPTION_IF_NULL(x_);
// if x_type is fp32, the cast is necessary.
if (x_type == kNumberTypeFloat32) {
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) {
return nullptr;
}
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 4);
MS_EXCEPTION_IF_NULL(fbn2g);
MS_EXCEPTION_IF_NULL(dy_after);
MS_EXCEPTION_IF_NULL(dy_before);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(x_);
MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var);
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager();
@@ -83,25 +98,7 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, kernel.get());
}
// 3. handle the output of fusedbatchnormgrad: tuplegetitem
auto outlist = GetRealNodeUsedList(graph, kernel);
for (size_t j = 0; j < outlist->size(); j++) {
outputs_type.clear();
outputs_shape.clear();
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(j).first), 1);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(j).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
continue;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(j).first));
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get());
}
}
HandleOutput(graph, kernel);
}
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
return node;


Loading…
Cancel
Save