|
|
|
@@ -27,7 +27,8 @@ namespace kernel { |
|
|
|
template <typename T> |
|
|
|
class SliceGpuFwdKernel : public GpuKernel { |
|
|
|
public: |
|
|
|
SliceGpuFwdKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} |
|
|
|
SliceGpuFwdKernel() |
|
|
|
: is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} |
|
|
|
~SliceGpuFwdKernel() override = default; |
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } |
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } |
|
|
|
@@ -35,6 +36,9 @@ class SliceGpuFwdKernel : public GpuKernel { |
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, |
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override { |
|
|
|
if (is_null_input_) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
T *input = GetDeviceAddress<T>(inputs, 0); |
|
|
|
T *output = GetDeviceAddress<T>(outputs, 0); |
|
|
|
if (is_strided_slice_) { |
|
|
|
@@ -79,7 +83,11 @@ class SliceGpuFwdKernel : public GpuKernel { |
|
|
|
if (size_[i] < 0) { |
|
|
|
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; |
|
|
|
} |
|
|
|
if (size_[i] == 0) { |
|
|
|
if (begin_[i] == size_[i] && is_strided_slice_) { |
|
|
|
MS_LOG(WARNING) << "Output is null."; |
|
|
|
is_null_input_ = true; |
|
|
|
} |
|
|
|
if (size_[i] == 0 && strides_[i] > 0) { |
|
|
|
size_[i] = begin_[i] + 1; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -143,6 +151,7 @@ class SliceGpuFwdKernel : public GpuKernel { |
|
|
|
std::vector<size_t> workspace_size_list_; |
|
|
|
|
|
|
|
bool is_strided_slice_; |
|
|
|
bool is_null_input_; |
|
|
|
size_t input_size_; |
|
|
|
size_t output_size_; |
|
|
|
size_t workspace_size_; |
|
|
|
|