Merge pull request !7696 from peixu_ren/custom_gputags/v1.1.0
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "determinant_triangle_impl.cuh" | |||
| template <typename T> | |||
| __global__ void DetTriangleKernel(T *input, T *output, size_t matrix_n_, size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = 1; | |||
| for (int pos = 0; pos < matrix_n_*matrix_n_; pos += matrix_n_+1) { | |||
| output[i] *= input[i * matrix_n_ * matrix_n_ + pos]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) { | |||
| DetTriangleKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, matrix_n_, count); | |||
| return; | |||
| } | |||
| __device__ bool dev_error_res = false; | |||
| template <typename T> | |||
| __global__ void CheckTriangleKernel(T *input, int fill_mode_, size_t matrix_n_, size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| size_t idx = 0; | |||
| if (fill_mode_ == 0) { // UPPER half | |||
| for (size_t row = 0; row < matrix_n_; row++) { | |||
| for (size_t col = row + 1; col < matrix_n_; col++) { | |||
| idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col; | |||
| if (static_cast<float>(input[idx]) != 0) { | |||
| dev_error_res = false; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| } else if (fill_mode_ == 1) { // LOWER half | |||
| for (size_t row = 0; row < matrix_n_; row++) { | |||
| for (size_t col = 0; col < row; col++) { | |||
| idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col; | |||
| if (static_cast<float>(input[idx]) != 0) { | |||
| dev_error_res = false; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| dev_error_res = false; | |||
| return; | |||
| } | |||
| } | |||
| dev_error_res = true; | |||
| return; | |||
| } | |||
| template <typename T> | |||
| bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) { | |||
| CheckTriangleKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, fill_mode_, matrix_n_, count); | |||
| bool host_error_res = false; | |||
| cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool)); | |||
| return host_error_res; | |||
| } | |||
| template void DetTriangle<float>(float *input, float *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||
| template void DetTriangle<half>(half *input, half *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||
| template bool CheckTriangle<float>(float *input, int fill_mode_, size_t matrix_n_, size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template bool CheckTriangle<half>(half *input, int fill_mode_, size_t matrix_n_, size_t count, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ | |||
| #include <curand_kernel.h> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| DetTriangleGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| DetTriangleGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,112 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ | |||
| #include <cuda_runtime_api.h> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class DetTriangleGpuKernel : public GpuKernel { | |||
| public: | |||
| DetTriangleGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {} | |||
| ~DetTriangleGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| if (!CheckTriangle(input_addr, fill_mode_, matrix_n_, outputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr))) { | |||
| if (fill_mode_ == 0) { | |||
| MS_LOG(ERROR) << "The elements in the upper half of the maxtices should be all 0."; | |||
| } else if (fill_mode_ == 1) { | |||
| MS_LOG(ERROR) << "The elements in the lower half of the maxtices should be all 0."; | |||
| } else { | |||
| MS_LOG(ERROR) << "The input matrix should be either upper filled or lower filled."; | |||
| } | |||
| return false; | |||
| } | |||
| DetTriangle(input_addr, output_addr, matrix_n_, outputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but DetTriangle needs 1 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but DetTriangle needs 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| matrix_n_ = input_shape[input_shape.size() - 1]; | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||
| output_size_ *= output_shape[i]; | |||
| } | |||
| if (output_size_ != input_size_ / matrix_n_ / matrix_n_) { | |||
| MS_LOG(ERROR) << "The output shape is wrong."; | |||
| return false; | |||
| } | |||
| if (input_shape[input_shape.size() - 2] != input_shape[input_shape.size() - 1]) { | |||
| MS_LOG(ERROR) << "The maxtices should be in shape of square."; | |||
| return false; | |||
| } | |||
| fill_mode_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("fill_mode")); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t matrix_n_; | |||
| int fill_mode_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ | |||
| @@ -87,7 +87,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul | |||
| from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | |||
| CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | |||
| CusMatMulCubeDenseRight, | |||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky) | |||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, DetTriangle) | |||
| from .sparse_ops import SparseToDense | |||
| from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx | |||
| @@ -636,3 +636,22 @@ class Cholesky(PrimitiveWithInfer): | |||
| def infer_dtype(self, x1_dtype): | |||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||
| return x1_dtype | |||
| class DetTriangle(PrimitiveWithInfer): | |||
| """ | |||
| Calculate the determinant of triangle matrices | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, fill_mode=0): | |||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||
| self.fill_mode = fill_mode | |||
| self.add_prim_attr('fill_mode', self.fill_mode) | |||
| def infer_shape(self, x1_shape): | |||
| out_shape = x1_shape | |||
| del out_shape[-2:] | |||
| return out_shape | |||
| def infer_dtype(self, x1_dtype): | |||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||
| return x1_dtype | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class Net(nn.Cell): | |||
| def __init__(self, fill_mode=0): | |||
| super(Net, self).__init__() | |||
| self.det_triangle = P.DetTriangle(fill_mode=fill_mode) | |||
| def construct(self, x): | |||
| return self.det_triangle(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_net_1D(): | |||
| fill_mode = 0 | |||
| input_x = np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]).astype(np.float32) | |||
| net = Net(fill_mode=fill_mode) | |||
| tx = Tensor(input_x, mstype.float32) | |||
| output = net(tx) | |||
| assert output == 18 | |||