| @@ -27,7 +27,8 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class SliceGpuFwdKernel : public GpuKernel { | class SliceGpuFwdKernel : public GpuKernel { | ||||
| public: | public: | ||||
| SliceGpuFwdKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} | |||||
| SliceGpuFwdKernel() | |||||
| : is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} | |||||
| ~SliceGpuFwdKernel() override = default; | ~SliceGpuFwdKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| @@ -35,6 +36,9 @@ class SliceGpuFwdKernel : public GpuKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | ||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| if (is_null_input_) { | |||||
| return true; | |||||
| } | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| T *output = GetDeviceAddress<T>(outputs, 0); | T *output = GetDeviceAddress<T>(outputs, 0); | ||||
| if (is_strided_slice_) { | if (is_strided_slice_) { | ||||
| @@ -79,7 +83,11 @@ class SliceGpuFwdKernel : public GpuKernel { | |||||
| if (size_[i] < 0) { | if (size_[i] < 0) { | ||||
| size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; | size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; | ||||
| } | } | ||||
| if (size_[i] == 0) { | |||||
| if (begin_[i] == size_[i] && is_strided_slice_) { | |||||
| MS_LOG(WARNING) << "Output is null."; | |||||
| is_null_input_ = true; | |||||
| } | |||||
| if (size_[i] == 0 && strides_[i] > 0) { | |||||
| size_[i] = begin_[i] + 1; | size_[i] = begin_[i] + 1; | ||||
| } | } | ||||
| } | } | ||||
| @@ -143,6 +151,7 @@ class SliceGpuFwdKernel : public GpuKernel { | |||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| bool is_strided_slice_; | bool is_strided_slice_; | ||||
| bool is_null_input_; | |||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t output_size_; | size_t output_size_; | ||||
| size_t workspace_size_; | size_t workspace_size_; | ||||