| @@ -39,6 +39,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | ||||
| auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); | auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); | ||||
| if (axis_addr->isa<ValueTuple>()) { | if (axis_addr->isa<ValueTuple>()) { | ||||
| auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS); | auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS); | ||||
| if (attr_axis.size() > shape_.size()) { | if (attr_axis.size() > shape_.size()) { | ||||
| @@ -47,18 +48,24 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| axis_.push_back(shape_.size() - 1); | axis_.push_back(shape_.size() - 1); | ||||
| } else { | } else { | ||||
| for (auto axis : attr_axis) { | for (auto axis : attr_axis) { | ||||
| while (axis < 0) { | |||||
| axis += SizeToInt(shape_.size()); | |||||
| } | |||||
| if (IntToSize(axis) >= (shape_.size())) { | if (IntToSize(axis) >= (shape_.size())) { | ||||
| MS_LOG(EXCEPTION) << "axis value is oversize."; | MS_LOG(EXCEPTION) << "axis value is oversize."; | ||||
| } | } | ||||
| axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); | |||||
| axis_.push_back(IntToSize(axis)); | |||||
| } | } | ||||
| } | } | ||||
| } else if (axis_addr->isa<Int32Imm>()) { | } else if (axis_addr->isa<Int32Imm>()) { | ||||
| int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS); | int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS); | ||||
| if (axis >= 0 && IntToSize(axis) >= shape_.size()) { | |||||
| while (axis < 0) { | |||||
| axis += SizeToInt(shape_.size()); | |||||
| } | |||||
| if (IntToSize(axis) >= shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "axis value is oversize."; | MS_LOG(EXCEPTION) << "axis value is oversize."; | ||||
| } | } | ||||
| axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); | |||||
| axis_.push_back(IntToSize(axis)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; | MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; | ||||
| } | } | ||||