|
|
|
@@ -566,54 +566,69 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr |
|
|
|
|
|
|
|
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
const std::string &op_name = primitive->name(); |
|
|
|
CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
auto reshape = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); |
|
|
|
auto input_shp = input->shape()->shape(); |
|
|
|
auto reshape_val = reshape->BuildValue(); |
|
|
|
if (reshape_val->isa<AnyValue>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input_shape can't be anything: " << args_spec_list[1]->ToString(); |
|
|
|
} |
|
|
|
auto reshape_val_data = reshape_val->cast<ValueTuplePtr>()->value(); |
|
|
|
ShapeVector reshape_vec; |
|
|
|
(void)std::transform(std::begin(reshape_val_data), std::end(reshape_val_data), std::back_inserter(reshape_vec), |
|
|
|
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); |
|
|
|
ShapeVector result_shp; |
|
|
|
auto input_prod = input_shp[0]; |
|
|
|
int64_t dim_prod = 1; |
|
|
|
size_t neg_idx = 0; |
|
|
|
for (size_t i = 1; i < input_shp.size(); i++) { |
|
|
|
input_prod *= input_shp[i]; |
|
|
|
} |
|
|
|
auto num_neg_one = count(std::begin(reshape_vec), std::end(reshape_vec), -1); |
|
|
|
if (num_neg_one > 1) { |
|
|
|
MS_LOG(EXCEPTION) << "The shape can only has one -1 at most, but " << num_neg_one; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < reshape_vec.size(); i++) { |
|
|
|
if (reshape_vec[i] == -1) { |
|
|
|
neg_idx = i; |
|
|
|
result_shp.push_back(-1); |
|
|
|
} else { |
|
|
|
dim_prod *= reshape_vec[i]; |
|
|
|
result_shp.push_back(reshape_vec[i]); |
|
|
|
const std::string op_name = primitive->name(); |
|
|
|
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(x); |
|
|
|
MS_EXCEPTION_IF_NULL(x->shape()); |
|
|
|
ShapeVector shape; |
|
|
|
ShapeVector x_shape = x->shape()->shape(); |
|
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
|
if (x_max_shape.empty()) { |
|
|
|
x_max_shape = x_shape; |
|
|
|
} |
|
|
|
if (x_min_shape.empty()) { |
|
|
|
x_min_shape = x_shape; |
|
|
|
} |
|
|
|
ValuePtr sh = primitive->GetAttr("shape"); |
|
|
|
auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(reshape_value_tuple); |
|
|
|
auto reshape_tuple = reshape_value_tuple->value(); |
|
|
|
|
|
|
|
(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), |
|
|
|
[](const ValuePtr &e) -> int { return GetValue<int>(e); }); |
|
|
|
|
|
|
|
auto max_shape = shape; |
|
|
|
auto min_shape = shape; |
|
|
|
int x_num = 1; |
|
|
|
int x_min_num = 1; |
|
|
|
int x_max_num = 1; |
|
|
|
for (int value : x_shape) { |
|
|
|
x_num = IntMulWithOverflowCheck(value, x_num); |
|
|
|
} |
|
|
|
for (int value : x_min_shape) { |
|
|
|
x_min_num = IntMulWithOverflowCheck(value, x_min_num); |
|
|
|
} |
|
|
|
for (int value : x_max_shape) { |
|
|
|
x_max_num = IntMulWithOverflowCheck(value, x_max_num); |
|
|
|
} |
|
|
|
|
|
|
|
auto it_first = find(shape.begin(), shape.end(), -1); |
|
|
|
if (it_first != shape.end()) { |
|
|
|
auto it_second = find(it_first + 1, shape.end(), -1); |
|
|
|
if (it_second != shape.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "At most one component of input shape can be -1"; |
|
|
|
} |
|
|
|
int index = std::distance(it_first, shape.begin()); |
|
|
|
int infer_value = x_num; |
|
|
|
int infer_min_value = x_min_num; |
|
|
|
int infer_max_value = x_max_num; |
|
|
|
for (size_t i = 0; i < shape.size(); ++i) { |
|
|
|
int value = shape[i]; |
|
|
|
if (value != -1 && value != 0) { |
|
|
|
infer_value = infer_value / value; |
|
|
|
infer_min_value = infer_min_value / value; |
|
|
|
infer_max_value = infer_max_value / value; |
|
|
|
} |
|
|
|
} |
|
|
|
shape[index] = infer_value; |
|
|
|
min_shape[index] = infer_min_value; |
|
|
|
max_shape[index] = infer_max_value; |
|
|
|
} |
|
|
|
if (dim_prod < 0 || input_prod % dim_prod != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "The input_x shape product is " << input_prod << ", input_shape shape product is " << dim_prod |
|
|
|
<< ", and this value should be > 0 and should divide product of input_x."; |
|
|
|
} |
|
|
|
if (num_neg_one == 1) { |
|
|
|
int64_t val = static_cast<int64_t>(input_prod) / dim_prod; |
|
|
|
dim_prod *= val; |
|
|
|
result_shp[neg_idx] = val; |
|
|
|
} |
|
|
|
if (dim_prod != input_prod) { |
|
|
|
MS_LOG(EXCEPTION) |
|
|
|
<< "The product of input_x shape should be equal to product of input_shape shape, but input_x shape is " |
|
|
|
<< input_prod << ", product of input_shape shape is " << dim_prod; |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp)); |
|
|
|
|
|
|
|
AbstractTensorPtr ret = |
|
|
|
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} // namespace abstract |
|
|
|
} // namespace mindspore |