From: @danishnxt Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMin, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMinGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMin, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMinGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMin, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMinGpuKernel, int) | |||
| // Dynamic Mode | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMinGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMinGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMinGpuKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,131 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MIN_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MIN_H_ | |||
| #include <vector> | |||
| #include <limits> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class UnsortedSegmentMinGpuKernel : public GpuKernel { | |||
| public: | |||
| UnsortedSegmentMinGpuKernel() { ResetResource(); } | |||
| ~UnsortedSegmentMinGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| int *indices_addr = GetDeviceAddress<int>(inputs, 1); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| CalUnsortedSegmentMin(input_addr, indices_addr, num_segments_, outer_size_, inner_size_, output_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shapes); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "UnsortedSegmentMin input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| auto segment_ids_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); | |||
| auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num == 3) { | |||
| MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 3 - dynamic mode"; | |||
| } else { | |||
| MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2"; | |||
| } | |||
| num_segments_ = output_shapes[0]; | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < input_shapes.size(); i++) { | |||
| input_size_ *= input_shapes[i]; | |||
| } | |||
| segment_ids_size_ = 1; | |||
| for (size_t i = 0; i < segment_ids_shapes.size(); i++) { | |||
| segment_ids_size_ *= segment_ids_shapes[i]; | |||
| } | |||
| output_size_ = 1; | |||
| for (size_t i = 0; i < output_shapes.size(); i++) { | |||
| output_size_ *= output_shapes[i]; | |||
| } | |||
| outer_size_ = input_shapes[0]; | |||
| inner_size_ = 1; | |||
| for (size_t i = 1; i < input_shapes.size(); i++) { | |||
| inner_size_ *= input_shapes[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| num_segments_ = 1; | |||
| inner_size_ = 1; | |||
| outer_size_ = 1; | |||
| input_size_ = 1; | |||
| segment_ids_size_ = 1; | |||
| output_size_ = 1; | |||
| is_null_input_ = false; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| input_size_list_.push_back(segment_ids_size_ * sizeof(int)); | |||
| output_size_list_.push_back(output_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| int num_segments_; | |||
| size_t inner_size_; | |||
| size_t outer_size_; | |||
| size_t input_size_; | |||
| size_t segment_ids_size_; | |||
| size_t output_size_; | |||
| bool is_null_input_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MIN_H_ | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh" | |||
| #include <limits> | |||
| template<typename T> | |||
| __device__ __forceinline__ void max_val_init(T *init_val) { | |||
| *init_val = std::numeric_limits<T>::max(); | |||
| } | |||
| // Handle fp16 differently for assignment | |||
| template<> | |||
| __device__ __forceinline__ void max_val_init(half *init_val) { | |||
| *init_val = __int2half_rd(65504); // Max value for Half | |||
| } | |||
| template <typename T> | |||
| __global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| size_t inner_size, T init_K, T *output) { | |||
| max_val_init(&init_K); | |||
| for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; | |||
| t_idx += blockDim.x * gridDim.x) { | |||
| int segment_id = t_idx / KWARPSIZE / inner_size; | |||
| int inner_id = t_idx / KWARPSIZE % inner_size; | |||
| int lane_id = threadIdx.x % KWARPSIZE; | |||
| T threadK = init_K; | |||
| for (int i = lane_id; i < outer_size; i += KWARPSIZE) { | |||
| if (segment_ids[i] != segment_id) continue; | |||
| T other_K = input[i * inner_size + inner_id]; | |||
| if (threadK > other_K) { | |||
| threadK = other_K; | |||
| } | |||
| } | |||
| __syncwarp(); | |||
| for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) { | |||
| T other_K = __shfl_down_sync(0xffffffff, threadK, offset); | |||
| if (threadK > other_K) { | |||
| threadK = other_K; | |||
| } | |||
| } | |||
| __syncwarp(); | |||
| if (lane_id == 0) { | |||
| output[segment_id * inner_size + inner_id] = threadK; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream) { | |||
| int size = (inner_size * KWARPSIZE * num_segments); | |||
| T init_K = std::numeric_limits<T>::lowest(); // only init here - overwritten later | |||
| UnsortedSegmentMin<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, segment_ids, num_segments, outer_size, | |||
| inner_size, init_K, output); | |||
| return; | |||
| } | |||
| template void CalUnsortedSegmentMin<float>(const float *input, const int *segment_ids, const int num_segments, | |||
| size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMin<half>(const half *input, const int *segment_ids, const int num_segments, | |||
| size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMin<int>(const int *input, const int *segment_ids, const int num_segments, | |||
| size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MIN_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MIN_H_ | |||
| #include <cuda_runtime.h> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| // Setting warp size to sync data across threads | |||
| #define KWARPSIZE 32 | |||
| template <typename T> | |||
| void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MIN_H_ | |||
| @@ -115,6 +115,8 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -341,6 +341,74 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(segment_ids); | |||
| MS_EXCEPTION_IF_NULL(segment_ids->shape()); | |||
| auto segment_ids_shape = segment_ids->shape()->shape(); | |||
| (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMin should be %s"); | |||
| (void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMin should be %s"); | |||
| // check if dynamic shape | |||
| bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); | |||
| bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); | |||
| bool op_is_dynamic = x_is_dyn && ids_is_dyn; | |||
| auto x_shape = x->shape()->shape(); | |||
| ShapeVector shape; | |||
| int64_t num_segments_value = 0; | |||
| if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor | |||
| auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(num_segments); | |||
| auto num_segments_value_ptr = num_segments->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(num_segments_value_ptr); | |||
| auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(num_segments_tensor); | |||
| num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); | |||
| } else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar | |||
| auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); | |||
| num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin"; | |||
| } | |||
| if (num_segments_value <= 0) { | |||
| MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMin"; | |||
| } | |||
| shape.emplace_back(num_segments_value); | |||
| shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); | |||
| if (!op_is_dynamic) { | |||
| if (x_shape[0] != segment_ids_shape[0]) { | |||
| MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin"; | |||
| } | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape)); | |||
| } | |||
| // is dynamic | |||
| ShapeVector min_shape; | |||
| ShapeVector max_shape; | |||
| min_shape.emplace_back(num_segments_value); | |||
| max_shape.emplace_back(num_segments_value); | |||
| // only run validation if shape values are known | |||
| bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); | |||
| bool ids_any_shape = | |||
| std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); | |||
| if (!x_any_shape && !ids_any_shape) { | |||
| if (x_shape[0] != segment_ids_shape[0]) { | |||
| MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin"; | |||
| } | |||
| } | |||
| ShapeVector x_shape_min; | |||
| ShapeVector x_shape_max; | |||
| x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape(); | |||
| x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape(); | |||
| min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end()); | |||
| max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end()); | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| @@ -59,6 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | |||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | |||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | |||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | |||
| @@ -1922,7 +1922,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| return out | |||
| class UnsortedSegmentMin(PrimitiveWithInfer): | |||
| class UnsortedSegmentMin(PrimitiveWithCheck): | |||
| """ | |||
| Computes the minimum of a tensor along segments. | |||
| @@ -1959,26 +1959,19 @@ class UnsortedSegmentMin(PrimitiveWithInfer): | |||
| """Initialize UnsortedSegmentMin""" | |||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | |||
| def __infer__(self, x, segment_ids, num_segments): | |||
| x_type = x['dtype'] | |||
| x_shape = x['shape'] | |||
| def __check__(self, x, segment_ids, num_segments): | |||
| segment_ids_shape = segment_ids['shape'] | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| validator.check(f'first shape of input_x', x_shape[0], | |||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||
| num_segments_v = num_segments['value'] | |||
| validator.check_value_type('num_segments', num_segments_v, [int], self.name) | |||
| validator.check_positive_int(num_segments_v, "num_segments", self.name) | |||
| segment_ids_shape_len = len(segment_ids_shape) | |||
| out_shape = [num_segments_v] | |||
| out_shape += x_shape[segment_ids_shape_len:] | |||
| out = {'shape': out_shape, | |||
| 'dtype': x_type, | |||
| 'value': None} | |||
| return out | |||
| num_segments_type = num_segments['dtype'] | |||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | |||
| if isinstance(num_segments_type, type(mstype.tensor)): | |||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64], | |||
| self.name) | |||
| else: | |||
| validator.check_value_type('num_segments', num_segments['value'], [int], self.name) | |||
| class UnsortedSegmentMax(PrimitiveWithCheck): | |||
| @@ -222,39 +222,12 @@ class UnsortedSegmentMaxDynNet(nn.Cell): | |||
| @pytest.mark.env_onecard | |||
| def test_3d_float32_dyn(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int32) | |||
| num_segments = 3 | |||
| num_segments = 4 | |||
| net = UnsortedSegmentMaxDynNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], | |||
| [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], | |||
| [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], | |||
| [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], | |||
| [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], | |||
| [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], | |||
| [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_single_init_dyn(): | |||
| context.set_context(device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([3, 0, 1, -1], mstype.int32) | |||
| num_segments = 4 | |||
| net = UnsortedSegmentMaxDynNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], | |||
| @@ -278,7 +251,15 @@ def test_3d_single_init_dyn(): | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| num_segments = 6 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_single_init_dyn(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([3, 0, 1, -1], mstype.int32) | |||
| num_segments = 4 | |||
| net = UnsortedSegmentMaxDynNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| @@ -300,15 +281,40 @@ def test_3d_single_init_dyn(): | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]], | |||
| [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], | |||
| [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32) | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| # changing the input shape here for same net | |||
| input_x = Tensor(np.arange( | |||
| 4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32) | |||
| segment_ids = Tensor([3, 0, 1, -1], mstype.int32) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[1.4000000e+01, 1.5000000e+01], | |||
| [1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01], | |||
| [2.0000000e+01, 2.1000000e+01], | |||
| [2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01], | |||
| [2.6000000e+01, 2.7000000e+01]], | |||
| [[2.8000000e+01, 2.9000000e+01], | |||
| [3.0000000e+01, 3.1000000e+01], | |||
| [3.2000000e+01, 3.3000000e+01], | |||
| [3.4000000e+01, 3.5000000e+01], | |||
| [3.6000000e+01, 3.7000000e+01], | |||
| [3.8000000e+01, 3.9000000e+01], | |||
| [4.0000000e+01, 4.1000000e+01]], | |||
| [[-3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38], | |||
| [-3.4028235e+38, -3.4028235e+38]], | |||
| [[0.0000000e+00, 1.0000000e+00], | |||
| [2.0000000e+00, 3.0000000e+00], | |||
| [4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00], | |||
| [8.0000000e+00, 9.0000000e+00], | |||
| [1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @@ -0,0 +1,318 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| class UnsortedSegmentMinNet(nn.Cell): | |||
| def __init__(self, num_segments): | |||
| super(UnsortedSegmentMinNet, self).__init__() | |||
| self.unsorted_segment_min = P.UnsortedSegmentMin() | |||
| self.num_segments = num_segments | |||
| def construct(self, data, ids): | |||
| return self.unsorted_segment_min(data, ids, self.num_segments) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_1d_int32(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| input_x = Tensor([1, 2, 3, 4], mstype.int32) | |||
| segment_ids = Tensor([0, 0, 1, 2], mstype.int32) | |||
| num_segments = 4 | |||
| net = UnsortedSegmentMinNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([1, 3, 4, 2147483647]).astype(np.int32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_2d_int32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor([[1, 2, 3, 4], | |||
| [5, 6, 7, 8], | |||
| [9, 10, 11, 12]], mstype.int32) | |||
| segment_ids = Tensor([2, 1, 1], mstype.int32) | |||
| num_segments = 4 | |||
| net = UnsortedSegmentMinNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[2147483647, 2147483647, 2147483647, 2147483647], | |||
| [5, 6, 7, 8], | |||
| [1, 2, 3, 4], | |||
| [2147483647, 2147483647, 2147483647, 2147483647]]).astype(np.int32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_float16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16) | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int32) | |||
| num_segments = 5 | |||
| net = UnsortedSegmentMinNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04]], | |||
| [[1.50e+01, 1.60e+01, 1.70e+01], | |||
| [1.80e+01, 1.90e+01, 2.00e+01], | |||
| [2.10e+01, 2.20e+01, 2.30e+01], | |||
| [2.40e+01, 2.50e+01, 2.60e+01], | |||
| [2.70e+01, 2.80e+01, 2.90e+01]], | |||
| [[0.00e+00, 1.00e+00, 2.00e+00], | |||
| [3.00e+00, 4.00e+00, 5.00e+00], | |||
| [6.00e+00, 7.00e+00, 8.00e+00], | |||
| [9.00e+00, 1.00e+01, 1.10e+01], | |||
| [1.20e+01, 1.30e+01, 1.40e+01]], | |||
| [[6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04]], | |||
| [[6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04], | |||
| [6.55e+04, 6.55e+04, 6.55e+04]]]).astype(np.float16) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_float32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int32) | |||
| num_segments = 3 | |||
| net = UnsortedSegmentMinNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], | |||
| [[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], | |||
| [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], | |||
| [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], | |||
| [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_single_init(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([3, 0, 1, -1], mstype.int32) | |||
| net = P.UnsortedSegmentMin() | |||
| num_segments = 4 | |||
| output = net(input_x, segment_ids, num_segments).asnumpy() | |||
| expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], | |||
| [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], | |||
| [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], | |||
| [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], | |||
| [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], | |||
| [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], | |||
| [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], | |||
| [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], | |||
| [[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], | |||
| [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| num_segments = 6 | |||
| output = net(input_x, segment_ids, num_segments).asnumpy() | |||
| expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], | |||
| [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], | |||
| [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], | |||
| [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], | |||
| [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], | |||
| [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], | |||
| [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], | |||
| [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], | |||
| [[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], | |||
| [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]], | |||
| [[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], | |||
| [[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| # For testing Dynamic Shape operation | |||
| class UnsortedSegmentMinDynNet(nn.Cell): | |||
| def __init__(self, num_segments): | |||
| super(UnsortedSegmentMinDynNet, self).__init__() | |||
| self.unsorted_segment_min = P.UnsortedSegmentMin() | |||
| self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||
| self.num_segments = num_segments | |||
| def construct(self, data, ids): | |||
| dyn_data = self.gpu_convert_to_dynamic_shape(data) | |||
| dyn_ids = self.gpu_convert_to_dynamic_shape(ids) | |||
| return self.unsorted_segment_min(dyn_data, dyn_ids, self.num_segments) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_float32_dyn(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int32) | |||
| num_segments = 3 | |||
| net = UnsortedSegmentMinDynNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], | |||
| [[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], | |||
| [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], | |||
| [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], | |||
| [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_single_init_dyn(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| num_segments = 4 | |||
| net = UnsortedSegmentMinDynNet(num_segments) | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([3, 0, 1, -1], mstype.int32) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], | |||
| [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], | |||
| [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], | |||
| [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], | |||
| [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], | |||
| [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], | |||
| [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], | |||
| [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], | |||
| [[3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], | |||
| [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], | |||
| [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], | |||
| [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||
| # changing the input shape here for same net | |||
| input_x = Tensor(np.arange( | |||
| 4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32) | |||
| segment_ids = Tensor([3, 0, 1, -1], mstype.int32) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[1.4000000e+01, 1.5000000e+01], | |||
| [1.6000000e+01, 1.7000000e+01], | |||
| [1.8000000e+01, 1.9000000e+01], | |||
| [2.0000000e+01, 2.1000000e+01], | |||
| [2.2000000e+01, 2.3000000e+01], | |||
| [2.4000000e+01, 2.5000000e+01], | |||
| [2.6000000e+01, 2.7000000e+01]], | |||
| [[2.8000000e+01, 2.9000000e+01], | |||
| [3.0000000e+01, 3.1000000e+01], | |||
| [3.2000000e+01, 3.3000000e+01], | |||
| [3.4000000e+01, 3.5000000e+01], | |||
| [3.6000000e+01, 3.7000000e+01], | |||
| [3.8000000e+01, 3.9000000e+01], | |||
| [4.0000000e+01, 4.1000000e+01]], | |||
| [[3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38], | |||
| [3.4028235e+38, 3.4028235e+38]], | |||
| [[0.0000000e+00, 1.0000000e+00], | |||
| [2.0000000e+00, 3.0000000e+00], | |||
| [4.0000000e+00, 5.0000000e+00], | |||
| [6.0000000e+00, 7.0000000e+00], | |||
| [8.0000000e+00, 9.0000000e+00], | |||
| [1.0000000e+01, 1.1000000e+01], | |||
| [1.2000000e+01, 1.3000000e+01]]]).astype(np.float32) | |||
| np.testing.assert_array_almost_equal(output, expect) | |||