diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index f7115e1848..0b2b1fe612 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -1008,9 +1008,11 @@ class MatrixSetDiag(Cell): dimensions :math:`[I, J, K, ..., min(M, N)]`. Then the output is a tensor of rank :math:`k+1` with dimensions :math:`[I, J, K, ..., M, N]` where: - :math:`output[i, j, k, ..., m, n] = diagnoal[i, j, k, ..., n]\ for\ m == n` + .. math:: + output[i, j, k, ..., m, n] = diagnoal[i, j, k, ..., n]\ for\ m == n - :math:`output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n]\ for\ m != n` + .. math:: + output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n]\ for\ m != n Inputs: - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types: