|
|
|
@@ -37,13 +37,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A |
|
|
|
const EquivPtr &equiv) const {
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
|
|
|
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
|
|
|
|
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
|
|
|
|
MS_EXCEPTION_IF_NULL(fbn2);
|
|
|
|
MS_EXCEPTION_IF_NULL(x_after);
|
|
|
|
MS_EXCEPTION_IF_NULL(x_before);
|
|
|
|
// only deal with x_after with fp32: x 16->32->bn->16->32
|
|
|
|
if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
std::vector<TypeId> outputs_type;
|
|
|
|
std::vector<std::vector<size_t>> outputs_shape;
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|