|
|
|
@@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { |
|
|
|
MS_EXCEPTION_IF_NULL(new_node); |
|
|
|
new_node->set_scope(relu->scope()); |
|
|
|
|
|
|
|
// ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector |
|
|
|
// ReluV2's 2rd output is mask whose data type is uint8 |
|
|
|
TypeId mask_dtype = kNumberTypeUInt8; |
|
|
|
std::vector<size_t> mask_shape; |
|
|
|
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); |
|
|
|
if (mask_shape.size() != 4) { |
|
|
|
MS_LOG(WARNING) << "relu's infer shape size not equal 4"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); |
|
|
|
if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { |
|
|
|
mask_shape[1] = (mask_shape[1] + 31) / 32; |
|
|
|
mask_shape.push_back(4); |
|
|
|
} else { |
|
|
|
mask_shape[1] = (mask_shape[1] + 15) / 16; |
|
|
|
mask_shape.push_back(2); |
|
|
|
} |
|
|
|
|
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; |
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); |
|
|
|
@@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP |
|
|
|
MS_EXCEPTION_IF_NULL(relu); |
|
|
|
|
|
|
|
auto relu_v2 = CreateReluV2(graph, relu); |
|
|
|
if (relu_v2 == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> relu_v2_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); |
|
|
|
|
|
|
|
|