| @@ -32,7 +32,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con | |||
| return reduce_min; | |||
| } | |||
| bool NeedOptmize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { | |||
| bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { | |||
| if (dtype != kNumberTypeFloat32) { | |||
| MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; | |||
| return false; | |||
| @@ -84,7 +84,7 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v | |||
| for (size_t item = 0; item < shape.size(); ++item) { | |||
| if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { | |||
| if (keep_dims) { | |||
| // If keep_dims is true, curretn dimesion set to 1 | |||
| // If keep_dims is true, current dimension set to 1 | |||
| shape_first.push_back(1); | |||
| } | |||
| } else { | |||
| @@ -110,28 +110,31 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN | |||
| CheckCNodeInputSize(cnode, 2); | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) { | |||
| MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!"; | |||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) { | |||
| MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!"; | |||
| return nullptr; | |||
| } | |||
| auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { | |||
| MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!"; | |||
| auto axis_value = prim->GetAttr(kAttrAxis); | |||
| MS_EXCEPTION_IF_NULL(axis_value); | |||
| if (!axis_value->isa<ValueSequeue>()) { | |||
| return nullptr; | |||
| } | |||
| auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); | |||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims); | |||
| if (!NeedOptmize(dtype, shape, axis)) { | |||
| if (!NeedOptimize(dtype, shape, axis)) { | |||
| MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); | |||
| return nullptr; | |||
| } | |||
| // Create reduce_min1 | |||
| CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); | |||
| std::vector<int> axis_fisrt = CalFirstAxis(shape, axis); | |||
| std::vector<size_t> shape_first = GetInferShape(shape, axis_fisrt, keep_dims); | |||
| std::vector<int> axis_first = CalFirstAxis(shape, axis); | |||
| std::vector<size_t> shape_first = GetInferShape(shape, axis_first, keep_dims); | |||
| AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1); | |||
| // Create reduce_min2 | |||
| CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); | |||