| @@ -12,7 +12,9 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """basic""" | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.seed import get_seed | |||
| @@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore._extends import cell_attr_register | |||
| from mindspore.common.api import ms_function | |||
| from mindspore import context | |||
| from mindspore.ops import _selected_ops | |||
| from ..cell import Cell | |||
| from .activation import get_activation | |||
| from ..._checkparam import Validator as validator | |||
| @@ -139,10 +140,8 @@ class Flatten(Cell): | |||
| the product of the remaining dimensions. | |||
| Examples: | |||
| >>> net = nn.Flatten() | |||
| >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) | |||
| >>> input.shape | |||
| (2, 2, 2) | |||
| >>> net = nn.Flatten() | |||
| >>> net(input) | |||
| [[1.2 1.2 2.1 2.1] | |||
| [2.2 2.2 3.2 3.2]] | |||
| @@ -157,9 +156,9 @@ class Flatten(Cell): | |||
| class Dense(Cell): | |||
| r""" | |||
| The fully connected layer. | |||
| The dense connected layer. | |||
| Applies dense-connected layer for the input. This layer implements the operation as: | |||
| Applies dense connected layer for the input. This layer implements the operation as: | |||
| .. math:: | |||
| \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), | |||
| @@ -190,8 +189,8 @@ class Dense(Cell): | |||
| Tensor of shape :math:`(N, out\_channels)`. | |||
| Examples: | |||
| >>> net = nn.Dense(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> net = nn.Dense(3, 4) | |||
| >>> net(input) | |||
| [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] | |||
| [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] | |||
| @@ -212,41 +211,36 @@ class Dense(Cell): | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| raise ValueError("Weight init shape error.") | |||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | |||
| self.bias = None | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| raise ValueError("Bias init shape error.") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| self.bias_add = P.BiasAdd() | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.bias_add = _selected_ops.BiasAdd() | |||
| self.activation = get_activation(activation) | |||
| self.activation_flag = self.activation is not None | |||
| def construct(self, x): | |||
| output = self.matmul(x, self.weight) | |||
| x = self.matmul(x, self.weight) | |||
| if self.has_bias: | |||
| output = self.bias_add(output, self.bias) | |||
| x = self.bias_add(x, self.bias) | |||
| if self.activation_flag: | |||
| return self.activation(output) | |||
| return output | |||
| x = self.activation(x) | |||
| return x | |||
| def extend_repr(self): | |||
| str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ | |||
| .format(self.in_channels, self.out_channels, self.weight, self.has_bias) | |||
| s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) | |||
| if self.has_bias: | |||
| str_info = str_info + ', bias={}'.format(self.bias) | |||
| s += ', has_bias={}'.format(self.has_bias) | |||
| if self.activation_flag: | |||
| str_info = str_info + ', activation={}'.format(self.activation) | |||
| return str_info | |||
| s += ', activation={}'.fomat(self.activation) | |||
| return s | |||
| @constexpr | |||
| @@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer): | |||
| class MatMul(PrimitiveWithInfer): | |||
| """ | |||
| Multiplies matrix `a` by matrix `b`. | |||
| Multiplies matrix `a` and matrix `b`. | |||
| The rank of input tensors must be `2`. | |||
| The rank of input tensors must equal to `2`. | |||
| Args: | |||
| transpose_a (bool): If True, `a` is transposed before multiplication. Default: False. | |||
| @@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer): | |||
| Tensor, the shape of the output tensor is :math:`(N, M)`. | |||
| Examples: | |||
| >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32) | |||
| >>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32) | |||
| >>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32) | |||
| >>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32) | |||
| >>> matmul = P.MatMul() | |||
| >>> output = matmul(input_x, input_y) | |||
| >>> output = matmul(input_x1, input_x2) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer): | |||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def check_shape_size(self, x, y): | |||
| if len(x) != 2 or len(y) != 2: | |||
| raise ValueError('MatMul input x, y should be the same dimension size and should be ' | |||
| + f'equal to 2, while x size = {len(x)}, y size= {len(y)}') | |||
| def check_shape_size(self, x1, x2): | |||
| if len(x1) != 2 or len(x2) != 2: | |||
| raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and ' | |||
| + f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).') | |||
| def infer_shape(self, x, y): | |||
| self.check_shape_size(x, y) | |||
| def infer_shape(self, x1, x2): | |||
| self.check_shape_size(x1, x2) | |||
| cls_name = self.name | |||
| # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | |||
| for i in range(len(x) - 2): | |||
| if x[i] != y[i]: | |||
| raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') | |||
| # validate whether last two dims satifing matrix multiply | |||
| x_last = x[-2:] | |||
| y_last = y[-2:] | |||
| x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0] | |||
| y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1] | |||
| if x_col != y_row: | |||
| for i in range(len(x1) - 2): | |||
| if x1[i] != x2[i]: | |||
| raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, ' | |||
| + f'while x1 is {x1[i]}, x2 is {x2[i]}') | |||
| # validate whether last two dims satisfying matrix multiply | |||
| x1_last = x1[-2:] | |||
| x2_last = x2[-2:] | |||
| # x1_col = x1_last[1] if (not transpose_a) else x1_last[0] | |||
| x1_col = x1_last[not self.transpose_a] | |||
| # x2_row = x2_last[0] if (not transpose_b) else x2_last[1] | |||
| x2_row = x2_last[self.transpose_b] | |||
| if x1_col != x2_row: | |||
| raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' | |||
| + f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})' | |||
| + f', y shape {y}(transpose_b={self.transpose_b}).') | |||
| + f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})' | |||
| + f', x2 shape {x2}(transpose_b={self.transpose_b}).') | |||
| # set attribute | |||
| self.add_prim_attr('transpose_x1', self.transpose_a) | |||
| self.add_prim_attr('transpose_x2', self.transpose_b) | |||
| ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]] | |||
| ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]] | |||
| return ret_dims | |||
| def infer_dtype(self, x, y): | |||
| args = {"x": x, "y": y} | |||
| def infer_dtype(self, x1, x2): | |||
| args = {"x1": x1, "x2": x2} | |||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) | |||
| if x.element_type() == mstype.int8: | |||
| if x1.element_type() == mstype.int8: | |||
| return mstype.tensor_type(mstype.int32) | |||
| return x | |||
| return x1 | |||
| class BatchMatMul(MatMul): | |||
| @@ -18,9 +18,7 @@ | |||
| import math | |||
| import operator | |||
| from functools import reduce | |||
| import numpy as np | |||
| from ... import context | |||
| from .. import signature as sig | |||
| from ..._checkparam import Validator as validator | |||
| @@ -58,7 +58,7 @@ class GNNFeatureTransform(nn.Cell): | |||
| Tensor, the shape of the output tensor is :math:`(*B, N, M)`. | |||
| Examples: | |||
| >>> net = nn.Dense(3, 4) | |||
| >>> net = nn.GNNFeatureTransform(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> net(input) | |||
| [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] | |||