From: @wanyiming Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,341 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/cpu/ctcloss_cpu_kernel.h" | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void CTCLossCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| CheckParam(kernel_node); | |||||
| probs_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| indice_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| labels_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| if (probs_shape_.size() != 3) { | |||||
| MS_LOG(EXCEPTION) << "Probs dims: " << probs_shape_.size() << " not support."; | |||||
| } | |||||
| if (labels_dims_.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Labels dims: " << labels_dims_.size() << " not support."; | |||||
| } | |||||
| if (indice_dims_.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "Labels indice dims: " << indice_dims_.size() << " not support."; | |||||
| } | |||||
| preprocess_collapse_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "preprocess_collapse_repeated"); | |||||
| ctc_merge_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "ctc_merge_repeated"); | |||||
| ignore_longer_outputs_than_inputs_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "ignore_longer_outputs_than_inputs"); | |||||
| max_time_ = probs_shape_[0]; | |||||
| batch_size_ = probs_shape_[1]; | |||||
| num_class_ = probs_shape_[2]; | |||||
| blank_index_ = num_class_ - 1; | |||||
| } | |||||
| bool CTCLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (dtype_ == kNumberTypeFloat16) { | |||||
| LaunchKernel<float16>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| LaunchKernel<float>(inputs, outputs); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| inline T LogSumExp(T logprob1, T logprob2) { | |||||
| T kLogZero_ = -std::numeric_limits<T>::infinity(); | |||||
| if (logprob1 == kLogZero_) { | |||||
| return logprob2; | |||||
| } else if (logprob2 == kLogZero_) { | |||||
| return logprob1; | |||||
| } else { | |||||
| return (logprob1 > logprob2) ? logprob1 + log1p(exp(logprob2 - logprob1)) | |||||
| : logprob2 + log1p(exp(logprob1 - logprob2)); | |||||
| } | |||||
| } | |||||
| template <typename TT> | |||||
| void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank, | |||||
| const std::vector<std::vector<TT>> &y, | |||||
| std::vector<std::vector<TT>> *log_alpha_b) { | |||||
| int U = label_with_blank.size(); | |||||
| int T = (*log_alpha_b)[0].size(); | |||||
| TT kLogZero_ = -std::numeric_limits<TT>::infinity(); | |||||
| (*log_alpha_b)[0][0] = log(y[blank_index_][0]); | |||||
| auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_; | |||||
| if (label_with_blank.size() > 1) { | |||||
| (*log_alpha_b)[1][0] = log(y[label_0][0]); | |||||
| } | |||||
| for (int t = 1; t < T; ++t) { | |||||
| int low = std::max(0, U - (2 * (T - t))); | |||||
| int high = std::min(U, 2 * (t + 1)); | |||||
| for (int u = low; u < high; ++u) { | |||||
| auto sum_log_alpha_b = kLogZero_; | |||||
| if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) { | |||||
| sum_log_alpha_b = (*log_alpha_b)[u][t - 1]; | |||||
| } | |||||
| if (u > 0) { | |||||
| sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]); | |||||
| } | |||||
| if (u > 1) { | |||||
| bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]); | |||||
| if (label_with_blank[u] != blank_index_ && !matching_labels_merge) { | |||||
| sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]); | |||||
| } | |||||
| } | |||||
| (*log_alpha_b)[u][t] = log(y[label_with_blank[u]][t]) + sum_log_alpha_b; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename TT> | |||||
| void CTCLossCPUKernel::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank, | |||||
| const std::vector<std::vector<TT>> &y, | |||||
| std::vector<std::vector<TT>> *log_beta_b) { | |||||
| int T = (*log_beta_b)[0].size(); | |||||
| int U = label_with_blank.size(); | |||||
| if (U > 1) { | |||||
| for (int u = U - 2; u < U; ++u) { | |||||
| (*log_beta_b)[u][T - 1] = TT(0); | |||||
| } | |||||
| } else { | |||||
| (*log_beta_b)[0][T - 1] = TT(0); | |||||
| (*log_beta_b)[0][T - 2] = TT(0); | |||||
| } | |||||
| for (int t = T - 2; t >= 0; --t) { | |||||
| int low = std::max(0, U - (2 * (T - t))); | |||||
| int high = std::min(U, 2 * (t + 1)); | |||||
| for (int u = low; u < high; ++u) { | |||||
| if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) { | |||||
| (*log_beta_b)[u][t] = | |||||
| LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1]))); | |||||
| } | |||||
| if (u + 1 < U) { | |||||
| (*log_beta_b)[u][t] = | |||||
| LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1]))); | |||||
| } | |||||
| if (u + 2 < U) { | |||||
| bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]); | |||||
| if (label_with_blank[u] != blank_index_ && !matching_labels_merge) { | |||||
| (*log_beta_b)[u][t] = | |||||
| LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1]))); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename TT> | |||||
| void CTCLossCPUKernel::CalculateGrad(const std::vector<uint32_t> &label_with_blank, | |||||
| const std::vector<std::vector<TT>> &y, | |||||
| const std::vector<std::vector<TT>> &log_alpha_b, | |||||
| const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx, | |||||
| std::vector<std::vector<TT>> *dy) { | |||||
| auto dy_b = dy; | |||||
| TT kLogZero_ = -std::numeric_limits<TT>::infinity(); | |||||
| if (log_pzx == kLogZero_) { | |||||
| MS_LOG(INFO) << "No valid path found"; | |||||
| return; | |||||
| } | |||||
| size_t L = y.size(); | |||||
| size_t T = y[0].size(); | |||||
| size_t U = label_with_blank.size(); | |||||
| for (size_t t = 0; t < T; ++t) { | |||||
| std::vector<TT> prob_sum(L, kLogZero_); | |||||
| for (size_t u = 0; u < U; ++u) { | |||||
| uint32_t l = label_with_blank[u]; | |||||
| prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]); | |||||
| } | |||||
| for (size_t l = 0; l < L; ++l) { | |||||
| (*dy_b)[l][t] = y[l][t] - exp(prob_sum[l] - log_pzx); | |||||
| } | |||||
| } | |||||
| } | |||||
| void CTCLossCPUKernel::GenLableWithBlank(uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label, | |||||
| std::vector<std::vector<uint32_t>> *label_with_blank) { | |||||
| for (size_t b = 0; b < batch_size_; ++b) { | |||||
| std::vector<uint32_t> l; | |||||
| const std::vector<uint32_t> &label = batch_label[b]; | |||||
| bool has_blank = false; | |||||
| for (size_t i = 0; i < label.size(); ++i) { | |||||
| if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) { | |||||
| if (label[i] >= num_class_ - 1) { | |||||
| has_blank = true; | |||||
| } else { | |||||
| if (has_blank) { | |||||
| MS_LOG(EXCEPTION) << "Invalid labels(index >= num_class - 1) should not appear between two valid labels"; | |||||
| } | |||||
| l.push_back(label[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!ignore_longer_outputs_than_inputs_) { | |||||
| if (l.size() > seq_len[b]) { | |||||
| MS_LOG(EXCEPTION) << "Input time(sequence length) should greater than output size(label length), but gets " | |||||
| << seq_len[b] << "< " << l.size(); | |||||
| } | |||||
| } | |||||
| (*label_with_blank)[b].reserve(2 * l.size() + 1); | |||||
| for (auto l_i : l) { | |||||
| (*label_with_blank)[b].push_back(blank_index_); | |||||
| (*label_with_blank)[b].push_back(l_i); | |||||
| } | |||||
| (*label_with_blank)[b].push_back(blank_index_); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void InnerSoftMax(T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length, | |||||
| size_t num_class, size_t batch_size, size_t b) { | |||||
| for (size_t t = 0; t < sequence_length; ++t) { | |||||
| T maxCoeff(T(0)); | |||||
| T sumCoeff(T(0)); | |||||
| for (size_t c = 0; c < num_class; ++c) { | |||||
| if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) { | |||||
| maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c]; | |||||
| } | |||||
| } | |||||
| for (size_t c = 0; c < num_class; ++c) { | |||||
| sumCoeff += exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff); | |||||
| (*softmax_probs)[c][t] = exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff); | |||||
| } | |||||
| for (size_t c = 0; c < num_class; ++c) { | |||||
| (*softmax_probs)[c][t] /= sumCoeff; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void MatrixfromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) { | |||||
| array2D->resize(row); | |||||
| for (size_t i = 0; i < row; ++i) { | |||||
| (*array2D)[i].resize(col, init_value); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void CTCLossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||||
| auto inputs_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| auto labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->addr); | |||||
| auto labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->addr); | |||||
| auto sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->addr); | |||||
| auto loss_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| auto gradient_addr = reinterpret_cast<T *>(outputs[1]->addr); | |||||
| std::vector<std::vector<uint32_t>> label_batch; | |||||
| std::vector<std::vector<uint32_t>> labels_with_blank; | |||||
| std::vector<uint64_t> each_label_length; | |||||
| label_batch.resize(batch_size_); | |||||
| labels_with_blank.resize(batch_size_); | |||||
| each_label_length.resize(batch_size_, 0); | |||||
| T kLogZero_ = -std::numeric_limits<T>::infinity(); | |||||
| // check validation of sequence length | |||||
| for (size_t b = 0; b < batch_size_; ++b) { | |||||
| if (sequence_length_addr[b] < uint32_t(0)) { | |||||
| MS_LOG(EXCEPTION) << "Sequence length should > 0, but gets " << sequence_length_addr[b]; | |||||
| } | |||||
| if (sequence_length_addr[b] > max_time_) { | |||||
| MS_LOG(EXCEPTION) << "Max time should be greater than sequence length, but gets " << max_time_ << " < " | |||||
| << sequence_length_addr[b]; | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < indice_dims_[0]; ++i) { | |||||
| each_label_length[labels_indices_addr[i * 2]]++; | |||||
| } | |||||
| // convert label format of label_value and label_indices to batch_label | |||||
| uint64_t cum_sum = 0; | |||||
| for (size_t b = 0; b < batch_size_; ++b) { | |||||
| std::vector<uint32_t> *b_value = &label_batch[b]; | |||||
| for (size_t l = 0; l < each_label_length[b]; ++l) { | |||||
| b_value->push_back(labels_values_addr[cum_sum + l]); | |||||
| } | |||||
| cum_sum += each_label_length[b]; | |||||
| } | |||||
| // convert label to label with blank | |||||
| GenLableWithBlank(sequence_length_addr, label_batch, &labels_with_blank); | |||||
| for (size_t b = 0; b < batch_size_; ++b) { | |||||
| std::vector<uint32_t> label_with_blank = labels_with_blank[b]; | |||||
| // y_b [num_class, sequence_length] | |||||
| std::vector<std::vector<T>> y_b; | |||||
| std::vector<std::vector<T>> dy; | |||||
| std::vector<std::vector<T>> log_alpha_b; | |||||
| std::vector<std::vector<T>> log_beta_b; | |||||
| MatrixfromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_); | |||||
| MatrixfromVector(y_b.size(), y_b[0].size(), &dy, T(0)); | |||||
| MatrixfromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_); | |||||
| MatrixfromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_); | |||||
| InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b); | |||||
| CalculateFwdVar(label_with_blank, y_b, &log_alpha_b); | |||||
| CalculateBwdVar(label_with_blank, y_b, &log_beta_b); | |||||
| T log_pzx = kLogZero_; | |||||
| for (size_t u = 0; u < label_with_blank.size(); ++u) { | |||||
| log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]); | |||||
| } | |||||
| loss_addr[b] = -log_pzx; | |||||
| CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy); | |||||
| for (size_t t = 0; t < sequence_length_addr[b]; ++t) { | |||||
| for (size_t c = 0; c < num_class_; ++c) { | |||||
| gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void CTCLossCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 4) { | |||||
| MS_LOG(EXCEPTION) << "CTCLossCPUKernel needs 4 inputs, but gets " << input_num; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 2) { | |||||
| MS_LOG(EXCEPTION) << "CTCLossCPUKernel expects 2 outputs, but gets" << output_num; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,93 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_ | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <limits> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class CTCLossCPUKernel : public CPUKernel { | |||||
| public: | |||||
| CTCLossCPUKernel() = default; | |||||
| ~CTCLossCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| void GenLableWithBlank(uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label, | |||||
| std::vector<std::vector<uint32_t>> *label_with_blank); | |||||
| template <typename T> | |||||
| void CalculateFwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y, | |||||
| std::vector<std::vector<T>> *log_alpha_b); | |||||
| template <typename T> | |||||
| void CalculateBwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y, | |||||
| std::vector<std::vector<T>> *log_beta_b); | |||||
| template <typename T> | |||||
| void CalculateGrad(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y, | |||||
| const std::vector<std::vector<T>> &log_alpha_b, const std::vector<std::vector<T>> &log_beta_b, | |||||
| const T log_pzx, std::vector<std::vector<T>> *dy); | |||||
| template <typename T> | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||||
| private: | |||||
| void CheckParam(const CNodePtr &kernel_node); | |||||
| std::vector<size_t> probs_shape_; | |||||
| std::vector<size_t> indice_dims_; | |||||
| std::vector<size_t> labels_dims_; | |||||
| size_t num_class_; | |||||
| size_t max_time_; | |||||
| size_t batch_size_; | |||||
| uint32_t blank_index_; | |||||
| TypeId dtype_{kTypeUnknown}; | |||||
| bool preprocess_collapse_repeated_; | |||||
| bool ctc_merge_repeated_; | |||||
| bool ignore_longer_outputs_than_inputs_; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL(CTCLoss, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| CTCLossCPUKernel); | |||||
| MS_REG_CPU_KERNEL(CTCLoss, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| CTCLossCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_ | |||||
| @@ -50,9 +50,6 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), | (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), | ||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) { | |||||
| MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!"; | |||||
| } | |||||
| if (stride_ori[0] != 1 || stride_ori[1] != 1) { | if (stride_ori[0] != 1 || stride_ori[1] != 1) { | ||||
| MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!"; | MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!"; | ||||
| } | } | ||||
| @@ -62,10 +59,10 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { | if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { | ||||
| MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!"; | MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!"; | ||||
| } | } | ||||
| int stride = stride_ori[2]; | |||||
| int dilation = dilation_ori[2]; | |||||
| dnnl::memory::dims strides{stride, stride}; | |||||
| dnnl::memory::dims dilates{dilation - 1, dilation - 1}; | |||||
| std::vector<int> stride{stride_ori[2], stride_ori[3]}; | |||||
| dnnl::memory::dims strides{stride_ori[2], stride_ori[3]}; | |||||
| dnnl::memory::dims dilates{dilation_ori[2] - 1, dilation_ori[3] - 1}; | |||||
| std::vector<int> int_padding_l; | std::vector<int> int_padding_l; | ||||
| std::vector<int> int_padding_r; | std::vector<int> int_padding_r; | ||||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | ||||
| @@ -50,20 +50,16 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), | (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), | ||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { | |||||
| MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel only support equal stride, and stride must be 2d!"; | |||||
| } | |||||
| if (dilation_ori.size() != 4) { | if (dilation_ori.size() != 4) { | ||||
| MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation must be 4d!"; | MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation must be 4d!"; | ||||
| } | } | ||||
| if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { | if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { | ||||
| MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!"; | MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!"; | ||||
| } | } | ||||
| int stride = stride_ori[0]; | |||||
| int dilation = dilation_ori[2]; | |||||
| dnnl::memory::dims strides{stride, stride}; | |||||
| dnnl::memory::dims dilates{dilation - 1, dilation - 1}; | |||||
| std::vector<int> stride{stride_ori[0], stride_ori[1]}; | |||||
| dnnl::memory::dims strides{stride_ori[0], stride_ori[1]}; | |||||
| dnnl::memory::dims dilates{dilation_ori[2] - 1, dilation_ori[3] - 1}; | |||||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | ||||
| std::vector<int> int_padding_l; | std::vector<int> int_padding_l; | ||||
| std::vector<int> int_padding_r; | std::vector<int> int_padding_r; | ||||
| @@ -51,19 +51,18 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), | (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), | ||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { | |||||
| MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel only support equal stride, and stride must be 2d!"; | |||||
| } | |||||
| if (dilation_ori.size() != 4) { | if (dilation_ori.size() != 4) { | ||||
| MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation must be 4d!"; | MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation must be 4d!"; | ||||
| } | } | ||||
| if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { | if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { | ||||
| MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!"; | MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!"; | ||||
| } | } | ||||
| int stride = stride_ori[0]; | |||||
| int dilation = dilation_ori[2]; | |||||
| dnnl::memory::dims strides{stride, stride}; | |||||
| dnnl::memory::dims dilates{dilation - 1, dilation - 1}; | |||||
| std::vector<int> stride{stride_ori[0], stride_ori[1]}; | |||||
| dnnl::memory::dims strides{stride_ori[0], stride_ori[1]}; | |||||
| dnnl::memory::dims dilates{dilation_ori[2] - 1, dilation_ori[3] - 1}; | |||||
| std::vector<int> int_padding_l; | std::vector<int> int_padding_l; | ||||
| std::vector<int> int_padding_r; | std::vector<int> int_padding_r; | ||||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | ||||
| @@ -23,8 +23,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, | void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, | ||||
| const std::vector<size_t> &src_shape, const std::vector<size_t> &kernel_size, int stride, | |||||
| std::vector<int> *padding_l, std::vector<int> *padding_r) { | |||||
| const std::vector<size_t> &src_shape, const std::vector<size_t> &kernel_size, | |||||
| const std::vector<int> &stride, std::vector<int> *padding_l, | |||||
| std::vector<int> *padding_r) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| if (src_shape.size() < 2) { | if (src_shape.size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!"; | MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!"; | ||||
| @@ -37,10 +38,10 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa | |||||
| if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { | if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { | ||||
| for (size_t i = 0; i < weight_height.size(); ++i) { | for (size_t i = 0; i < weight_height.size(); ++i) { | ||||
| auto wh = weight_height[i]; | auto wh = weight_height[i]; | ||||
| int re = wh % stride; | |||||
| int re = wh % stride[i]; | |||||
| int pad_along; | int pad_along; | ||||
| if (re == 0) { | if (re == 0) { | ||||
| pad_along = std::max(SizeToInt(kernel_size[i]) - stride, 0); | |||||
| pad_along = std::max(SizeToInt(kernel_size[i]) - stride[i], 0); | |||||
| } else { | } else { | ||||
| pad_along = std::max(SizeToInt(kernel_size[i]) - re, 0); | pad_along = std::max(SizeToInt(kernel_size[i]) - re, 0); | ||||
| } | } | ||||
| @@ -60,8 +61,8 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa | |||||
| (void)std::transform(pad_me.begin(), pad_me.end(), std::back_inserter(pad), | (void)std::transform(pad_me.begin(), pad_me.end(), std::back_inserter(pad), | ||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| padding_l->emplace_back(pad[0]); | padding_l->emplace_back(pad[0]); | ||||
| padding_l->emplace_back(pad[1]); | |||||
| padding_r->emplace_back(pad[2]); | |||||
| padding_l->emplace_back(pad[2]); | |||||
| padding_r->emplace_back(pad[1]); | |||||
| padding_r->emplace_back(pad[3]); | padding_r->emplace_back(pad[3]); | ||||
| } | } | ||||
| } | } | ||||
| @@ -35,7 +35,7 @@ class MKLCPUKernel : public CPUKernel { | |||||
| bool BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape, | bool BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape, | ||||
| std::vector<size_t> *dst_shape); | std::vector<size_t> *dst_shape); | ||||
| void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector<size_t> &src_shape, | void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector<size_t> &src_shape, | ||||
| const std::vector<size_t> &kernel_size, int stride, std::vector<int> *padding_l, | |||||
| const std::vector<size_t> &kernel_size, const std::vector<int> &stride, std::vector<int> *padding_l, | |||||
| std::vector<int> *padding_r); | std::vector<int> *padding_r); | ||||
| void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); | void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); | ||||
| void SetArgumentHandle(int arg_key, void *ptr); | void SetArgumentHandle(int arg_key, void *ptr); | ||||
| @@ -40,13 +40,14 @@ void AvgPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (origin_kernel_sizes.size() != 4 || strides.size() != 4) { | if (origin_kernel_sizes.size() != 4 || strides.size() != 4) { | ||||
| MS_LOG(EXCEPTION) << "Invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size(); | MS_LOG(EXCEPTION) << "Invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size(); | ||||
| } | } | ||||
| std::vector<int> stride{strides[2], strides[3]}; | |||||
| dnnl::memory::dims strides_dims{strides[2], strides[3]}; | dnnl::memory::dims strides_dims{strides[2], strides[3]}; | ||||
| dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]}; | dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]}; | ||||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | ||||
| std::vector<int> int_padding_l; | std::vector<int> int_padding_l; | ||||
| std::vector<int> int_padding_r; | std::vector<int> int_padding_r; | ||||
| std::vector<size_t> kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])}); | std::vector<size_t> kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])}); | ||||
| GetPadding(kernel_node, pad_mode, src_shape, kernel_size, strides[3], &int_padding_l, &int_padding_r); | |||||
| GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); | |||||
| if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { | if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Pooling avg get padding failed"; | MS_LOG(EXCEPTION) << "Pooling avg get padding failed"; | ||||
| } | } | ||||
| @@ -34,7 +34,6 @@ class AvgPoolingGradCPUKernel : public MKLCPUKernel { | |||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| int stride_{0}; | |||||
| std::vector<size_t> kernel_size_; | std::vector<size_t> kernel_size_; | ||||
| }; | }; | ||||
| @@ -39,13 +39,14 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (origin_kernel_sizes.size() != 4 || strides.size() != 4) { | if (origin_kernel_sizes.size() != 4 || strides.size() != 4) { | ||||
| MS_LOG(EXCEPTION) << "invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size(); | MS_LOG(EXCEPTION) << "invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size(); | ||||
| } | } | ||||
| std::vector<int> stride{strides[2], strides[3]}; | |||||
| dnnl::memory::dims strides_dims{strides[2], strides[3]}; | dnnl::memory::dims strides_dims{strides[2], strides[3]}; | ||||
| dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]}; | dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]}; | ||||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | ||||
| std::vector<int> int_padding_l; | std::vector<int> int_padding_l; | ||||
| std::vector<int> int_padding_r; | std::vector<int> int_padding_r; | ||||
| std::vector<size_t> kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])}); | std::vector<size_t> kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])}); | ||||
| GetPadding(kernel_node, pad_mode, src_shape, kernel_size, strides[3], &int_padding_l, &int_padding_r); | |||||
| GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); | |||||
| if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { | if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "pooling get padding failed"; | MS_LOG(EXCEPTION) << "pooling get padding failed"; | ||||
| } | } | ||||
| @@ -41,7 +41,8 @@ void MaxPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| std::vector<int> padding_r; | std::vector<int> padding_r; | ||||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE); | ||||
| kernel_size_ = {IntToSize(kernel_sizes[2]), IntToSize(kernel_sizes[3])}; | kernel_size_ = {IntToSize(kernel_sizes[2]), IntToSize(kernel_sizes[3])}; | ||||
| stride_ = strides[3]; | |||||
| stride_.push_back(strides[2]); | |||||
| stride_.push_back(strides[3]); | |||||
| GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); | GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); | ||||
| } | } | ||||
| @@ -94,9 +95,9 @@ void MaxPoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float | |||||
| box[1].second = IntToSize(std::min(w_start + SizeToInt(kernel_size_[1]), src_width)); | box[1].second = IntToSize(std::min(w_start + SizeToInt(kernel_size_[1]), src_width)); | ||||
| RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); | RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); | ||||
| diff_index += 1; | diff_index += 1; | ||||
| w_start += stride_; | |||||
| w_start += stride_[1]; | |||||
| } | } | ||||
| h_start += stride_; | |||||
| h_start += stride_[0]; | |||||
| } | } | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ class MaxPoolingGradCPUKernel : public MKLCPUKernel { | |||||
| void RowPoolingGrad(const float *input, float *output, float diff, const std::vector<std::pair<size_t, size_t>> &box, | void RowPoolingGrad(const float *input, float *output, float diff, const std::vector<std::pair<size_t, size_t>> &box, | ||||
| std::vector<std::pair<size_t, float>> *row_max_pair); | std::vector<std::pair<size_t, float>> *row_max_pair); | ||||
| void ChannelPoolingGrad(const float *input, const float *diff, float *output); | void ChannelPoolingGrad(const float *input, const float *diff, float *output); | ||||
| int stride_{0}; | |||||
| std::vector<int> stride_; | |||||
| std::vector<size_t> kernel_size_; | std::vector<size_t> kernel_size_; | ||||
| std::vector<int> padding_l_; | std::vector<int> padding_l_; | ||||
| std::vector<size_t> src_shape_; | std::vector<size_t> src_shape_; | ||||
| @@ -0,0 +1,88 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops.composite import GradOperation | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.loss = P.CTCLoss() | |||||
| self.div = P.RealDiv() | |||||
| self.mean = P.ReduceMean() | |||||
| def construct(self, probs, label, input_length, indices): | |||||
| x, _ = self.loss(probs, indices, label, input_length) | |||||
| x = self.mean(x) | |||||
| return x | |||||
| class GradData(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradData, self).__init__() | |||||
| self.grad = GradOperation(get_all=True, sens_param=False) | |||||
| self.network = network | |||||
| def construct(self, probs, indices, labels, input_lengths): | |||||
| return self.grad(self.network)(probs, indices, labels, input_lengths) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_ctcloss(): | |||||
| probs = Tensor([[[-4.4131, -4.6093, -3.4333, -3.9268, -2.8917, -3.4093, -4.2243, -1.1379, -7.1046, -0.6902], | |||||
| [-2.5109, -3.3397, -4.9384, -1.2723, -1.1443, -2.4683, -2.6768, -4.1282, -2.7062, -3.1906], | |||||
| [-2.5092, -1.6392, -2.0864, -4.0059, -1.5610, -2.3223, -2.4816, -2.9922, -3.1412, -2.3311]], | |||||
| [[-2.1243, -3.5773, -3.1108, -4.4253, -2.7080, -1.9653, -2.0499, -2.4418, -1.8620, -1.5229], | |||||
| [-2.2479, -3.5128, -1.4189, -2.8701, -1.8562, -2.2752, -2.7019, -2.1865, -2.5634, -2.9869], | |||||
| [-3.2144, -1.3986, -3.1083, -3.9634, -3.5131, -3.2317, -2.6200, -1.7938, -1.8159, -1.7255]], | |||||
| [[-3.1301, -2.1649, -0.9286, -2.9452, -2.5992, -2.0263, -2.9201, -3.2155, -2.8302, -3.3636], | |||||
| [-1.4661, -3.6311, -2.4781, -4.6180, -2.7308, -1.7019, -1.5570, -2.6012, -4.0788, -2.3073], | |||||
| [-2.6833, -1.5033, -3.6922, -2.6360, -2.6974, -2.6847, -2.7579, -2.1396, -1.4093, -2.9630]], | |||||
| [[-2.0094, -2.3024, -3.3673, -1.0220, -2.8326, -2.2613, -3.0535, -2.9879, -3.7015, -2.4510], | |||||
| [-1.9071, -3.2603, -2.3229, -2.0572, -4.3450, -2.1284, -2.6306, -1.3824, -2.9815, -2.5061], | |||||
| [-2.7931, -3.7631, -3.2440, -4.3887, -1.0271, -3.8851, -1.2418, -4.5123, -2.2993, -2.4607]], | |||||
| [[-1.5763, -2.7539, -3.6941, -3.8166, -1.2599, -2.6903, -2.5826, -4.8208, -2.9562, -1.6321], | |||||
| [-3.3031, -3.0087, -1.9982, -1.9081, -3.8731, -2.8764, -2.2485, -2.3808, -1.4283, -2.1625], | |||||
| [-2.4516, -3.2394, -4.2053, -4.3541, -2.5229, -4.0717, -1.4894, -2.3151, -1.1098, -2.3465]]], | |||||
| dtype=mstype.float32) | |||||
| labels = Tensor([3, 4, 6, 4, 7, 1, 4, 6, 6, 8], dtype=mstype.int32) | |||||
| indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2], [2, 3]] | |||||
| indices = Tensor(indices, dtype=mstype.int64) | |||||
| input_lengths = Tensor([5, 5, 5], dtype=mstype.int32) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| net = Net() | |||||
| ctc_loss = net(probs, labels, input_lengths, indices) | |||||
| expect_loss = [9.083767] | |||||
| assert np.allclose(ctc_loss.asnumpy(), expect_loss) | |||||
| grad = GradData(net)(probs, labels, input_lengths, indices) | |||||
| grad = P.ReduceMean()(grad[0]) | |||||
| expect_grad = [-5.9604646e-09] | |||||
| assert np.allclose(grad.asnumpy(), expect_grad, atol=1e-5) | |||||