|
|
|
@@ -33,6 +33,30 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { |
|
|
|
return tupleget;
|
|
|
|
}
|
|
|
|
|
|
|
|
const void HandleOutput(const FuncGraphPtr &graph, const mindspore::CNodePtr &kernel) {
|
|
|
|
auto outlist = GetRealNodeUsedList(graph, kernel);
|
|
|
|
auto manager = graph->manager();
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
for (size_t j = 0; j < outlist->size(); j++) {
|
|
|
|
std::vector<TypeId> outputs_type;
|
|
|
|
std::vector<std::vector<size_t>> outputs_shape;
|
|
|
|
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(j).first), 1);
|
|
|
|
auto value_node = index_node->cast<ValueNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
int item_idx = GetValue<int>(value_node->value());
|
|
|
|
if (item_idx == 0) {
|
|
|
|
auto cast = GetRealNodeUsedList(graph, outlist->at(j).first);
|
|
|
|
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(j).first));
|
|
|
|
outputs_type.push_back(kNumberTypeFloat16);
|
|
|
|
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0));
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
const EquivPtr &equiv) const {
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
@@ -40,26 +64,17 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con |
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
|
|
|
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
|
|
|
|
|
|
|
auto dy_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 0);
|
|
|
|
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
|
|
|
|
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
|
|
|
|
auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0);
|
|
|
|
MS_EXCEPTION_IF_NULL(x_);
|
|
|
|
// if x_type is fp32, the cast is necessary.
|
|
|
|
if (x_type == kNumberTypeFloat32) {
|
|
|
|
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 2);
|
|
|
|
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 3);
|
|
|
|
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 4);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fbn2g);
|
|
|
|
MS_EXCEPTION_IF_NULL(dy_after);
|
|
|
|
MS_EXCEPTION_IF_NULL(dy_before);
|
|
|
|
MS_EXCEPTION_IF_NULL(scale);
|
|
|
|
MS_EXCEPTION_IF_NULL(x_);
|
|
|
|
MS_EXCEPTION_IF_NULL(mean);
|
|
|
|
MS_EXCEPTION_IF_NULL(var);
|
|
|
|
std::vector<TypeId> outputs_type;
|
|
|
|
std::vector<std::vector<size_t>> outputs_shape;
|
|
|
|
auto manager = graph->manager();
|
|
|
|
@@ -83,25 +98,7 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, kernel.get());
|
|
|
|
}
|
|
|
|
// 3. handle the output of fusedbatchnormgrad: tuplegetitem
|
|
|
|
auto outlist = GetRealNodeUsedList(graph, kernel);
|
|
|
|
for (size_t j = 0; j < outlist->size(); j++) {
|
|
|
|
outputs_type.clear();
|
|
|
|
outputs_shape.clear();
|
|
|
|
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(j).first), 1);
|
|
|
|
auto value_node = index_node->cast<ValueNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
int item_idx = GetValue<int>(value_node->value());
|
|
|
|
if (item_idx == 0) {
|
|
|
|
auto cast = GetRealNodeUsedList(graph, outlist->at(j).first);
|
|
|
|
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(j).first));
|
|
|
|
outputs_type.push_back(kNumberTypeFloat16);
|
|
|
|
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0));
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
HandleOutput(graph, kernel);
|
|
|
|
}
|
|
|
|
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
|
|
|
|
return node;
|
|
|
|
|