|
|
|
@@ -32,7 +32,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con |
|
|
|
return reduce_min; |
|
|
|
} |
|
|
|
|
|
|
|
bool NeedOptmize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { |
|
|
|
bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { |
|
|
|
if (dtype != kNumberTypeFloat32) { |
|
|
|
MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; |
|
|
|
return false; |
|
|
|
@@ -84,7 +84,7 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v |
|
|
|
for (size_t item = 0; item < shape.size(); ++item) { |
|
|
|
if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { |
|
|
|
if (keep_dims) { |
|
|
|
// If keep_dims is true, curretn dimesion set to 1 |
|
|
|
// If keep_dims is true, current dimension set to 1 |
|
|
|
shape_first.push_back(1); |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -110,28 +110,31 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN |
|
|
|
CheckCNodeInputSize(cnode, 2); |
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); |
|
|
|
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) { |
|
|
|
MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!"; |
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) { |
|
|
|
MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { |
|
|
|
MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!"; |
|
|
|
auto axis_value = prim->GetAttr(kAttrAxis); |
|
|
|
MS_EXCEPTION_IF_NULL(axis_value); |
|
|
|
if (!axis_value->isa<ValueSequeue>()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); |
|
|
|
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims); |
|
|
|
|
|
|
|
if (!NeedOptmize(dtype, shape, axis)) { |
|
|
|
if (!NeedOptimize(dtype, shape, axis)) { |
|
|
|
MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// Create reduce_min1 |
|
|
|
CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); |
|
|
|
std::vector<int> axis_fisrt = CalFirstAxis(shape, axis); |
|
|
|
std::vector<size_t> shape_first = GetInferShape(shape, axis_fisrt, keep_dims); |
|
|
|
std::vector<int> axis_first = CalFirstAxis(shape, axis); |
|
|
|
std::vector<size_t> shape_first = GetInferShape(shape, axis_first, keep_dims); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1); |
|
|
|
|
|
|
|
// Create reduce_min2 |
|
|
|
CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); |
|
|
|
|