| @@ -30,7 +30,14 @@ MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int) | |||
| // Dynamic Mode | |||
| // Dynamic Mode - registered for int32/int64 3rd input | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -38,6 +45,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| @@ -45,6 +59,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| @@ -69,7 +69,11 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { | |||
| } else { | |||
| MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2"; | |||
| } | |||
| auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); | |||
| if (value_count.size() != 1) { | |||
| MS_LOG(ERROR) << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank: 1, got Rank: " | |||
| << value_count.size() << "."; | |||
| } | |||
| num_segments_ = output_shapes[0]; | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < input_shapes.size(); i++) { | |||
| @@ -117,7 +121,7 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { | |||
| } | |||
| private: | |||
| int num_segments_; | |||
| int64_t num_segments_; | |||
| size_t inner_size_; | |||
| size_t outer_size_; | |||
| size_t input_size_; | |||
| @@ -30,7 +30,14 @@ MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMin, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMinGpuKernel, int) | |||
| // Dynamic Mode | |||
| // Dynamic Mode - registered for int32/int64 3rd input | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMinGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -38,6 +45,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMinGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMinGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| @@ -45,6 +59,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMinGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMinGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| @@ -65,7 +65,11 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel { | |||
| } else { | |||
| MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2"; | |||
| } | |||
| auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); | |||
| if (value_count.size() != 1) { | |||
| MS_LOG(ERROR) << "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank: 1, got Rank: " | |||
| << value_count.size() << "."; | |||
| } | |||
| num_segments_ = output_shapes[0]; | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < input_shapes.size(); i++) { | |||
| @@ -113,7 +117,7 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel { | |||
| } | |||
| private: | |||
| int num_segments_; | |||
| int64_t num_segments_; | |||
| size_t inner_size_; | |||
| size_t outer_size_; | |||
| size_t input_size_; | |||
| @@ -18,8 +18,8 @@ | |||
| #include <limits> | |||
| template <typename T> | |||
| __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| size_t inner_size, bool fp16_flag, T init_K, T *output) { | |||
| __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, bool fp16_flag, T init_K, T *output) { | |||
| if (fp16_flag) { | |||
| init_K = __int2half_rd(-65504); // min value representable by float16 | |||
| } | |||
| @@ -57,7 +57,7 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const | |||
| } | |||
| template <typename T> | |||
| void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream) { | |||
| int size = (inner_size * KWARPSIZE * num_segments); | |||
| bool fp16_flag = false; | |||
| @@ -71,9 +71,9 @@ void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num | |||
| return; | |||
| } | |||
| template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int num_segments, | |||
| template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int num_segments, | |||
| template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int num_segments, | |||
| template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); | |||
| @@ -22,9 +22,8 @@ | |||
| // Setting warp size to sync data across threads | |||
| #define KWARPSIZE 32 | |||
| template <typename T> | |||
| void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_ | |||
| @@ -17,19 +17,19 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh" | |||
| #include <limits> | |||
| template<typename T> | |||
| template <typename T> | |||
| __device__ __forceinline__ void max_val_init(T *init_val) { | |||
| *init_val = std::numeric_limits<T>::max(); | |||
| } | |||
| // Handle fp16 differently for assignment | |||
| template<> | |||
| template <> | |||
| __device__ __forceinline__ void max_val_init(half *init_val) { | |||
| *init_val = __int2half_rd(65504); // Max value for Half | |||
| } | |||
| template <typename T> | |||
| __global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| size_t inner_size, T init_K, T *output) { | |||
| __global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, T init_K, T *output) { | |||
| max_val_init(&init_K); | |||
| for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; | |||
| t_idx += blockDim.x * gridDim.x) { | |||
| @@ -62,18 +62,18 @@ __global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const | |||
| } | |||
| template <typename T> | |||
| void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream) { | |||
| int size = (inner_size * KWARPSIZE * num_segments); | |||
| T init_K = std::numeric_limits<T>::lowest(); // only init here - overwritten later | |||
| T init_K = std::numeric_limits<T>::lowest(); // only init here - overwritten later | |||
| UnsortedSegmentMin<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, segment_ids, num_segments, outer_size, | |||
| inner_size, init_K, output); | |||
| return; | |||
| } | |||
| template void CalUnsortedSegmentMin<float>(const float *input, const int *segment_ids, const int num_segments, | |||
| template void CalUnsortedSegmentMin<float>(const float *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMin<half>(const half *input, const int *segment_ids, const int num_segments, | |||
| template void CalUnsortedSegmentMin<half>(const half *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMin<int>(const int *input, const int *segment_ids, const int num_segments, | |||
| template void CalUnsortedSegmentMin<int>(const int *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); | |||
| @@ -23,6 +23,6 @@ | |||
| // Setting warp size to sync data across threads | |||
| #define KWARPSIZE 32 | |||
| template <typename T> | |||
| void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MIN_H_ | |||
| @@ -29,7 +29,7 @@ namespace kernel { | |||
| template <typename T> | |||
| class PadGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(0), output_size_(0), workspace_size_(0) {} | |||
| PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(1), output_size_(1), workspace_size_(0) {} | |||
| ~PadGpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -53,13 +53,11 @@ class PadGpuFwdKernel : public GpuKernel { | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| // check number of inputs -> should be 1 | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but Pad needs 1 input."; | |||
| return false; | |||
| } | |||
| // check number of output -> should be 1 | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; | |||
| @@ -67,8 +65,7 @@ class PadGpuFwdKernel : public GpuKernel { | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| shape_size_ = input_shape.size(); | |||
| // shape adjustement -> from 2d/3d to 4d to standardize | |||
| if (shape_size_ == 4) { | |||
| if (shape_size_ == 4) { // shape adjustement from 2d/3d to 4d | |||
| } else if (shape_size_ == 3) { | |||
| auto it = input_shape.begin(); | |||
| input_shape.insert(it, 1); // batch padding | |||
| @@ -87,8 +84,7 @@ class PadGpuFwdKernel : public GpuKernel { | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| return shape; | |||
| }); | |||
| // shape adjustement -> from 2d/3d to 4d to standardize | |||
| if (paddings.size() == 4) { | |||
| if (paddings.size() == 4) { // shape adjustement from 2d/3d to 4d | |||
| } else if (paddings.size() == 3) { | |||
| auto it = paddings.begin(); | |||
| paddings.insert(it, 1, {0, 0}); // batch padding | |||
| @@ -96,13 +92,11 @@ class PadGpuFwdKernel : public GpuKernel { | |||
| auto it = paddings.begin(); | |||
| paddings.insert(it, 2, {0, 0}); // channel padding | |||
| } | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < shape_size_; i++) { | |||
| input_size_ *= input_shape[i]; | |||
| input_shape_.push_back(input_shape[i]); | |||
| } | |||
| input_size_ *= sizeof(T); | |||
| output_size_ = 1; | |||
| for (size_t i = 0; i < shape_size_; i++) { | |||
| temp = input_shape[i] + (paddings[i][0] + paddings[i][1]); // compute new dim size | |||
| output_size_ *= temp; | |||
| @@ -227,10 +227,18 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri | |||
| MS_EXCEPTION_IF_NULL(num_segments_value_ptr); | |||
| auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(num_segments_tensor); | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| } else { | |||
| num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c()); | |||
| } | |||
| } else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar | |||
| auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| } else { | |||
| num_segments_value = GetValue<int32_t>(num_segments->BuildValue()); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentSum"; | |||
| } | |||
| @@ -300,10 +308,19 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri | |||
| MS_EXCEPTION_IF_NULL(num_segments_value_ptr); | |||
| auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(num_segments_tensor); | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| } else { | |||
| num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c()); | |||
| } | |||
| // num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| } else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar | |||
| auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| } else { | |||
| num_segments_value = GetValue<int32_t>(num_segments->BuildValue()); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax"; | |||
| } | |||
| @@ -368,10 +385,18 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri | |||
| MS_EXCEPTION_IF_NULL(num_segments_value_ptr); | |||
| auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(num_segments_tensor); | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| } else { | |||
| num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c()); | |||
| } | |||
| } else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar | |||
| auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| } else { | |||
| num_segments_value = GetValue<int32_t>(num_segments->BuildValue()); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin"; | |||
| } | |||
| @@ -1893,8 +1893,10 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name) | |||
| validator.check(f'rank of input_x', len(x_shp), | |||
| 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name) | |||
| for i, value in enumerate(segment_ids_shp): | |||
| validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) | |||
| if (not -1 in x_shp and not -1 in segment_ids_shp): | |||
| # only validate when both shapes fully known | |||
| for i, value in enumerate(segment_ids_shp): | |||
| validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) | |||
| num_segments_v = num_segments['value'] | |||
| num_segments_type = num_segments['dtype'] | |||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | |||
| @@ -1968,7 +1970,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck): | |||
| num_segments_type = num_segments['dtype'] | |||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | |||
| if isinstance(num_segments_type, type(mstype.tensor)): | |||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64], | |||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], | |||
| self.name) | |||
| else: | |||
| validator.check_value_type('num_segments', num_segments['value'], [int], self.name) | |||
| @@ -2021,7 +2023,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck): | |||
| num_segments_type = num_segments['dtype'] | |||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | |||
| if isinstance(num_segments_type, type(mstype.tensor)): | |||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64], | |||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], | |||
| self.name) | |||
| else: | |||
| validator.check_value_type('num_segments', num_segments['value'], [int], self.name) | |||