| @@ -27,7 +27,8 @@ namespace kernel { | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| class UnsortedSegmentSumGpuKernel : public GpuKernel { | class UnsortedSegmentSumGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {} | |||||
| UnsortedSegmentSumGpuKernel() | |||||
| : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1), is_null_input_(false) {} | |||||
| ~UnsortedSegmentSumGpuKernel() override = default; | ~UnsortedSegmentSumGpuKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -36,6 +37,9 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | ||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| if (is_null_input_) { | |||||
| return true; | |||||
| } | |||||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | T *input_addr = GetDeviceAddress<T>(inputs, 0); | ||||
| S *indices_addr = GetDeviceAddress<S>(inputs, 1); | S *indices_addr = GetDeviceAddress<S>(inputs, 1); | ||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | T *output_addr = GetDeviceAddress<T>(outputs, 0); | ||||
| @@ -50,6 +54,12 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { | |||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| is_null_input_ = CHECK_NULL_INPUT(input_shapes); | |||||
| if (is_null_input_) { | |||||
| MS_LOG(WARNING) << "UnsortedSegmentSum input is null"; | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||
| auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| @@ -83,6 +93,7 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { | |||||
| size_t input_dim1_; | size_t input_dim1_; | ||||
| size_t output_dim0_; | size_t output_dim0_; | ||||
| size_t output_dim1_; | size_t output_dim1_; | ||||
| bool is_null_input_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||