Browse Source

!3521 add AbsGrad support for GPU

Merge pull request !3521 from caojian05/ms_master_dev3
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
62fb2d1423
4 changed files with 29 additions and 4 deletions
  1. +13
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
  2. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
  3. +11
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
  4. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h

+ 13
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -87,6 +87,14 @@ struct FloorDivFunc<half, bool> {
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; }
};

template <typename T, typename S>
struct AbsGradFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) {
T zero = 0.0;
return lhs < zero ? -rhs : rhs;
}
};


template <>
struct PowerFunc<half, bool> {
@@ -149,6 +157,9 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
case BROADCAST_TYPE_FLOORDIV:
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_ABSGRAD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
}
}

@@ -192,6 +203,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_FLOORDIV:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_ABSGRAD:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
}
}



+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh View File

@@ -30,6 +30,7 @@ enum BroadcastOpType {
BROADCAST_TYPE_SUB = 7,
BROADCAST_TYPE_ADD = 8,
BROADCAST_TYPE_FLOORDIV = 9,
BROADCAST_TYPE_ABSGRAD = 10,
BROADCAST_TYPE_INVALID = 0xffffffff,
};



+ 11
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc View File

@@ -55,6 +55,10 @@ MS_REG_GPU_KERNEL_TWO(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)

// fp16
MS_REG_GPU_KERNEL_TWO(
@@ -93,6 +97,10 @@ MS_REG_GPU_KERNEL_TWO(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)

// int32
MS_REG_GPU_KERNEL_TWO(
@@ -113,5 +121,8 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
} // namespace kernel
} // namespace mindspore

+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -96,10 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);

static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
{"FloorDiv", BROADCAST_TYPE_FLOORDIV},
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
{"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
};

auto iter = kBroadcastTypeMap.find(kernel_name);


Loading…
Cancel
Save