Browse Source

!12038 Frontend operator attribute modification

From: @zhupuxu
Reviewed-by: @kingxian,@zhoufeng54
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
aa7f82cf1b
3 changed files with 0 additions and 5 deletions
  1. +0
    -2
      mindspore/ops/operations/array_ops.py
  2. +0
    -2
      mindspore/ops/operations/math_ops.py
  3. +0
    -1
      mindspore/ops/operations/nn_ops.py

+ 0
- 2
mindspore/ops/operations/array_ops.py View File

@@ -2167,7 +2167,6 @@ class Concat(PrimitiveWithInfer):
x_shp = input_x['shape']
x_type = input_x['dtype']
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()
value = None
@@ -2616,7 +2615,6 @@ class Select(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, cond_type, x_type, y_type):
self.add_prim_attr('T', x_type)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)


+ 0
- 2
mindspore/ops/operations/math_ops.py View File

@@ -313,7 +313,6 @@ class _Reduce(PrimitiveWithInfer):
"""Initialize Reduce"""
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
self.add_prim_attr("io_format", "ND")

def __call__(self, x, axis=()):
args = [x, axis]
@@ -753,7 +752,6 @@ class MatMul(PrimitiveWithCheck):
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
self.add_prim_attr("io_format", "ND")

def check_shape_size(self, x1, x2):
if len(x1) != 2 or len(x2) != 2:


+ 0
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -1457,7 +1457,6 @@ class Conv2D(PrimitiveWithCheck):
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.add_prim_attr('offset_a', 0)

def check_shape(self, x_shape, w_shape, b_shape=None):
x_shape_norm = x_shape if self.format == "NCHW" else (x_shape[0], x_shape[3], x_shape[1], x_shape[2])


Loading…
Cancel
Save