|
|
|
@@ -67,41 +67,47 @@ cl_float4 ReduceOpenCLKernel::GenC4Mask() { |
|
|
|
} |
|
|
|
|
|
|
|
bool ReduceOpenCLKernel::IsHWCReduce() { |
|
|
|
return !reduce_axes_[0] && reduce_axes_[1] && reduce_axes_[2] && reduce_axes_[3]; |
|
|
|
return !reduce_axes_[kNHWC_N] && reduce_axes_[kNHWC_H] && reduce_axes_[kNHWC_W] && reduce_axes_[kNHWC_C]; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReduceOpenCLKernel::IsHWReduce() { |
|
|
|
return !reduce_axes_[0] && reduce_axes_[1] && reduce_axes_[2] && !reduce_axes_[3]; |
|
|
|
return !reduce_axes_[kNHWC_N] && reduce_axes_[kNHWC_H] && reduce_axes_[kNHWC_W] && !reduce_axes_[kNHWC_C]; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReduceOpenCLKernel::IsWCReduce() { |
|
|
|
return !reduce_axes_[0] && !reduce_axes_[1] && reduce_axes_[2] && reduce_axes_[3]; |
|
|
|
return !reduce_axes_[kNHWC_N] && !reduce_axes_[kNHWC_H] && reduce_axes_[kNHWC_W] && reduce_axes_[kNHWC_C]; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReduceOpenCLKernel::IsHReduce() { |
|
|
|
return !reduce_axes_[0] && reduce_axes_[1] && !reduce_axes_[2] && !reduce_axes_[3]; |
|
|
|
return !reduce_axes_[kNHWC_N] && reduce_axes_[kNHWC_H] && !reduce_axes_[kNHWC_W] && !reduce_axes_[kNHWC_C]; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReduceOpenCLKernel::IsWReduce() { |
|
|
|
return !reduce_axes_[0] && !reduce_axes_[1] && reduce_axes_[2] && !reduce_axes_[3]; |
|
|
|
return !reduce_axes_[kNHWC_N] && !reduce_axes_[kNHWC_H] && reduce_axes_[kNHWC_W] && !reduce_axes_[kNHWC_C]; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReduceOpenCLKernel::IsCReduce() { |
|
|
|
return !reduce_axes_[0] && !reduce_axes_[1] && !reduce_axes_[2] && reduce_axes_[3]; |
|
|
|
return !reduce_axes_[kNHWC_N] && !reduce_axes_[kNHWC_H] && !reduce_axes_[kNHWC_W] && reduce_axes_[kNHWC_C]; |
|
|
|
} |
|
|
|
|
|
|
|
int ReduceOpenCLKernel::SetShapeSizeIs0Axes() { |
|
|
|
// axes is input tensor |
|
|
|
auto *axes_tensor = in_tensors_.at(1); |
|
|
|
auto input_shape_size = in_tensors_.at(0)->shape().size(); |
|
|
|
if (input_shape_size == 0) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
CHECK_NULL_RETURN(axes_tensor->data()); |
|
|
|
|
|
|
|
auto reduction_indices = reinterpret_cast<int *>(axes_tensor->data())[0]; |
|
|
|
|
|
|
|
if (reduction_indices == -1) { |
|
|
|
reduce_axes_[1] = true; |
|
|
|
reduce_axes_[2] = true; |
|
|
|
reduce_axes_[3] = true; |
|
|
|
} else if (reduction_indices == 1 || reduction_indices == 2 || reduction_indices == 3) { |
|
|
|
reduce_axes_[kNHWC_H] = true; |
|
|
|
reduce_axes_[kNHWC_W] = true; |
|
|
|
reduce_axes_[kNHWC_C] = true; |
|
|
|
} else if (reduction_indices == kNHWC_H || reduction_indices == kNHWC_W || reduction_indices == kNHWC_C) { |
|
|
|
reduction_indices = reduction_indices + (C4NUM % input_shape_size); |
|
|
|
reduce_axes_[reduction_indices] = true; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "in Reduce: axes tensor's reduction_indices should be -1, 1, 2, 3"; |
|
|
|
@@ -136,16 +142,16 @@ int ReduceOpenCLKernel::SetShapeSizeIs1Axes() { |
|
|
|
reduce_axes_[axis] = true; |
|
|
|
} |
|
|
|
if (num_axes == 1) { |
|
|
|
if (reduce_axes_[1] && inShape.W == 1) { |
|
|
|
reduce_axes_[2] = true; |
|
|
|
} else if (reduce_axes_[2]) { |
|
|
|
if (reduce_axes_[kNHWC_H] && inShape.W == 1) { |
|
|
|
reduce_axes_[kNHWC_W] = true; |
|
|
|
} else if (reduce_axes_[kNHWC_W]) { |
|
|
|
if (inShape.H == 1) { |
|
|
|
reduce_axes_[1] = true; |
|
|
|
reduce_axes_[kNHWC_H] = true; |
|
|
|
} else if (inShape.C == 1) { |
|
|
|
reduce_axes_[3] = true; |
|
|
|
reduce_axes_[kNHWC_C] = true; |
|
|
|
} |
|
|
|
} else if (reduce_axes_[3] && inShape.W == 1) { |
|
|
|
reduce_axes_[3] = true; |
|
|
|
} else if (reduce_axes_[kNHWC_C] && inShape.W == 1) { |
|
|
|
reduce_axes_[kNHWC_C] = true; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
@@ -200,7 +206,7 @@ int ReduceOpenCLKernel::CheckSpecs() { |
|
|
|
if (IsReduceAxesSupport() != RET_OK) { |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
if ((IsWCReduce() || IsWReduce()) && !reduce_param->keep_dims_) { |
|
|
|
if (IsWCReduce() && !reduce_param->keep_dims_) { |
|
|
|
MS_LOG(WARNING) << "reduce axis (2,3) should keep dims"; |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
|