Browse Source

fix error log of check axis

tags/v1.6.0
lianliguang 4 years ago
parent
commit
d907524023
5 changed files with 14 additions and 10 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
  2. +4
    -3
      mindspore/core/abstract/param_validator.cc
  3. +2
    -1
      mindspore/core/abstract/param_validator.h
  4. +3
    -3
      mindspore/core/abstract/prim_arrays.cc
  5. +4
    -2
      mindspore/core/ops/layer_norm.cc

+ 1
- 1
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc View File

@@ -240,7 +240,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP
}

for (auto &elem : axis_data) {
int64_t e_value = CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank), SizeToLong(x_rank));
int64_t e_value = CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank), SizeToLong(x_rank), "input_x");
(void)axis_set.insert(e_value);
}
MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());


+ 4
- 3
mindspore/core/abstract/param_validator.cc View File

@@ -162,7 +162,7 @@ TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_ba
}

int64_t CheckAxis(const std::string &op, const std::string &args_name, const ValuePtr &axis, int64_t minimum,
int64_t max) {
int64_t max, const std::string &rank_name) {
if (axis == nullptr) {
MS_LOG(EXCEPTION) << op << " evaluator axis is null";
}
@@ -171,8 +171,9 @@ int64_t CheckAxis(const std::string &op, const std::string &args_name, const Val
}
int64_t axis_value = GetValue<int64_t>(axis);
if (axis_value >= max || axis_value < minimum) {
MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s \'" << args_name << "\' value should be in the range ["
<< minimum << ", " << max << "), but got " << axis_value;
MS_LOG(EXCEPTION) << "For primitive[" << op << "], " << rank_name << "'s rank is " << max << ", while the "
<< "\'" << args_name << "\' value should be in the range [" << minimum << ", " << max
<< "), but got " << axis_value;
}
if (axis_value < 0) {
axis_value = axis_value + SizeToLong(max);


+ 2
- 1
mindspore/core/abstract/param_validator.h View File

@@ -45,7 +45,8 @@ void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base,

TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);

int64_t CheckAxis(const std::string &op, const std::string &arg_name, const ValuePtr &axis, int64_t min, int64_t max);
int64_t CheckAxis(const std::string &op, const std::string &arg_name, const ValuePtr &axis, int64_t min, int64_t max,
const std::string &rank_name);

void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect);



+ 3
- 3
mindspore/core/abstract/prim_arrays.cc View File

@@ -148,7 +148,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr

ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base);
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x");

for (size_t i = 1; i < tuple_len; ++i) {
AbstractTensorPtr tensor = nullptr;
@@ -948,7 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr
int64_t rank = SizeToLong(x_shape.size());

ValuePtr axis = primitive->GetAttr("axis");
int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank);
int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank, "input_x");
int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
@@ -1093,7 +1093,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p

ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base);
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x");

int64_t all_shp = shape_base[axis_value];
int64_t min_all_shp = min_shape_base[axis_value];


+ 4
- 2
mindspore/core/ops/layer_norm.cc View File

@@ -61,10 +61,12 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit

// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
int64_t begin_norm_axis = abstract::CheckAxis(op_name, "begin_norm_axis", bna_ptr, -1, SizeToLong(input_rank));
int64_t begin_norm_axis =
abstract::CheckAxis(op_name, "begin_norm_axis", bna_ptr, -1, SizeToLong(input_rank), "input_x");

ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
int64_t begin_params_axis = abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank));
int64_t begin_params_axis =
abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank), "input_x");

// the beta and gama shape should be x_shape[begin_params_axis:]
auto valid_types = {kFloat16, kFloat32};


Loading…
Cancel
Save