From 96183f8b124d8e852c745c55716c2b94bae99f01 Mon Sep 17 00:00:00 2001 From: linqingke Date: Sat, 31 Oct 2020 16:36:12 +0800 Subject: [PATCH] new add matrix inverse. --- .../gpu/math/matrix_inverse_gpu_kernel.cc | 26 ++++ .../gpu/math/matrix_inverse_gpu_kernel.h | 145 ++++++++++++++++++ mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/math_ops.py | 41 +++++ model_zoo/official/cv/psenet/README.md | 112 ++++++++------ model_zoo/official/cv/psenet/README_CN.md | 2 +- tests/st/ops/gpu/test_matrix_inverse_op.py | 56 +++++++ 7 files changed, 333 insertions(+), 53 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_matrix_inverse_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.cc new file mode 100644 index 0000000000..987a4f7282 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatrixInverseGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + MatrixInverseGpuKernel, double) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.h new file mode 100644 index 0000000000..71230cd1ee --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matrix_inverse_gpu_kernel.h @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MATRIX_INVERSE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MATRIX_INVERSE_GPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class MatrixInverseGpuKernel : public GpuKernel { + public: + MatrixInverseGpuKernel() : input_size_(0), adjoint_(false), batch_size_(1), size_(1) {} + ~MatrixInverseGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + auto lu_batch_addr = GetDeviceAddress(workspace, 0); + auto inv_batch_addr = GetDeviceAddress(workspace, 1); + auto pivo_addr = GetDeviceAddress(workspace, 2); + auto info_addr = GetDeviceAddress(workspace, 3); + + int len = SizeToInt(size_); + int batchsize = SizeToInt(batch_size_); + for (size_t i = 0; i < batch_size_; i++) { + lu_addr_[i] = input_addr + i * len * len; + inv_addr_[i] = output_addr + i * len * len; + } + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(lu_batch_addr, lu_addr_.data(), sizeof(T *) * batch_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(inv_batch_addr, inv_addr_.data(), sizeof(T *) * batch_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + if (std::is_same::value) { + CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, + cublasSgetrfBatched(handle_, len, reinterpret_cast(lu_batch_addr), len, + pivo_addr, info_addr, batchsize), + "cublas trsm batched Fail"); + CHECK_CUBLAS_RET_WITH_EXCEPT( + kernel_node_, + cublasSgetriBatched(handle_, len, reinterpret_cast(lu_batch_addr), len, pivo_addr, + reinterpret_cast(inv_batch_addr), len, info_addr, batchsize), + "cublas trsm batched Fail"); + } else if (std::is_same::value) { + CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, + cublasDgetrfBatched(handle_, len, reinterpret_cast(lu_batch_addr), len, + pivo_addr, info_addr, batchsize), + "cublas trsm batched Fail"); + CHECK_CUBLAS_RET_WITH_EXCEPT( + kernel_node_, + cublasDgetriBatched(handle_, len, reinterpret_cast(lu_batch_addr), len, pivo_addr, + reinterpret_cast(inv_batch_addr), len, info_addr, batchsize), + "cublas trsm batched Fail"); + } else { + MS_LOG(EXCEPTION) << "The data type entered must be float or double."; + } + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + + if (input_shape.empty() || input_shape.size() < 2) { + MS_LOG(EXCEPTION) << "The dim entered needs to be greater than 2, but " << input_shape.size() << " was taken"; + } + size_t last_index = input_shape.size() - 1; + if (input_shape[last_index] != input_shape[last_index - 1]) { + MS_LOG(EXCEPTION) << "The last two dimensions of the input matrix should be equal!"; + } + size_ = input_shape[last_index]; + for (size_t i = 0; i < last_index - 1; i++) { + batch_size_ *= input_shape[i]; + } + + input_size_ = sizeof(T); + for (auto dim : input_shape) { + input_size_ *= dim; + } + adjoint_ = GetAttr(kernel_node, "adjoint"); + lu_addr_.resize(batch_size_); + inv_addr_.resize(batch_size_); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + size_t lu_size = batch_size_ * sizeof(T *); + workspace_size_list_.push_back(lu_size); + size_t inv_size = batch_size_ * sizeof(T *); + workspace_size_list_.push_back(inv_size); + size_t pivo_size = batch_size_ * size_ * sizeof(int); + workspace_size_list_.push_back(pivo_size); + size_t info_size = batch_size_ * sizeof(int); + workspace_size_list_.push_back(info_size); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; + bool adjoint_; + cublasHandle_t handle_; + size_t batch_size_; + size_t size_; + std::vector lu_addr_; + std::vector inv_addr_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MATRIX_INVERSE_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index a4d3bfafee..e72f7cb985 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -54,7 +54,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, - Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) + Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, + MatrixInverse) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, @@ -400,6 +401,7 @@ __all__ = [ "Pull", "ReLUV2", "SparseToDense", + "MatrixInverse", ] __all__.sort() diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 87999a1eab..2b5ee934e8 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -4008,3 +4008,44 @@ class LinSpace(PrimitiveWithInfer): 'dtype': start['dtype'], 'value': None} return out + +class MatrixInverse(PrimitiveWithInfer): + """ + Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown + result may be returned + + Args: + adjoint (bool) : An optional bool. Default: False. + + Inputs: + - **x** (Tensor) - A matrix to be calculated. + types: float32, double. + + Outputs: + Tensor, has the same type and shape as input `x`. + + Examples: + >>> x = Tensor(np.random.uniform(-2, 2, (2, 2, 2)), mstype.float32) + >>> matrix_inverse = P.MatrixInverse(adjoint=False) + >>> result = matrix_inverse(x) + [[[ 0.6804 0.8111] + [-2.3257 -1.0616] + [[-0.7074 -0.4963] + [0.1896 -1.5285]]] + """ + + @prim_attr_register + def __init__(self, adjoint=False): + """Initialize MatrixInverse""" + validator.check_value_type("adjoint", adjoint, [bool], self.name) + self.adjoint = adjoint + + def infer_dtype(self, x_dtype): + valid_type = [mstype.float32, mstype.double] + validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_type, self.name) + return x_dtype + + def infer_shape(self, x_shape): + validator.check_int(len(x_shape), 2, Rel.GE, self.name, None) + validator.check_equal_int(x_shape[-1], x_shape[-2], self.name, None) + return x_shape diff --git a/model_zoo/official/cv/psenet/README.md b/model_zoo/official/cv/psenet/README.md index 8893fbab64..c2a2e17fd5 100644 --- a/model_zoo/official/cv/psenet/README.md +++ b/model_zoo/official/cv/psenet/README.md @@ -5,7 +5,7 @@ - [Features](#features) - [Mixed Precision](#mixed-precision) - [Environment Requirements](#environment-requirements) -- [Quick Start](#quick-start) +- [Quick Start](#quick-start) - [Script Description](#script-description) - [Script and Sample Code](#script-and-sample-code) - [Script Parameters](#script-parameters) @@ -19,19 +19,20 @@ - [Evaluation Performance](#evaluation-performance) - [Inference Performance](#evaluation-performance) - [How to use](#how-to-use) - - [Inference](#inference) + - [Inference](#inference) - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) - - [Transfer Learning](#transfer-learning) - + - [Transfer Learning](#transfer-learning) # [PSENet Description](#contents) -With the development of convolutional neural network, scene text detection technology has been developed rapidly. However, there are still two problems in this algorithm, which hinders its application in industry. On the one hand, most of the existing algorithms require quadrilateral bounding boxes to accurately locate arbitrary shape text. On the other hand, two adjacent instances of text can cause error detection overwriting both instances. Traditionally, a segmentation-based approach can solve the first problem, but usually not the second. To solve these two problems, a new PSENet (PSENet) is proposed, which can accurately detect arbitrary shape text instances. More specifically, PSENet generates different scale kernels for each text instance and gradually expands the minimum scale kernel to a text instance with full shape. Because of the large geometric margins between the minimum scale kernels, our method can effectively segment closed text instances, making it easier to detect arbitrary shape text instances. The effectiveness of PSENet has been verified by numerous experiments on CTW1500, full text, ICDAR 2015, and ICDAR 2017 MLT. -[Paper](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Shape_Robust_Text_Detection_With_Progressive_Scale_Expansion_Network_CVPR_2019_paper.html): Wenhai Wang, Enze Xie, Xiang Li, Wenbo Hou, Tong Lu, Gang Yu, Shuai Shao; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9336-9345 +With the development of convolutional neural network, scene text detection technology has been developed rapidly. However, there are still two problems in this algorithm, which hinders its application in industry. On the one hand, most of the existing algorithms require quadrilateral bounding boxes to accurately locate arbitrary shape text. On the other hand, two adjacent instances of text can cause error detection overwriting both instances. Traditionally, a segmentation-based approach can solve the first problem, but usually not the second. To solve these two problems, a new PSENet (PSENet) is proposed, which can accurately detect arbitrary shape text instances. More specifically, PSENet generates different scale kernels for each text instance and gradually expands the minimum scale kernel to a text instance with full shape. Because of the large geometric margins between the minimum scale kernels, our method can effectively segment closed text instances, making it easier to detect arbitrary shape text instances. The effectiveness of PSENet has been verified by numerous experiments on CTW1500, full text, ICDAR 2015, and ICDAR 2017 MLT. +[Paper](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Shape_Robust_Text_Detection_With_Progressive_Scale_Expansion_Network_CVPR_2019_paper.html): Wenhai Wang, Enze Xie, Xiang Li, Wenbo Hou, Tong Lu, Gang Yu, Shuai Shao; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9336-9345 # PSENet Example + ## Description + Progressive Scale Expansion Network (PSENet) is a text detector which is able to well detect the arbitrary-shape text in natural scene. # [Dataset](#contents) @@ -39,23 +40,26 @@ Progressive Scale Expansion Network (PSENet) is a text detector which is able to Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. Dataset used: [ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization) -A training set of 1000 images containing about 4500 readable words +A training set of 1000 images containing about 4500 readable words A testing set containing about 2000 readable words # [Environment Requirements](#contents) + - Hardware(Ascend) - - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. + - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Framework - - [MindSpore](http://www.mindspore.cn/install/en) + - [MindSpore](http://www.mindspore.cn/install/en) - For more information, please check the resources below: - - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - install Mindspore - install [pyblind11](https://github.com/pybind/pybind11) -- install [Opencv3.4](https://docs.opencv.org/3.4.9/d7/d9f/tutorial_linux_install.html) +- install [Opencv3.4](https://docs.opencv.org/3.4.9/) # [Quick Start](#contents) -After installing MindSpore via the official website, you can start training and evaluation as follows: + +After installing MindSpore via the official website, you can start training and evaluation as follows: + ```python # run distributed training example sh scripts/run_distribute_train.sh rank_table_file pretrained_model.ckpt @@ -83,34 +87,34 @@ sh scripts/run_eval_ascend.sh # [Script Description](#contents) ## [Script and Sample Code](#contents) -``` + +```path └── PSENet - ├── README.md // descriptions about PSENet - ├── scripts - ├── run_distribute_train.sh // shell script for distributed - └── run_eval_ascend.sh // shell script for evaluation - ├── src - ├── __init__.py - ├── ETSNET - ├── __init__.py - ├── base.py // convolution and BN operator - ├── dice_loss.py // calculate PSENet loss value - ├── etsnet.py // Subnet in PSENet - ├── fpn.py // Subnet in PSENet - ├── resnet50.py // Subnet in PSENet - ├── pse // Subnet in PSENet + ├── README.md // descriptions about PSENet + ├── scripts + ├── run_distribute_train.sh // shell script for distributed + └── run_eval_ascend.sh // shell script for evaluation + ├──src + ├── __init__.py + ├── ETSNET + ├── __init__.py + ├── base.py // convolution and BN operator + ├── dice_loss.py // calculate PSENet loss value + ├── etsnet.py // Subnet in PSENet + ├── fpn.py // Subnet in PSENet + ├── resnet50.py // Subnet in PSENet + ├── pse // Subnet in PSENet ├── __init__.py ├── adaptor.cpp ├── adaptor.h ├── Makefile - ├── config.py // parameter configuration - ├── dataset.py // creating dataset - ├── lr_schedule.py // learning ratio generation - └── network_define.py // PSENet architecture - ├── export.py // export mindir file - ├── mindspore_hub_conf.py // hub config file - ├── test.py // test script - └── train.py // training script + ├──config.py // parameter configuration + ├──dataset.py // creating dataset + ├──network_define.py // learning ratio generation + ├──export.py // export mindir file + ├──mindspore_hub_conf.py // hub config file + ├──test.py // test script + ├──train.py // training script ``` @@ -120,26 +124,26 @@ sh scripts/run_eval_ascend.sh Major parameters in train.py and config.py are: --pre_trained: Whether training from scratch or training based on the - pre-trained model.Optional values are True, False. + pre-trained model.Optional values are True, False. --device_id: Device ID used to train or evaluate the dataset. Ignore it when you use train.sh for distributed training. --device_num: devices used when you use train.sh for distributed training. ``` - ## [Training Process](#contents) ### Distributed Training -``` + +```shell sh scripts/run_distribute_train.sh rank_table_file pretrained_model.ckpt ``` rank_table_file which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). -The above shell script will run distribute training in the background. You can view the results through the file +The above shell script will run distribute training in the background. You can view the results through the file `device[X]/test_*.log`. The loss value will be achieved as follows: -``` +```log # grep "epoch: " device_*/loss.log device_0/log:epoch: 1, step: 20, loss is 0.80383 device_0/log:epcoh: 2, step: 40, loss is 0.77951 @@ -150,25 +154,32 @@ device_1/log:epcoh: 2, step: 40, loss is 0.76629 ``` ## [Evaluation Process](#contents) + ### run test code + python test.py --ckpt=./device*/ckpt*/ETSNet-*.ckpt ### Eval Script for ICDAR2015 + #### Usage -+ step 1: download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization). -+ step 2: click "My Methods" button,then download Evaluation Scripts. -+ step 3: it is recommended to symlink the eval method root to $MINDSPORE/model_zoo/psenet/eval_ic15/. if your folder structure is different,you may need to change the corresponding paths in eval script files. -``` + +step 1: download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization). +step 2: click "My Methods" button,then download Evaluation Scripts. +step 3: it is recommended to symlink the eval method root to $MINDSPORE/model_zoo/psenet/eval_ic15/. if your folder structure is different,you may need to change the corresponding paths in eval script files. + +```shell sh ./script/run_eval_ascend.sh.sh ``` + #### Result -Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0} +Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0} # [Model Description](#contents) + ## [Performance](#contents) -### Evaluation Performance +### Evaluation Performance | Parameters | PSENet | | -------------------------- | ----------------------------------------------------------- | @@ -186,8 +197,7 @@ Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean | Total time | 1pc: 75.48 h; 8pcs: 10.01 h | | Parameters (M) | 27.36 | | Checkpoint for Fine tuning | 109.44M (.ckpt file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/psenet | - +| Scripts | | ### Inference Performance @@ -207,11 +217,11 @@ Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example: -``` +```python # Load unseen dataset for inference dataset = dataset.create_dataset(cfg.data_path, 1, False) -# Define model +# Define model config.INFERENCE = False net = ETSNet(config) net = net.set_train() diff --git a/model_zoo/official/cv/psenet/README_CN.md b/model_zoo/official/cv/psenet/README_CN.md index dfd1422c33..b94633f53f 100644 --- a/model_zoo/official/cv/psenet/README_CN.md +++ b/model_zoo/official/cv/psenet/README_CN.md @@ -56,7 +56,7 @@ - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) - 安装Mindspore - 安装[pyblind11](https://github.com/pybind/pybind11) -- 安装[Opencv3.4](https://docs.opencv.org/3.4.9/d7/d9f/tutory_linux_install.html) +- 安装[Opencv3.4](https://docs.opencv.org/3.4.9/) # 快速入门 diff --git a/tests/st/ops/gpu/test_matrix_inverse_op.py b/tests/st/ops/gpu/test_matrix_inverse_op.py new file mode 100644 index 0000000000..75bb62182a --- /dev/null +++ b/tests/st/ops/gpu/test_matrix_inverse_op.py @@ -0,0 +1,56 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either matrix_inverseress or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from numpy.linalg import inv +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetMatrixInverse(nn.Cell): + def __init__(self): + super(NetMatrixInverse, self).__init__() + self.matrix_inverse = P.MatrixInverse() + + def construct(self, x): + return self.matrix_inverse(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_matrix_inverse(): + x0_np = np.random.uniform(-2, 2, (3, 4, 4)).astype(np.float32) + x0 = Tensor(x0_np) + expect0 = inv(x0_np) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + matrix_inverse = NetMatrixInverse() + output0 = matrix_inverse(x0) + diff0 = output0.asnumpy() - expect0 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + matrix_inverse = NetMatrixInverse() + output0 = matrix_inverse(x0) + diff0 = output0.asnumpy() - expect0 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape