comment fix docString fix added asserts in test file atop np checks lint lint-2 lint3tags/v1.1.0
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/linspace.cuh" | |||||
| #include <iostream> | |||||
| template <typename T> | |||||
| __global__ void LinSpaceKernel(const T *start, const T *stop, const size_t value_count, T *output) { | |||||
| T add_value = ((*stop - *start) / (value_count - 1)); | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < value_count; i += gridDim.x * blockDim.x) { | |||||
| output[i] = *start + (add_value * i); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void calLinSpace(const T *start, const T *stop, const size_t value_count, T *output, cudaStream_t cuda_stream) { | |||||
| LinSpaceKernel<<<GET_BLOCKS(value_count), GET_THREADS, 0, cuda_stream>>>(start, stop, value_count, output); | |||||
| } | |||||
| template void calLinSpace<float>(const float *start, const float *stop, const size_t value_count, float *output, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,23 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void calLinSpace(const T *start, const T *stop, const size_t value_count, T *output, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_LINSPACE_IMPL_CU_H_ | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/math/linspace.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(LinSpace, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| LinSpaceGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,102 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <iostream> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/linspace.cuh" | |||||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class LinSpaceGpuKernel : public GpuKernel { | |||||
| public: | |||||
| LinSpaceGpuKernel() { ResetResource(); } | |||||
| ~LinSpaceGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| VARIABLE_NOT_USED(workspace); | |||||
| T *start_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| T *stop_addr = GetDeviceAddress<T>(inputs, 1); | |||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| calLinSpace(start_addr, stop_addr, value_count_, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 3) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but DynamicLinSpace needs 3 inputs."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but DynamicLinSpace needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| auto input_1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | |||||
| auto input_2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); | |||||
| // error checking input data | |||||
| if ((input_1.size() != 0) || (input_2.size() != 0)) { | |||||
| MS_LOG(ERROR) << "For LinShape " | |||||
| << "both start and end must be 0-D Tensors. Got " << input_1.size() << " and " << input_2.size() | |||||
| << "."; | |||||
| return false; | |||||
| } | |||||
| auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); | |||||
| if (value_count.size() != 1) { | |||||
| MS_LOG(ERROR) << "For LinShape, output shape incorrect rank. Expect Rank: 1, got Rank: " << value_count.size() | |||||
| << "."; | |||||
| } | |||||
| value_count_ = value_count[0]; | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void ResetResource() noexcept override { | |||||
| value_count_ = 0; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(sizeof(T)); // Scalar tensor | |||||
| input_size_list_.push_back(sizeof(T)); // Scalar tensor | |||||
| output_size_list_.push_back(value_count_ * sizeof(T)); | |||||
| } | |||||
| private: | |||||
| size_t value_count_; | |||||
| int num_input_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LINSPACE_GPU_KERNEL_H_ | |||||
| @@ -247,6 +247,8 @@ AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr & | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -167,5 +167,47 @@ AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const Pri | |||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); | return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); | ||||
| } | } | ||||
| AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | |||||
| auto start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(start); | |||||
| MS_EXCEPTION_IF_NULL(start->shape()); | |||||
| auto stop = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(stop); | |||||
| MS_EXCEPTION_IF_NULL(stop->shape()); | |||||
| (void)CheckTensorDType(start, {kFloat32}, "Input 0 (start) for LinSpace should be %s"); | |||||
| (void)CheckTensorDType(stop, {kFloat32}, "Input 1 (stop) for LinSpace should be %s"); | |||||
| ShapeVector shape; | |||||
| ShapeVector max_shape; | |||||
| ShapeVector min_shape; | |||||
| int64_t num_val = 0; | |||||
| // 3rd input is a Tensor when LinSpace is a dynamic shape operator | |||||
| if (args_spec_list[2]->isa<AbstractTensor>()) { | |||||
| auto num = args_spec_list[2]->cast<AbstractTensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(num); | |||||
| auto num_value_ptr = num->BuildValue(); | |||||
| MS_EXCEPTION_IF_NULL(num_value_ptr); | |||||
| auto num_tensor = num_value_ptr->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(num_tensor); | |||||
| num_val = *static_cast<int64_t *>(num_tensor->data_c()); | |||||
| } else if (args_spec_list[2]->isa<AbstractScalar>()) { | |||||
| auto num = args_spec_list[2]->cast<AbstractScalarPtr>(); | |||||
| num_val = GetValue<int64_t>(num->BuildValue()); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name(); | |||||
| } | |||||
| shape.emplace_back(num_val); | |||||
| if (shape[0] < 0) { | |||||
| MS_LOG(EXCEPTION) << "num must be >= 0 in LinSpace"; | |||||
| } | |||||
| max_shape.emplace_back(num_val); | |||||
| min_shape.emplace_back(num_val); | |||||
| AbstractTensorPtr ret = | |||||
| std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| return ret; | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,6 +45,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimEqual, {InferImplEqual, true}}, | {prim::kPrimEqual, {InferImplEqual, true}}, | ||||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | {prim::kPrimMinimum, {InferImplMinimum, true}}, | ||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | ||||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||||
| // Array | // Array | ||||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | ||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | ||||
| @@ -241,6 +241,7 @@ inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | |||||
| inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | ||||
| inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | ||||
| inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | ||||
| inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | |||||
| // Statements | // Statements | ||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | ||||
| @@ -3946,15 +3946,19 @@ class Eps(PrimitiveWithInfer): | |||||
| class LinSpace(PrimitiveWithInfer): | class LinSpace(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Generates values in an interval and returns the corresponding interpolation accroding to assist. | |||||
| Generates values in an interval (inclusive of start and stop) and returns the corresponding | |||||
| interpolated array with **num** number of ticks. | |||||
| Inputs: | Inputs: | ||||
| - **start** (Tensor[float32]) - The start of interval, With shape of 0-D. | |||||
| - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D. | |||||
| - **num** (int) - Ticks number in the interval, the ticks include start and stop value. | |||||
| - **start** (Tensor[float32]) - Start value of interval, With shape of 0-D. | |||||
| - **stop** (Tensor[float32]) - Last value of interval, With shape of 0-D. | |||||
| - **num** (int) - Number of ticks in the interval, inclusive of start and stop. | |||||
| Outputs: | Outputs: | ||||
| Tensor, has the same shape as `assist`. | |||||
| Tensor, has the same shape as `start`. | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> linspace = P.LinSpace() | >>> linspace = P.LinSpace() | ||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.context as context | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.nn import Cell | |||||
| from mindspore.ops import operations as P | |||||
| class LinSpaceNet(Cell): | |||||
| def __init__(self, num): | |||||
| super(LinSpaceNet, self).__init__() | |||||
| self.ls_op = P.LinSpace() | |||||
| self.num = num | |||||
| def construct(self, start, stop): | |||||
| output = self.ls_op(start, stop, self.num) | |||||
| return output | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_lin_space_1(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| start_np = 5 | |||||
| stop_np = 150 | |||||
| num_np = 12 | |||||
| start = Tensor(start_np, dtype=mstype.float32) | |||||
| stop = Tensor(stop_np, dtype=mstype.float32) | |||||
| num = num_np | |||||
| ls_op = P.LinSpace() | |||||
| result_ms = ls_op(start, stop, num).asnumpy() | |||||
| result_np = np.linspace(start_np, stop_np, num_np) | |||||
| assert np.allclose(result_ms, result_np) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_lin_shape_2(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||||
| start_np = -25 | |||||
| stop_np = 147 | |||||
| num_np = 10 | |||||
| start = Tensor(start_np, dtype=mstype.float32) | |||||
| stop = Tensor(stop_np, dtype=mstype.float32) | |||||
| num = num_np | |||||
| ls_op = P.LinSpace() | |||||
| result_ms = ls_op(start, stop, num).asnumpy() | |||||
| result_np = np.linspace(start_np, stop_np, num_np) | |||||
| assert np.allclose(result_ms, result_np) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_lin_shape_3(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| start_np = 25 | |||||
| stop_np = -147 | |||||
| num_np = 20 | |||||
| start = Tensor(start_np, dtype=mstype.float32) | |||||
| stop = Tensor(stop_np, dtype=mstype.float32) | |||||
| net = LinSpaceNet(num_np) | |||||
| result_ms = net(start, stop).asnumpy() | |||||
| result_np = np.linspace(start_np, stop_np, num_np) | |||||
| assert np.allclose(result_ms, result_np) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_lin_shape_4(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| start_np = -25.3 | |||||
| stop_np = -147 | |||||
| num_np = 36 | |||||
| start = Tensor(start_np, dtype=mstype.float32) | |||||
| stop = Tensor(stop_np, dtype=mstype.float32) | |||||
| net = LinSpaceNet(num_np) | |||||
| result_ms = net(start, stop).asnumpy() | |||||
| result_np = np.linspace(start_np, stop_np, num_np) | |||||
| assert np.allclose(result_ms, result_np) | |||||