changed default max_output_length to 1000000 change docstring fix ci change max_output_length to maxlentags/v1.2.0-rc1
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstdint> | |||||
| #include "backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(Range, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| DynamicRangeGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Range, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddOutputAttr(kNumberTypeFloat64), | |||||
| DynamicRangeGpuKernel, double) | |||||
| MS_REG_GPU_KERNEL_ONE(Range, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicRangeGpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE(Range, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddOutputAttr(kNumberTypeInt64), | |||||
| DynamicRangeGpuKernel, int64_t) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_ | |||||
| #include <cuda_runtime.h> | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class DynamicRangeGpuKernel : public GpuKernel { | |||||
| public: | |||||
| DynamicRangeGpuKernel() { ResetResource(); } | |||||
| ~DynamicRangeGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *range_start = GetDeviceAddress<T>(inputs, 0); | |||||
| T *range_end = GetDeviceAddress<T>(inputs, 1); | |||||
| T *range_delta = GetDeviceAddress<T>(inputs, 2); | |||||
| T *output_device_address = GetDeviceAddress<T>(outputs, 0); | |||||
| int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0); | |||||
| stream_ptr_ = stream_ptr; | |||||
| CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, | |||||
| max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| // use workspace[0] for actual output shape, we know it must be 1d | |||||
| CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, | |||||
| cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t), | |||||
| cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to copy gpu memory."); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); | |||||
| return true; | |||||
| } | |||||
| void PostExecute() override { | |||||
| // required synchronize for PostExecute | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)), | |||||
| "cudaStreamSynchronize failed"); | |||||
| std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)}; | |||||
| std::vector<std::vector<size_t>> output_shape = {{(size_t)output_shape_}}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, c_node_ptr_.get()); | |||||
| } | |||||
| void ResetResource() noexcept override { | |||||
| stream_ptr_ = nullptr; | |||||
| c_node_ptr_ = nullptr; | |||||
| output_shape_ = 0; | |||||
| max_output_length_ = 0; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_count != 3) { | |||||
| MS_LOG(ERROR) << input_count << " inputs were provided, but DynamicRangeGpuKernel expects 3."; | |||||
| return false; | |||||
| } | |||||
| max_output_length_ = GetAttr<int64_t>(kernel_node, "maxlen"); | |||||
| c_node_ptr_ = kernel_node; | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| output_size_list_.push_back(max_output_length_ * sizeof(T)); | |||||
| // this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape. | |||||
| workspace_size_list_.push_back(sizeof(int64_t)); | |||||
| return; | |||||
| } | |||||
| private: | |||||
| void *stream_ptr_; | |||||
| CNodePtr c_node_ptr_; | |||||
| int64_t output_shape_; | |||||
| int64_t max_output_length_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dynamic_range_impl.cuh" | |||||
| #include <cuda_runtime.h> | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| __device__ void CheckInputs(const T &start, const T &end, const T &delta) { | |||||
| if (delta == 0) { | |||||
| asm("trap;"); | |||||
| } | |||||
| if (start < end && delta < 0) { | |||||
| asm("trap;"); | |||||
| } | |||||
| if (start > end && delta > 0) { | |||||
| asm("trap;"); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, | |||||
| int64_t *output_shape, const int64_t max_output_size) { | |||||
| T start = range_start[0]; | |||||
| T end = range_end[0]; | |||||
| T delta = range_delta[0]; | |||||
| CheckInputs(start, end, delta); | |||||
| int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta)); | |||||
| if (real_output_shape > max_output_size) { | |||||
| asm("trap;"); | |||||
| } | |||||
| *output_shape = real_output_shape; | |||||
| size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) { | |||||
| output[gt_id] = gt_id * delta + start; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, | |||||
| const int64_t max_output_size, cudaStream_t cuda_stream) { | |||||
| Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta, | |||||
| output, output_shape, max_output_size); | |||||
| } | |||||
| template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output, | |||||
| int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); | |||||
| template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta, | |||||
| int64_t *output, int64_t *output_shape, const int64_t max_output_size, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output, | |||||
| int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); | |||||
| template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta, | |||||
| double *output, int64_t *output_shape, const int64_t max_output_size, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ | |||||
| #include <cuda_runtime.h> | |||||
| template <typename T> | |||||
| void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, | |||||
| const int64_t max_output_size, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ | |||||
| @@ -1,7 +1,7 @@ | |||||
| /** | /** | ||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | ||||
| * | * | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -285,6 +285,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -168,6 +168,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| if (max_shape.empty()) { | if (max_shape.empty()) { | ||||
| max_shape = shape->shape(); | max_shape = shape->shape(); | ||||
| } | } | ||||
| auto ids = | auto ids = | ||||
| std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); | std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); | ||||
| // Currently we choose the same data type as input for the idx. | // Currently we choose the same data type as input for the idx. | ||||
| @@ -186,6 +187,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| if (idx_max_shape.empty()) { | if (idx_max_shape.empty()) { | ||||
| idx_max_shape = shape->shape(); | idx_max_shape = shape->shape(); | ||||
| } | } | ||||
| auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape); | auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape); | ||||
| ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape)); | ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape)); | ||||
| // outputs: ids, ids_idx | // outputs: ids, ids_idx | ||||
| @@ -951,5 +953,36 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive | |||||
| ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max); | ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max); | ||||
| return std::make_shared<AbstractTensor>(kBool, output_shape); | return std::make_shared<AbstractTensor>(kBool, output_shape); | ||||
| } | } | ||||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string &op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | |||||
| AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| AbstractTensorPtr range_delta = CheckArg<AbstractTensor>(op_name, args_spec_list, 2); | |||||
| TypePtrList supported_types = {kInt64, kInt32, kFloat32, kFloat64}; | |||||
| TypePtr range_start_type = CheckTensorDType(range_start, supported_types, "range_start input of Range should be %s"); | |||||
| TypePtr range_end_type = CheckTensorDType(range_end, supported_types, "range_start input of Range should be %s"); | |||||
| TypePtr range_delta_type = CheckTensorDType(range_delta, supported_types, "range_start input of Range should be %s"); | |||||
| // check all 3 inputs are same type | |||||
| if (!IsIdentidityOrSubclass(range_start_type, range_end_type) || | |||||
| !IsIdentidityOrSubclass(range_end_type, range_delta_type)) { | |||||
| MS_LOG(EXCEPTION) << "All inputs must have same type, but got: " << args_spec_list[0]->type_name() << ", " | |||||
| << args_spec_list[1]->type_name() << ", and " << args_spec_list[2]->type_name(); | |||||
| } | |||||
| int64_t max_output_length = -1; | |||||
| ValuePtr max_output_length_ptr = primitive->GetAttr("maxlen"); | |||||
| max_output_length = GetValue<int64_t>(max_output_length_ptr); | |||||
| ShapeVector output_shape = {Shape::SHP_ANY}; | |||||
| ShapeVector min_shape = {1}; | |||||
| ShapeVector max_shape = {max_output_length}; | |||||
| ShapePtr shape = std::make_shared<Shape>(output_shape, min_shape, max_shape); | |||||
| return std::make_shared<AbstractTensor>(range_start_type, shape); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,7 +1,7 @@ | |||||
| /** | /** | ||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | ||||
| * | * | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | ||||
| {prim::kPrimSplit, {InferImplSplit, true}}, | {prim::kPrimSplit, {InferImplSplit, true}}, | ||||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | ||||
| {prim::kPrimRange, {InferImplRange, true}}, | |||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | ||||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | {prim::kPrimMakeList, {InferImplMakeList, true}}, | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -125,6 +125,7 @@ inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("Scat | |||||
| inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform"); | inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform"); | ||||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | ||||
| inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | ||||
| inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, | ||||
| UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | ||||
| Unique, GatherD, Identity) | |||||
| Unique, GatherD, Identity, Range) | |||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, | _VirtualDiv, _GetTensorSlice, | ||||
| @@ -402,6 +402,7 @@ __all__ = [ | |||||
| "ReLUV2", | "ReLUV2", | ||||
| "SparseToDense", | "SparseToDense", | ||||
| "MatrixInverse", | "MatrixInverse", | ||||
| "Range", | |||||
| ] | ] | ||||
| __all__.sort() | __all__.sort() | ||||
| @@ -1,6 +1,6 @@ | |||||
| # coding: utf-8 | # coding: utf-8 | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -4722,3 +4722,63 @@ class Identity(PrimitiveWithInfer): | |||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| return out | return out | ||||
| class Range(PrimitiveWithCheck): | |||||
| r""" | |||||
| Creates a sequence of numbers that begins at `start` and extends by increments of | |||||
| `delta` up to but not including `limit`. | |||||
| The types of all 3 inputs must be the same. The type of the resulting tensor is | |||||
| the same as the type of the inputs. | |||||
| Args: | |||||
| maxlen (int): Memory that can fit `maxlen` many elements | |||||
| will be allocated for the output. Optional, must be positive, defaults to 1000000. | |||||
| If the output has more than `maxlen` elements, a runtime error | |||||
| will occur. | |||||
| Inputs: | |||||
| - **start** (Tensor) - A scalar Tensor. The first number in the sequence. Must have | |||||
| type: int32 or float32 | |||||
| - **limit** (Tensor) - A scalar Tensor. Upper limit of the sequence, exclusive. Must | |||||
| have type: int32 or float32 | |||||
| - **delta** (Tensor) - A scalar Tensor. Number that increments `start`. Must have | |||||
| type: int32 or float32 | |||||
| Outputs: | |||||
| A 1-D Tensor, with the same type as the inputs. | |||||
| Examples: | |||||
| >>> start = Tensor(0) | |||||
| >>> limit = Tensor(10) | |||||
| >>> delta = Tensor(4) | |||||
| >>> output = ops.Range()(start, limit, delta) | |||||
| >>> print(output) | |||||
| [0, 4, 8] | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, maxlen=1000000): | |||||
| self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output']) | |||||
| validator.check_value_type("maxlen", maxlen, [int], self.name) | |||||
| validator.check_positive_int(maxlen, "maxlen", self.name) | |||||
| self.maxlen = maxlen | |||||
| self.add_prim_attr('maxlen', maxlen) | |||||
| self.add_prim_attr("dynamic_shape_depends", [0]) | |||||
| self.add_prim_attr("dynamic_shape_depends", [1]) | |||||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||||
| def check_shape(self, start_shape, limit_shape, delta_shape): | |||||
| validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name) | |||||
| validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name) | |||||
| validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name) | |||||
| def check_dtype(self, start_dtype, limit_dtype, delta_dtype): | |||||
| valid_dtypes = [mstype.int32, mstype.float32] | |||||
| inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} | |||||
| validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) | |||||
| @@ -0,0 +1,93 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class RangeNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(RangeNet, self).__init__() | |||||
| self.range = P.Range() | |||||
| def construct(self, s, e, d): | |||||
| return self.range(s, e, d) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_range_int(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() | |||||
| np_expected = np.array([2, 3, 4]) | |||||
| np.testing.assert_array_equal(ms_out, np_expected) | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(-24, mstype.int32), Tensor(1, mstype.int32), Tensor(4, mstype.int32)).asnumpy() | |||||
| np_expected = np.array([-24, -20, -16, -12, -8, -4, 0]) | |||||
| np.testing.assert_array_equal(ms_out, np_expected) | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(8, mstype.int32), Tensor(1, mstype.int32), Tensor(-1, mstype.int32)).asnumpy() | |||||
| np_expected = np.array([8, 7, 6, 5, 4, 3, 2]) | |||||
| np.testing.assert_array_equal(ms_out, np_expected) | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(3, mstype.int32), Tensor(-11, mstype.int32), Tensor(-5, mstype.int32)).asnumpy() | |||||
| np_expected = np.array([3, -2, -7]) | |||||
| np.testing.assert_array_equal(ms_out, np_expected) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_range_float(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(2.3, mstype.float32), Tensor(5.5, mstype.float32), Tensor(1.2, mstype.float32)).asnumpy() | |||||
| np_expected = np.array([2.3, 3.5, 4.7]) | |||||
| np.testing.assert_array_almost_equal(ms_out, np_expected) | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(-4, mstype.float32), Tensor(-1, mstype.float32), Tensor(1.5, mstype.float32)).asnumpy() | |||||
| np_expected = np.array([-4.0, -2.5]) | |||||
| np.testing.assert_array_almost_equal(ms_out, np_expected) | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(8.0, mstype.float32), Tensor(1.0, mstype.float32), Tensor(-1.0, mstype.float32)).asnumpy() | |||||
| np_expected = np.array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0]) | |||||
| np.testing.assert_array_almost_equal(ms_out, np_expected) | |||||
| range_net = RangeNet() | |||||
| ms_out = range_net(Tensor(1.5, mstype.float32), Tensor(-1, mstype.float32), Tensor(-18.9, mstype.float32)).asnumpy() | |||||
| np_expected = np.array([1.5]) | |||||
| np.testing.assert_array_almost_equal(ms_out, np_expected) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_range_invalid_max_output_length(): | |||||
| with pytest.raises(ValueError): | |||||
| _ = P.Range(0) | |||||
| _ = P.Range(-1) | |||||
| _ = P.Range(None) | |||||
| _ = P.Range('5') | |||||