|
|
|
@@ -22,7 +22,7 @@ namespace kernel { |
|
|
|
template <typename T> |
|
|
|
void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
epsilon_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"); |
|
|
|
epsilon_ = static_cast<T>(AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon")); |
|
|
|
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis")); |
|
|
|
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); |
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
@@ -35,9 +35,6 @@ void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
template <typename T> |
|
|
|
void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims, |
|
|
|
std::unique_ptr<T[]> *denominator_addr) { |
|
|
|
T temp = (T)0.0; |
|
|
|
T epsilon = (T)epsilon_; |
|
|
|
T denominator = (T)0.0; |
|
|
|
// Calculate transpose axes and stride |
|
|
|
size_t stride = 1; |
|
|
|
std::vector<size_t> axes(input_shape_.size()); |
|
|
|
@@ -60,6 +57,8 @@ void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t |
|
|
|
TransposeIterator tran_base_iter(std::move(transpose_shape), std::move(axes), input_shape_); |
|
|
|
|
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
T temp = (T)0.0; |
|
|
|
T denominator = (T)0.0; |
|
|
|
auto iter = tran_base_iter; |
|
|
|
iter.SetPos(start * stride); |
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
@@ -71,7 +70,7 @@ void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t |
|
|
|
denominator += temp * temp; |
|
|
|
iter.GenNextPos(); |
|
|
|
} |
|
|
|
denominator = (denominator > epsilon) ? denominator : epsilon; |
|
|
|
denominator = (denominator > epsilon_) ? denominator : epsilon_; |
|
|
|
(*denominator_addr)[i] = sqrt(denominator); |
|
|
|
} |
|
|
|
}; |
|
|
|
@@ -146,7 +145,7 @@ void L2NormalizeCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) { |
|
|
|
if (axis_ < -dims || axis_ >= dims) { |
|
|
|
MS_LOG(EXCEPTION) << "Attr axis_ " << axis_ << " must be in " << -dims << "~" << dims; |
|
|
|
} |
|
|
|
if (epsilon_ == 0.0) { |
|
|
|
if (epsilon_ == (T)0.0) { |
|
|
|
MS_LOG(EXCEPTION) << "Attr epsilon can not be zero."; |
|
|
|
} |
|
|
|
} |
|
|
|
|