Browse Source

add grad impl for op MatrixInverse

tags/v1.2.0-rc1
zhouyuanshen 4 years ago
parent
commit
a078e11f66
2 changed files with 25 additions and 3 deletions
  1. +17
    -1
      mindspore/ops/_grad/grad_math_ops.py
  2. +8
    -2
      mindspore/ops/operations/math_ops.py

+ 17
- 1
mindspore/ops/_grad/grad_math_ops.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -165,6 +165,22 @@ def get_bprop_tensor_add(self):
return bprop


@bprop_getters.register(P.MatrixInverse)
def get_bprop_matrix_inverse(self):
"""Grad definition for `MatrixInverse` operation."""
batchmatmul_a = P.math_ops.BatchMatMul(transpose_a=True)
batchmatmul_b = P.math_ops.BatchMatMul(transpose_b=True)
neg = P.Neg()

def bprop(x, out, dout):
dx = batchmatmul_b(dout, out)
dx = batchmatmul_a(out, dx)
dx = neg(dx)
return dx

return bprop


@bprop_getters.register(P.Neg)
def get_bprop_neg(self):
"""Grad definition for `Neg` operation."""


+ 8
- 2
mindspore/ops/operations/math_ops.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -4131,6 +4131,9 @@ class MatrixInverse(PrimitiveWithInfer):
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
result may be returned

Note:
The parameter 'adjoint' is only supporting False right now. Because complex number is not supported at present.

Args:
adjoint (bool) : An optional bool. Default: False.

@@ -4141,6 +4144,9 @@ class MatrixInverse(PrimitiveWithInfer):
Outputs:
Tensor, has the same type and shape as input `x`.

Supported Platforms:
``GPU``

Examples:
>>> x = Tensor(np.random.uniform(-2, 2, (2, 2, 2)), mstype.float32)
>>> matrix_inverse = P.MatrixInverse(adjoint=False)
@@ -4154,7 +4160,7 @@ class MatrixInverse(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, adjoint=False):
"""Initialize MatrixInverse"""
validator.check_value_type("adjoint", adjoint, [bool], self.name)
validator.check_type_name("adjoint", adjoint, False, self.name)
self.adjoint = adjoint

def infer_dtype(self, x_dtype):


Loading…
Cancel
Save