| @@ -32,7 +32,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con | |||||
| return reduce_min; | return reduce_min; | ||||
| } | } | ||||
| bool NeedOptmize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { | |||||
| bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { | |||||
| if (dtype != kNumberTypeFloat32) { | if (dtype != kNumberTypeFloat32) { | ||||
| MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; | MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; | ||||
| return false; | return false; | ||||
| @@ -84,7 +84,7 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v | |||||
| for (size_t item = 0; item < shape.size(); ++item) { | for (size_t item = 0; item < shape.size(); ++item) { | ||||
| if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { | if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { | ||||
| if (keep_dims) { | if (keep_dims) { | ||||
| // If keep_dims is true, curretn dimesion set to 1 | |||||
| // If keep_dims is true, current dimension set to 1 | |||||
| shape_first.push_back(1); | shape_first.push_back(1); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -110,28 +110,31 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN | |||||
| CheckCNodeInputSize(cnode, 2); | CheckCNodeInputSize(cnode, 2); | ||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | ||||
| auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); | auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); | ||||
| if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) { | |||||
| MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!"; | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) { | |||||
| MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { | |||||
| MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!"; | |||||
| auto axis_value = prim->GetAttr(kAttrAxis); | |||||
| MS_EXCEPTION_IF_NULL(axis_value); | |||||
| if (!axis_value->isa<ValueSequeue>()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); | |||||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims); | auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims); | ||||
| if (!NeedOptmize(dtype, shape, axis)) { | |||||
| if (!NeedOptimize(dtype, shape, axis)) { | |||||
| MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); | MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Create reduce_min1 | // Create reduce_min1 | ||||
| CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); | CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); | ||||
| std::vector<int> axis_fisrt = CalFirstAxis(shape, axis); | |||||
| std::vector<size_t> shape_first = GetInferShape(shape, axis_fisrt, keep_dims); | |||||
| std::vector<int> axis_first = CalFirstAxis(shape, axis); | |||||
| std::vector<size_t> shape_first = GetInferShape(shape, axis_first, keep_dims); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); | AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); | ||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1); | |||||
| // Create reduce_min2 | // Create reduce_min2 | ||||
| CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); | CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); | ||||