|
|
|
@@ -616,7 +616,7 @@ class MatrixDiagPart(PrimitiveWithInfer): |
|
|
|
Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) |
|
|
|
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) |
|
|
|
>>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) |
|
|
|
>>> matrix_diag_part = P.MatrixDiagPart() |
|
|
|
>>> result = matrix_diag_part(x, assist) |
|
|
|
@@ -658,11 +658,11 @@ class MatrixSetDiag(PrimitiveWithInfer): |
|
|
|
Tensor, data type same as input `x`. The shape same as `x`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) |
|
|
|
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) |
|
|
|
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) |
|
|
|
>>> matrix_set_diag = P.MatrixSetDiag() |
|
|
|
>>> result = matrix_set_diag(x, diagonal) |
|
|
|
[[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] |
|
|
|
[[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]] |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
@@ -681,10 +681,10 @@ class MatrixSetDiag(PrimitiveWithInfer): |
|
|
|
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) |
|
|
|
|
|
|
|
if x_shape[-2] < x_shape[-1]: |
|
|
|
validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape", |
|
|
|
diagonal_shape, Rel.EQ, self.name) |
|
|
|
validator.check("diagnoal shape", diagonal_shape, "x shape excluding the last dimension", |
|
|
|
x_shape[:-1], Rel.EQ, self.name) |
|
|
|
else: |
|
|
|
validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:], |
|
|
|
"diagonal shape", diagonal_shape, Rel.EQ, self.name) |
|
|
|
validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension", |
|
|
|
x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name) |
|
|
|
|
|
|
|
return assist_shape |