Merge pull request !7696 from peixu_ren/custom_gputags/v1.1.0
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "determinant_triangle_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void DetTriangleKernel(T *input, T *output, size_t matrix_n_, size_t count) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| output[i] = 1; | |||||
| for (int pos = 0; pos < matrix_n_*matrix_n_; pos += matrix_n_+1) { | |||||
| output[i] *= input[i * matrix_n_ * matrix_n_ + pos]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) { | |||||
| DetTriangleKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, matrix_n_, count); | |||||
| return; | |||||
| } | |||||
| __device__ bool dev_error_res = false; | |||||
| template <typename T> | |||||
| __global__ void CheckTriangleKernel(T *input, int fill_mode_, size_t matrix_n_, size_t count) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| size_t idx = 0; | |||||
| if (fill_mode_ == 0) { // UPPER half | |||||
| for (size_t row = 0; row < matrix_n_; row++) { | |||||
| for (size_t col = row + 1; col < matrix_n_; col++) { | |||||
| idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col; | |||||
| if (static_cast<float>(input[idx]) != 0) { | |||||
| dev_error_res = false; | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else if (fill_mode_ == 1) { // LOWER half | |||||
| for (size_t row = 0; row < matrix_n_; row++) { | |||||
| for (size_t col = 0; col < row; col++) { | |||||
| idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col; | |||||
| if (static_cast<float>(input[idx]) != 0) { | |||||
| dev_error_res = false; | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| dev_error_res = false; | |||||
| return; | |||||
| } | |||||
| } | |||||
| dev_error_res = true; | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) { | |||||
| CheckTriangleKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, fill_mode_, matrix_n_, count); | |||||
| bool host_error_res = false; | |||||
| cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool)); | |||||
| return host_error_res; | |||||
| } | |||||
| template void DetTriangle<float>(float *input, float *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||||
| template void DetTriangle<half>(half *input, half *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||||
| template bool CheckTriangle<float>(float *input, int fill_mode_, size_t matrix_n_, size_t count, | |||||
| cudaStream_t cuda_stream); | |||||
| template bool CheckTriangle<half>(half *input, int fill_mode_, size_t matrix_n_, size_t count, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ | |||||
| #include <curand_kernel.h> | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| DetTriangleGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| DetTriangleGpuKernel, half) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,112 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ | |||||
| #include <cuda_runtime_api.h> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class DetTriangleGpuKernel : public GpuKernel { | |||||
| public: | |||||
| DetTriangleGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {} | |||||
| ~DetTriangleGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| VARIABLE_NOT_USED(workspace); | |||||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| if (!CheckTriangle(input_addr, fill_mode_, matrix_n_, outputs[0]->size / sizeof(T), | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr))) { | |||||
| if (fill_mode_ == 0) { | |||||
| MS_LOG(ERROR) << "The elements in the upper half of the maxtices should be all 0."; | |||||
| } else if (fill_mode_ == 1) { | |||||
| MS_LOG(ERROR) << "The elements in the lower half of the maxtices should be all 0."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "The input matrix should be either upper filled or lower filled."; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| DetTriangle(input_addr, output_addr, matrix_n_, outputs[0]->size / sizeof(T), | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but DetTriangle needs 1 inputs."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but DetTriangle needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||||
| input_size_ *= input_shape[i]; | |||||
| } | |||||
| matrix_n_ = input_shape[input_shape.size() - 1]; | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||||
| output_size_ *= output_shape[i]; | |||||
| } | |||||
| if (output_size_ != input_size_ / matrix_n_ / matrix_n_) { | |||||
| MS_LOG(ERROR) << "The output shape is wrong."; | |||||
| return false; | |||||
| } | |||||
| if (input_shape[input_shape.size() - 2] != input_shape[input_shape.size() - 1]) { | |||||
| MS_LOG(ERROR) << "The maxtices should be in shape of square."; | |||||
| return false; | |||||
| } | |||||
| fill_mode_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("fill_mode")); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| } | |||||
| private: | |||||
| size_t input_size_; | |||||
| size_t output_size_; | |||||
| size_t matrix_n_; | |||||
| int fill_mode_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ | |||||
| @@ -87,7 +87,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul | |||||
| from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | ||||
| CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | ||||
| CusMatMulCubeDenseRight, | CusMatMulCubeDenseRight, | ||||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky) | |||||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, DetTriangle) | |||||
| from .sparse_ops import SparseToDense | from .sparse_ops import SparseToDense | ||||
| from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx | from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx | ||||
| @@ -636,3 +636,22 @@ class Cholesky(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x1_dtype): | def infer_dtype(self, x1_dtype): | ||||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | ||||
| return x1_dtype | return x1_dtype | ||||
| class DetTriangle(PrimitiveWithInfer): | |||||
| """ | |||||
| Calculate the determinant of triangle matrices | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, fill_mode=0): | |||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||||
| self.fill_mode = fill_mode | |||||
| self.add_prim_attr('fill_mode', self.fill_mode) | |||||
| def infer_shape(self, x1_shape): | |||||
| out_shape = x1_shape | |||||
| del out_shape[-2:] | |||||
| return out_shape | |||||
| def infer_dtype(self, x1_dtype): | |||||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||||
| return x1_dtype | |||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as mstype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, fill_mode=0): | |||||
| super(Net, self).__init__() | |||||
| self.det_triangle = P.DetTriangle(fill_mode=fill_mode) | |||||
| def construct(self, x): | |||||
| return self.det_triangle(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_net_1D(): | |||||
| fill_mode = 0 | |||||
| input_x = np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]).astype(np.float32) | |||||
| net = Net(fill_mode=fill_mode) | |||||
| tx = Tensor(input_x, mstype.float32) | |||||
| output = net(tx) | |||||
| assert output == 18 | |||||