Browse Source

!3203 GPU fix cast fusion bug

Merge pull request !3203 from VectorSL/fix-cast-fusion
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
28f873e9ad
2 changed files with 58 additions and 48 deletions
  1. +26
    -25
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
  2. +32
    -23
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc

+ 26
- 25
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc View File

@@ -30,8 +30,7 @@ const BaseRef ReplaceBNCastFusion::DefinePattern() const {
VectorRef in_cast = VectorRef({prim::kPrimCast, x_}); VectorRef in_cast = VectorRef({prim::kPrimCast, x_});
VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_}); VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_});
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_}); VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_});
VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget});
return out_cast;
return tupleget;
} }
const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
@@ -40,19 +39,9 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
auto tuple = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 1);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 0);
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0); auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0); auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
if (item_idx != 0) {
return nullptr;
}
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 1); auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 2); auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 3); auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 3);
@@ -65,14 +54,32 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(tuple));
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto outlist = GetRealNodeUsedList(graph, fbn2);
for (size_t i = 0; i < outlist->size(); i++) {
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
return nullptr;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
}
}
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
outputs_type.clear();
outputs_shape.clear();
auto output_num = AnfAlgo::GetOutputTensorNum(fbn2); auto output_num = AnfAlgo::GetOutputTensorNum(fbn2);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i));
@@ -80,13 +87,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
} }
outputs_type[0] = kNumberTypeFloat16; outputs_type[0] = kNumberTypeFloat16;
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get());
outputs_type.clear();
outputs_shape.clear();
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get());
return tuple;
return node;
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

+ 32
- 23
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc View File

@@ -30,8 +30,7 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const {
VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_}); VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_});
VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_}); VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_});
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_});
VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget});
return out_cast;
return tupleget;
} }
const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
@@ -40,21 +39,16 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
auto tuple = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 1);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx != 0) {
return nullptr;
}
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 0);
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto dy_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 0); auto dy_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 0);
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0); auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1); auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0);
// if x_type is fp32, the cast is nessery.
if (x_type == kNumberTypeFloat32) {
return nullptr;
}
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 2); auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 3); auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 4); auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 4);
@@ -66,13 +60,32 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(x_); MS_EXCEPTION_IF_NULL(x_);
MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto outlist = GetRealNodeUsedList(graph, fbn2g);
for (size_t i = 0; i < outlist->size(); i++) {
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
return nullptr;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
}
}
outputs_type.clear();
outputs_shape.clear();
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before)); manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(tuple));
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g); auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i));
@@ -80,12 +93,8 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
} }
outputs_type[0] = kNumberTypeFloat16; outputs_type[0] = kNumberTypeFloat16;
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get());
outputs_type.clear();
outputs_shape.clear();
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get());
return tuple;
return node;
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

Loading…
Cancel
Save