Browse Source

!7140 UnsortedSegmentSum null check

Merge pull request !7140 from chenweifeng/unsorted-null-checkout
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2fb6b0cc5d
1 changed files with 12 additions and 1 deletions
  1. +12
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h

+ 12
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h View File

@@ -27,7 +27,8 @@ namespace kernel {
template <typename T, typename S> template <typename T, typename S>
class UnsortedSegmentSumGpuKernel : public GpuKernel { class UnsortedSegmentSumGpuKernel : public GpuKernel {
public: public:
UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {}
UnsortedSegmentSumGpuKernel()
: input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1), is_null_input_(false) {}
~UnsortedSegmentSumGpuKernel() override = default; ~UnsortedSegmentSumGpuKernel() override = default;


const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@@ -36,6 +37,9 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {


bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override { const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0); T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *indices_addr = GetDeviceAddress<S>(inputs, 1); S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0); T *output_addr = GetDeviceAddress<T>(outputs, 0);
@@ -50,6 +54,12 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {


bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
if (is_null_input_) {
MS_LOG(WARNING) << "UnsortedSegmentSum input is null";
InitSizeLists();
return true;
}
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);


@@ -83,6 +93,7 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
size_t input_dim1_; size_t input_dim1_;
size_t output_dim0_; size_t output_dim0_;
size_t output_dim1_; size_t output_dim1_;
bool is_null_input_;


std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;


Loading…
Cancel
Save