Browse Source

!9356 Fix infer type for Isnan isfinite isinfinite

From: @VectorSL
Reviewed-by: @liangchenghui,@linqingke
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9fb7aff589
14 changed files with 112 additions and 12 deletions
  1. +7
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc
  3. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu
  4. +7
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu
  5. +13
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
  6. +3
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu
  7. +16
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu
  8. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h
  9. +6
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
  10. +3
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc
  11. +33
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
  12. +16
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc
  13. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc
  14. +3
    -3
      mindspore/ops/operations/math_ops.py

+ 7
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc View File

@@ -39,5 +39,12 @@ MS_REG_GPU_KERNEL_ONE(Select,
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddOutputAttr(kNumberTypeInt32),
SelectGpuKernel, int) SelectGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
SelectGpuKernel, int64_t)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc View File

@@ -26,6 +26,8 @@ MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOu
SliceGpuFwdKernel, half) SliceGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SliceGpuFwdKernel, int16_t) SliceGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SliceGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SliceGpuFwdKernel, uchar) SliceGpuFwdKernel, uchar)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu View File

@@ -38,3 +38,5 @@ template void CalAssignAdd<float>(const size_t size, float* ref, const float* va
template void CalAssignAdd<half>(const size_t size, half* ref, const half* value, half* output, template void CalAssignAdd<half>(const size_t size, half* ref, const half* value, half* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalAssignAdd<int>(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream); template void CalAssignAdd<int>(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream);
template void CalAssignAdd<int64_t>(const size_t size, int64_t* ref, const int64_t* value, int64_t* output,
cudaStream_t cuda_stream);

+ 7
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu View File

@@ -121,6 +121,9 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
const half *x1, const half *x2, const half *dy, half *dx1, half *dx2, const half *x1, const half *x2, const half *dy, half *dx1, half *dx2,
cudaStream_t stream); cudaStream_t stream);
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
const int64_t *x1, const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2,
cudaStream_t stream);
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1,
@@ -133,3 +136,7 @@ template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const i
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1,
const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream); const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream);
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int64_t *x1,
const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2, cudaStream_t stream);

+ 13
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -203,6 +203,8 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int8_t
cudaStream_t stream); cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y, template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y,
cudaStream_t stream); cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, bool *y,
cudaStream_t stream);


// Element-wise ArithMetic // Element-wise ArithMetic
template <typename T, typename Func> template <typename T, typename Func>
@@ -269,6 +271,8 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int8_
cudaStream_t stream); cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y, template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y,
cudaStream_t stream); cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y,
cudaStream_t stream);


// Broadcast comparation // Broadcast comparation
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
@@ -347,6 +351,9 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0, const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0,
const uint8_t *x1, bool *y, cudaStream_t stream); const uint8_t *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int64_t *x0,
const int64_t *x1, bool *y, cudaStream_t stream);


// Broadcast Arithmetic // Broadcast Arithmetic
template <typename T, typename Func> template <typename T, typename Func>
@@ -468,6 +475,9 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0, const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0,
const uint8_t *x1, uint8_t *y, cudaStream_t stream); const uint8_t *x1, uint8_t *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int64_t *x0,
const int64_t *x1, int64_t *y, cudaStream_t stream);


// BroadcastTo // BroadcastTo
template <typename T> template <typename T>
@@ -500,3 +510,6 @@ template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2,
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr, const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr,
half *output_addr, cudaStream_t stream); half *output_addr, cudaStream_t stream);
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const int64_t *input_addr,
int64_t *output_addr, cudaStream_t stream);

+ 3
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu View File

@@ -40,3 +40,6 @@ template void CalSelect<int>(const size_t size, const bool* cond, const int* inp
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalSelect<half>(const size_t size, const bool* cond, const half* input_X, const half* input_y, template void CalSelect<half>(const size_t size, const bool* cond, const half* input_X, const half* input_y,
half* output, cudaStream_t cuda_stream); half* output, cudaStream_t cuda_stream);
template void CalSelect<int64_t>(const size_t size, const bool* cond, const int64_t* input_X, const int64_t* input_y,
int64_t* output, cudaStream_t cuda_stream);


+ 16
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu View File

@@ -204,6 +204,16 @@ template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigne
const std::vector<size_t> in_shape, const std::vector<int> begin, const std::vector<size_t> in_shape, const std::vector<int> begin,
const std::vector<int> size, unsigned char *output, cudaStream_t cuda_stream); const std::vector<int> size, unsigned char *output, cudaStream_t cuda_stream);


template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
cudaStream_t stream);
template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, int64_t *output,
cudaStream_t cuda_stream);

template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
@@ -230,6 +240,9 @@ template void StridedSlice(const std::vector<size_t> &input_shape, const std::ve
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const bool *input, const std::vector<int> &strides, const std::vector<size_t> &output_shape, const bool *input,
bool *output, cudaStream_t cuda_stream); bool *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape,
const int64_t *input, int64_t *output, cudaStream_t cuda_stream);


