Browse Source

Fix bug: AbsGrad support int as input.

tags/v1.6.0
hezhenhao1 4 years ago
parent
commit
4c12385050
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc
  2. +1
    -1
      tests/st/ops/cpu/test_abs_op.py

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc View File

@@ -267,7 +267,7 @@ void EltWiseGradCPUKernel<T>::InitComputeFunc() {
if constexpr (std::is_same_v<T, int>) {
static const std::map<std::string,
std::function<void(EltWiseGradCPUKernel *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCPUKernel<T>::AbsGrad}};
elt_map{{prim::kPrimAbsGrad->name(), &EltWiseGradCPUKernel<T>::AbsGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCPUKernel does not support " << kernel_name_ << " with int as input.";
}


+ 1
- 1
tests/st/ops/cpu/test_abs_op.py View File

@@ -47,7 +47,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('dtype', [np.int, np.float32, np.float64])
def test_abs(dtype):
"""
Feature: ALL To ALL


Loading…
Cancel
Save