Browse Source

bug fix fot nn.Dense and P.Matmul

tags/v1.1.0
chenzomi 5 years ago
parent
commit
b22cb38dab
4 changed files with 50 additions and 56 deletions
  1. +18
    -24
      mindspore/nn/layer/basic.py
  2. +31
    -29
      mindspore/ops/operations/math_ops.py
  3. +0
    -2
      mindspore/ops/operations/nn_ops.py
  4. +1
    -1
      tests/st/gnn/aggregator.py

+ 18
- 24
mindspore/nn/layer/basic.py View File

@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================

"""basic""" """basic"""

import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.seed import get_seed from mindspore.common.seed import get_seed
@@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import context from mindspore import context
from mindspore.ops import _selected_ops
from ..cell import Cell from ..cell import Cell
from .activation import get_activation from .activation import get_activation
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@@ -139,10 +140,8 @@ class Flatten(Cell):
the product of the remaining dimensions. the product of the remaining dimensions.


Examples: Examples:
>>> net = nn.Flatten()
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
>>> input.shape
(2, 2, 2)
>>> net = nn.Flatten()
>>> net(input) >>> net(input)
[[1.2 1.2 2.1 2.1] [[1.2 1.2 2.1 2.1]
[2.2 2.2 3.2 3.2]] [2.2 2.2 3.2 3.2]]
@@ -157,9 +156,9 @@ class Flatten(Cell):


class Dense(Cell): class Dense(Cell):
r""" r"""
The fully connected layer.
The dense connected layer.


Applies dense-connected layer for the input. This layer implements the operation as:
Applies dense connected layer for the input. This layer implements the operation as:


.. math:: .. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
@@ -190,8 +189,8 @@ class Dense(Cell):
Tensor of shape :math:`(N, out\_channels)`. Tensor of shape :math:`(N, out\_channels)`.


Examples: Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net = nn.Dense(3, 4)
>>> net(input) >>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]] [ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
@@ -212,41 +211,36 @@ class Dense(Cell):
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels: weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")

raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")


self.bias = None
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")

raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd()


self.matmul = P.MatMul(transpose_b=True) self.matmul = P.MatMul(transpose_b=True)
self.bias_add = _selected_ops.BiasAdd()

self.activation = get_activation(activation) self.activation = get_activation(activation)
self.activation_flag = self.activation is not None self.activation_flag = self.activation is not None


def construct(self, x): def construct(self, x):
output = self.matmul(x, self.weight)
x = self.matmul(x, self.weight)
if self.has_bias: if self.has_bias:
output = self.bias_add(output, self.bias)
x = self.bias_add(x, self.bias)
if self.activation_flag: if self.activation_flag:
return self.activation(output)
return output
x = self.activation(x)
return x


def extend_repr(self): def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias: if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)

s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag: if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)

return str_info
s += ', activation={}'.fomat(self.activation)
return s




@constexpr @constexpr


+ 31
- 29
mindspore/ops/operations/math_ops.py View File

@@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer):


class MatMul(PrimitiveWithInfer): class MatMul(PrimitiveWithInfer):
""" """
Multiplies matrix `a` by matrix `b`.
Multiplies matrix `a` and matrix `b`.


The rank of input tensors must be `2`.
The rank of input tensors must equal to `2`.


Args: Args:
transpose_a (bool): If True, `a` is transposed before multiplication. Default: False. transpose_a (bool): If True, `a` is transposed before multiplication. Default: False.
@@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer):
Tensor, the shape of the output tensor is :math:`(N, M)`. Tensor, the shape of the output tensor is :math:`(N, M)`.


Examples: Examples:
>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
>>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
>>> matmul = P.MatMul() >>> matmul = P.MatMul()
>>> output = matmul(input_x, input_y)
>>> output = matmul(input_x1, input_x2)
""" """


@prim_attr_register @prim_attr_register
@@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer):
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")


def check_shape_size(self, x, y):
if len(x) != 2 or len(y) != 2:
raise ValueError('MatMul input x, y should be the same dimension size and should be '
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
def check_shape_size(self, x1, x2):
if len(x1) != 2 or len(x2) != 2:
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')


def infer_shape(self, x, y):
self.check_shape_size(x, y)
def infer_shape(self, x1, x2):
self.check_shape_size(x1, x2)
cls_name = self.name cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for i in range(len(x) - 2):
if x[i] != y[i]:
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')

# validate whether last two dims satifing matrix multiply
x_last = x[-2:]
y_last = y[-2:]

x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
if x_col != y_row:
for i in range(len(x1) - 2):
if x1[i] != x2[i]:
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, '
+ f'while x1 is {x1[i]}, x2 is {x2[i]}')

# validate whether last two dims satisfying matrix multiply
x1_last = x1[-2:]
x2_last = x2[-2:]
# x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
x1_col = x1_last[not self.transpose_a]
# x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
x2_row = x2_last[self.transpose_b]
if x1_col != x2_row:
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
+ f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
+ f', y shape {y}(transpose_b={self.transpose_b}).')
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
# set attribute # set attribute
self.add_prim_attr('transpose_x1', self.transpose_a) self.add_prim_attr('transpose_x1', self.transpose_a)
self.add_prim_attr('transpose_x2', self.transpose_b) self.add_prim_attr('transpose_x2', self.transpose_b)


ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
return ret_dims return ret_dims


def infer_dtype(self, x, y):
args = {"x": x, "y": y}
def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
if x.element_type() == mstype.int8:
if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32) return mstype.tensor_type(mstype.int32)
return x
return x1




class BatchMatMul(MatMul): class BatchMatMul(MatMul):


+ 0
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -18,9 +18,7 @@
import math import math
import operator import operator
from functools import reduce from functools import reduce

import numpy as np import numpy as np

from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator


+ 1
- 1
tests/st/gnn/aggregator.py View File

@@ -58,7 +58,7 @@ class GNNFeatureTransform(nn.Cell):
Tensor, the shape of the output tensor is :math:`(*B, N, M)`. Tensor, the shape of the output tensor is :math:`(*B, N, M)`.


Examples: Examples:
>>> net = nn.Dense(3, 4)
>>> net = nn.GNNFeatureTransform(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input) >>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]


Loading…
Cancel
Save