template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const float *dy, const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const float *dy,
@@ -249,3 +262,6 @@ template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::v
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const bool *dy, const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream); bool *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const int64_t *dy,
int64_t *dx, cudaStream_t cuda_stream);

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h View File

@@ -45,7 +45,7 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F;
// Used by mixprecision, cudnn dtype select // Used by mixprecision, cudnn dtype select
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT},
{"kNumberTypeFloat16", CUDNN_DATA_HALF}, {"kNumberTypeFloat16", CUDNN_DATA_HALF},
{"kNumberTypeInt64", CUDNN_DATA_DOUBLE},
{"kNumberTypeFloat64", CUDNN_DATA_DOUBLE},
{"kNumberTypeInt32", CUDNN_DATA_INT32}}; {"kNumberTypeInt32", CUDNN_DATA_INT32}};
// Used by mixprecision, cuda dtype select // Used by mixprecision, cuda dtype select
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},


+ 6
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h View File

@@ -51,7 +51,7 @@ class AddNGpuFwdKernel : public GpuKernel {
} }
T *output_addr = GetDeviceAddress<T>(outputs, 0); T *output_addr = GetDeviceAddress<T>(outputs, 0);
auto work_addr = output_addr; auto work_addr = output_addr;
for (size_t i = 0; i < IntToSize(num_input_); i++) {
for (size_t i = 0; i < num_input_; i++) {
if (output_addr == GetDeviceAddress<T>(inputs, i)) { if (output_addr == GetDeviceAddress<T>(inputs, i)) {
work_addr = GetDeviceAddress<T>(workspace, 0); work_addr = GetDeviceAddress<T>(workspace, 0);
break; break;
@@ -63,7 +63,7 @@ class AddNGpuFwdKernel : public GpuKernel {
} }
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
for (size_t i = 0; i < IntToSize(num_input_); i++) {
for (size_t i = 0; i < num_input_; i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i); T *input_addr = GetDeviceAddress<T>(inputs, i);
if (cudnn_data_type_ == CUDNN_DATA_INT32) { if (cudnn_data_type_ == CUDNN_DATA_INT32) {
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr, ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
@@ -85,8 +85,8 @@ class AddNGpuFwdKernel : public GpuKernel {
InitResource(); InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
num_input_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "n"));
if (IntToSize(num_input_) != input_num) {
num_input_ = GetAttr<int64_t>(kernel_node, "n");
if (num_input_ != input_num) {
MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input."; MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input.";
return false; return false;
} }
@@ -137,7 +137,7 @@ class AddNGpuFwdKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed"); "cudnnGetTensorSizeInBytes failed");
} }
for (int i = 0; i < num_input_; i++) {
for (size_t i = 0; i < num_input_; i++) {
input_size_list_.push_back(input_size_); input_size_list_.push_back(input_size_);
} }
output_size_list_.push_back(input_size_); output_size_list_.push_back(input_size_);
@@ -157,7 +157,7 @@ class AddNGpuFwdKernel : public GpuKernel {
size_t output_size_; size_t output_size_;
size_t workspace_size_; size_t workspace_size_;
bool is_null_input_; bool is_null_input_;
int num_input_;
size_t num_input_;
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore


+ 3
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc View File

@@ -21,6 +21,9 @@ namespace kernel {
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AssignAddGpuFwdKernel, int) AssignAddGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AssignAddGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
AssignAdd, AssignAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),


+ 33
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc View File

@@ -148,6 +148,39 @@ MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)


// int64
// int32
MS_REG_GPU_KERNEL_ONE(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)

// int8 // int8
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),


+ 16
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc View File

@@ -66,5 +66,21 @@ MS_REG_GPU_KERNEL_ONE(MaximumGrad,
.AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddOutputAttr(kNumberTypeInt32),
BroadcastOpGradGpuKernel, int) BroadcastOpGradGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
BroadcastOpGradGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
BroadcastOpGradGpuKernel, int64_t)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc View File

@@ -24,8 +24,6 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOut
ActivationGpuFwdKernel, half) ActivationGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ActivationGpuFwdKernel, int32_t) ActivationGpuFwdKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ActivationGpuFwdKernel, int64_t)


MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float) ActivationGpuFwdKernel, float)


+ 3
- 3
mindspore/ops/operations/math_ops.py View File

@@ -2962,7 +2962,7 @@ class IsNan(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
return mstype.bool_
return mstype.tensor_type(mstype.bool_)




class IsInf(PrimitiveWithInfer): class IsInf(PrimitiveWithInfer):
@@ -2993,7 +2993,7 @@ class IsInf(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
return mstype.bool_
return mstype.tensor_type(mstype.bool_)




class IsFinite(PrimitiveWithInfer): class IsFinite(PrimitiveWithInfer):
@@ -3027,7 +3027,7 @@ class IsFinite(PrimitiveWithInfer):


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_
return mstype.tensor_type(mstype.bool_)




class FloatStatus(PrimitiveWithInfer): class FloatStatus(PrimitiveWithInfer):


Loading…
Cancel
Save