|
|
|
@@ -148,7 +148,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr |
|
|
|
|
|
|
|
ValuePtr axis = primitive->GetAttr("axis"); |
|
|
|
// Axis value should be in [-(rank_base + 1), rank_base). |
|
|
|
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base); |
|
|
|
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x"); |
|
|
|
|
|
|
|
for (size_t i = 1; i < tuple_len; ++i) { |
|
|
|
AbstractTensorPtr tensor = nullptr; |
|
|
|
@@ -948,7 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr |
|
|
|
int64_t rank = SizeToLong(x_shape.size()); |
|
|
|
|
|
|
|
ValuePtr axis = primitive->GetAttr("axis"); |
|
|
|
int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank); |
|
|
|
int64_t axis_value_pos = CheckAxis(op_name, "axis", axis, -(rank + 1), rank, "input_x"); |
|
|
|
int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num")); |
|
|
|
if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) { |
|
|
|
MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos] |
|
|
|
@@ -1093,7 +1093,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
|
|
|
|
ValuePtr axis = primitive->GetAttr("axis"); |
|
|
|
// Axis value should be in [-(rank_base + 1), rank_base). |
|
|
|
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base); |
|
|
|
int64_t axis_value = CheckAxis(op_name, "axis", axis, -(rank_base + 1), rank_base, "input_x"); |
|
|
|
|
|
|
|
int64_t all_shp = shape_base[axis_value]; |
|
|
|
int64_t min_all_shp = min_shape_base[axis_value]; |
|
|
|
|