|
|
|
@@ -94,7 +94,8 @@ class ArrayReduceGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); |
|
|
|
|
|
|
|
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueTuple>()) { |
|
|
|
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueTuple>() || |
|
|
|
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueList>()) { |
|
|
|
auto attr_axis = GetAttr<std::vector<int>>(kernel_node, "axis"); |
|
|
|
if (attr_axis.empty()) { |
|
|
|
axis_.push_back(-1); |
|
|
|
|