Browse Source

fix Split

tags/v0.6.0-beta
jiangjinsheng 5 years ago
parent
commit
f3badea5bc
2 changed files with 6 additions and 4 deletions
  1. +4
    -2
      mindspore/ops/operations/array_ops.py
  2. +2
    -2
      mindspore/ops/operations/nn_ops.py

+ 4
- 2
mindspore/ops/operations/array_ops.py View File

@@ -643,8 +643,10 @@ class Split(PrimitiveWithInfer):
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
output_valid_check = x_shape[self.axis] % self.output_num
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
f" output_num {self.output_num}")

x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
out_shapes = []
out_dtypes = []


+ 2
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -4951,8 +4951,7 @@ class LRN(PrimitiveWithInfer):
bias (float): An offset (usually positive to avoid dividing by 0).
alpha (float): A scale factor, usually positive.
beta (float): An exponent.
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL".
Default: "ACROSS_CHANNELS".
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".

Inputs:
- **x** (Tensor) - A 4D Tensor with float16 or float32 data type.
@@ -4974,6 +4973,7 @@ class LRN(PrimitiveWithInfer):
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name)
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)


Loading…
Cancel
Save