Merge pull request !1823 from jiangjinsheng/vm_ConfusionMatrixtags/v0.5.0-beta
| @@ -237,3 +237,4 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe | |||||
| from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | ||||
| from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | ||||
| from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | ||||
| from .confusion_matrix import _confusion_matrix_tbe | |||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ConfusionMatrix op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| confusion_matrix_op_info = TBERegOp("ConfusionMatrix") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("confusion_matrix.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("confusion_matrix") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("num_classes", "required", "int", "all") \ | |||||
| .attr("dtype", "required", "str", "all") \ | |||||
| .input(0, "labels", False, "required", "all") \ | |||||
| .input(1, "predictions", False, "required", "all") \ | |||||
| .input(2, "weights", False, "optional", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(confusion_matrix_op_info) | |||||
| def _confusion_matrix_tbe(): | |||||
| """ConfusionMatrix TBE register""" | |||||
| return | |||||
| @@ -73,7 +73,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | ||||
| ApplyProximalAdagrad, SparseApplyProximalAdagrad, | ApplyProximalAdagrad, SparseApplyProximalAdagrad, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell) | ||||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | |||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | |||||
| CheckValid, MakeRefKey, CheckBprop, ConfusionMatrix) | |||||
| from . import _quant_ops | from . import _quant_ops | ||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| from .thor_ops import * | from .thor_ops import * | ||||
| @@ -287,7 +288,8 @@ __all__ = [ | |||||
| "BesselI1e", | "BesselI1e", | ||||
| "Atan", | "Atan", | ||||
| "Atanh", | "Atanh", | ||||
| "BasicLSTMCell" | |||||
| "BasicLSTMCell", | |||||
| "ConfusionMatrix" | |||||
| ] | ] | ||||
| __all__.extend(_quant_ops.__all__) | __all__.extend(_quant_ops.__all__) | ||||
| @@ -366,3 +366,50 @@ class CheckBprop(PrimitiveWithInfer): | |||||
| raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," | raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," | ||||
| f" but got {xdtype}.") | f" but got {xdtype}.") | ||||
| return xdtypes | return xdtypes | ||||
| class ConfusionMatrix(PrimitiveWithInfer): | |||||
| r""" | |||||
| Calculate the confusion matrix from labels and predictions. | |||||
| Args: | |||||
| num_classes (int): The num of classes. | |||||
| dtype (str): Data type of confusion matrix. Default: 'int32'. | |||||
| Inputs: | |||||
| - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer. | |||||
| - **predictions** (Tensor) - the labels from prediction, tensor of 1-D. | |||||
| the shape same as `labels` and the dtype must be non-negative Integer. | |||||
| - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`. | |||||
| Outputs: | |||||
| Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`). | |||||
| Examples: | |||||
| >>> confusion_matrix = P.ConfusionMatrix(4) | |||||
| >>> labels = Tensor([0, 1, 1, 3], mindspore.int32) | |||||
| >>> predictions = Tensor([1, 2, 1, 3], mindspore.int32) | |||||
| >>> confusion_matrix(labels, predictions) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, num_classes, dtype="int32"): | |||||
| validator.check_value_type("num_classes", num_classes, [int], self.name) | |||||
| validator.check_value_type("dtype", dtype, [str], self.name) | |||||
| def infer_shape(self, labels, predictions, weights=None): | |||||
| validator.check('labels dimension', len(labels), '', 1, Rel.EQ, self.name) | |||||
| validator.check('labels shape', labels, 'predictions shape', predictions, Rel.EQ, self.name) | |||||
| if weights is not None: | |||||
| validator.check('labels shape', labels, 'weights shape', weights, Rel.EQ, self.name) | |||||
| ret = (self.num_classes, self.num_classes) | |||||
| return ret | |||||
| def infer_dtype(self, labels, predictions, weights=None): | |||||
| validator.check_subclass('labels', labels, mstype.tensor, self.name) | |||||
| validator.check_subclass('predictions', predictions, mstype.tensor, self.name) | |||||
| if weights is not None: | |||||
| validator.check_subclass('weights', weights, mstype.tensor, self.name) | |||||
| args = {"labels": labels, "predictions": predictions} | |||||
| validator.check_tensor_type_same(args, (mstype.number_type), self.name) | |||||
| return labels | |||||
| @@ -285,6 +285,16 @@ class SpaceToBatchNDNet(Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return self.space_to_batch_nd(x) | return self.space_to_batch_nd(x) | ||||
| class ConfusionMatrixNet(Cell): | |||||
| def __init__(self): | |||||
| super(ConfusionMatrixNet, self).__init__() | |||||
| self.confusion_matrix = P.ConfusionMatrix(4, "int32") | |||||
| def construct(self, x, y): | |||||
| return self.confusion_matrix(x, y) | |||||
| test_case_array_ops = [ | test_case_array_ops = [ | ||||
| ('CustNet1', { | ('CustNet1', { | ||||
| 'block': CustNet1(), | 'block': CustNet1(), | ||||
| @@ -325,6 +335,9 @@ test_case_array_ops = [ | |||||
| ('BatchToSpaceNDNet', { | ('BatchToSpaceNDNet', { | ||||
| 'block': BatchToSpaceNDNet(), | 'block': BatchToSpaceNDNet(), | ||||
| 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), | 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), | ||||
| ('ConfusionMatrixNet', { | |||||
| 'block': ConfusionMatrixNet(), | |||||
| 'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}), | |||||
| ] | ] | ||||
| test_case_lists = [test_case_array_ops] | test_case_lists = [test_case_array_ops] | ||||