|
|
@@ -403,7 +403,7 @@ AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePt |
|
|
ShapeVector shape = x->shape()->shape(); |
|
|
ShapeVector shape = x->shape()->shape(); |
|
|
ShapeVector min_shape = x->shape()->min_shape(); |
|
|
ShapeVector min_shape = x->shape()->min_shape(); |
|
|
ShapeVector max_shape = x->shape()->max_shape(); |
|
|
ShapeVector max_shape = x->shape()->max_shape(); |
|
|
(void)CheckMinMaxShape(shape, &min_shape, &max_shape); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(shape, &min_shape, &max_shape); |
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); |
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -417,7 +417,7 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv |
|
|
ShapeVector shape = x->shape()->shape(); |
|
|
ShapeVector shape = x->shape()->shape(); |
|
|
ShapeVector min_shape = x->shape()->min_shape(); |
|
|
ShapeVector min_shape = x->shape()->min_shape(); |
|
|
ShapeVector max_shape = x->shape()->max_shape(); |
|
|
ShapeVector max_shape = x->shape()->max_shape(); |
|
|
(void)CheckMinMaxShape(shape, &min_shape, &max_shape); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(shape, &min_shape, &max_shape); |
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); |
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -774,7 +774,7 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr |
|
|
ShapeVector min_shp; |
|
|
ShapeVector min_shp; |
|
|
ShapeVector x_max_shp = input->shape()->max_shape(); |
|
|
ShapeVector x_max_shp = input->shape()->max_shape(); |
|
|
ShapeVector x_min_shp = input->shape()->min_shape(); |
|
|
ShapeVector x_min_shp = input->shape()->min_shape(); |
|
|
(void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp); |
|
|
for (size_t i = 0; i < perm_vec.size(); i++) { |
|
|
for (size_t i = 0; i < perm_vec.size(); i++) { |
|
|
auto idx = static_cast<size_t>(perm_vec[i]); |
|
|
auto idx = static_cast<size_t>(perm_vec[i]); |
|
|
result_shp.push_back(input_shp[idx]); |
|
|
result_shp.push_back(input_shp[idx]); |
|
|
@@ -984,7 +984,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
int64_t rank_base = SizeToLong(shape_base.size()); |
|
|
int64_t rank_base = SizeToLong(shape_base.size()); |
|
|
ShapeVector min_shape_base = tensor_base->shape()->min_shape(); |
|
|
ShapeVector min_shape_base = tensor_base->shape()->min_shape(); |
|
|
ShapeVector max_shape_base = tensor_base->shape()->max_shape(); |
|
|
ShapeVector max_shape_base = tensor_base->shape()->max_shape(); |
|
|
(void)CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base); |
|
|
|
|
|
|
|
|
primitive->set_attr("T", tensor_base->element()->BuildType()); |
|
|
primitive->set_attr("T", tensor_base->element()->BuildType()); |
|
|
primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len))); |
|
|
primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len))); |
|
|
@@ -1009,7 +1009,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
int64_t rank_tensor = SizeToLong(shape_tensor.size()); |
|
|
int64_t rank_tensor = SizeToLong(shape_tensor.size()); |
|
|
ShapeVector min_shape_tensor = tensor->shape()->min_shape(); |
|
|
ShapeVector min_shape_tensor = tensor->shape()->min_shape(); |
|
|
ShapeVector max_shape_tensor = tensor->shape()->max_shape(); |
|
|
ShapeVector max_shape_tensor = tensor->shape()->max_shape(); |
|
|
(void)CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor); |
|
|
(void)CheckDtypeSame(op_name, tensor_base, tensor); |
|
|
(void)CheckDtypeSame(op_name, tensor_base, tensor); |
|
|
if (rank_tensor != rank_base) { |
|
|
if (rank_tensor != rank_base) { |
|
|
MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank"; |
|
|
MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank"; |
|
|
@@ -1033,7 +1033,7 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
auto shape = ret->shape()->shape(); |
|
|
auto shape = ret->shape()->shape(); |
|
|
auto min_shape = ret->shape()->min_shape(); |
|
|
auto min_shape = ret->shape()->min_shape(); |
|
|
auto max_shape = ret->shape()->max_shape(); |
|
|
auto max_shape = ret->shape()->max_shape(); |
|
|
(void)CheckMinMaxShape(shape, &min_shape, &max_shape); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(shape, &min_shape, &max_shape); |
|
|
shape[axis_value] = all_shp; |
|
|
shape[axis_value] = all_shp; |
|
|
min_shape[axis_value] = min_all_shp; |
|
|
min_shape[axis_value] = min_all_shp; |
|
|
max_shape[axis_value] = max_all_shp; |
|
|
max_shape[axis_value] = max_all_shp; |
|
|
@@ -1107,13 +1107,13 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit |
|
|
}; |
|
|
}; |
|
|
// main calculate shape func |
|
|
// main calculate shape func |
|
|
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { |
|
|
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { |
|
|
shape.insert(shape.end(), x_shape.begin(), x_shape.end()); |
|
|
|
|
|
|
|
|
(void)shape.insert(shape.end(), x_shape.begin(), x_shape.end()); |
|
|
auto axis_value = GetValue<int64_t>(axis); |
|
|
auto axis_value = GetValue<int64_t>(axis); |
|
|
check_axis(axis_value, x_shape.size()); |
|
|
check_axis(axis_value, x_shape.size()); |
|
|
if (keep_dims_value) { |
|
|
if (keep_dims_value) { |
|
|
shape[axis_value] = 1; |
|
|
shape[axis_value] = 1; |
|
|
} else { |
|
|
} else { |
|
|
shape.erase(std::begin(shape) + axis_value); |
|
|
|
|
|
|
|
|
(void)shape.erase(std::begin(shape) + axis_value); |
|
|
} |
|
|
} |
|
|
}; |
|
|
}; |
|
|
ShapeVector shape = {}; |
|
|
ShapeVector shape = {}; |
|
|
@@ -1122,7 +1122,7 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit |
|
|
ShapeVector x_shape = x->shape()->shape(); |
|
|
ShapeVector x_shape = x->shape()->shape(); |
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); |
|
|
|
|
|
|
|
|
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); |
|
|
cal_shape(shape, x_shape); |
|
|
cal_shape(shape, x_shape); |
|
|
cal_shape(min_shape, x_min_shape); |
|
|
cal_shape(min_shape, x_min_shape); |
|
|
cal_shape(max_shape, x_max_shape); |
|
|
cal_shape(max_shape, x_max_shape); |
|
|
|