|
|
|
@@ -144,6 +144,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList |
|
|
|
MS_EXCEPTION_IF_NULL(arg); |
|
|
|
return arg->Broaden(); |
|
|
|
}); |
|
|
|
if (func_graph_->joined_shapes_.size() != broaded_list.size()) { |
|
|
|
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size() |
|
|
|
<< " does not equal to number of original buffer arguments " |
|
|
|
<< func_graph_->joined_shapes_.size(); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < broaded_list.size(); ++i) { |
|
|
|
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) |
|
|
|
<< ", broaded: " << mindspore::ToString(broaded_list); |
|
|
|
return broaded_list; |
|
|
|
@@ -171,6 +179,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa |
|
|
|
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. |
|
|
|
if (!(joined_args_spec_list == args_spec_list)) { |
|
|
|
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); |
|
|
|
func_graph_->joined_shapes_.clear(); |
|
|
|
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), |
|
|
|
std::back_inserter(func_graph_->joined_shapes_), |
|
|
|
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); |
|
|
|
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; |
|
|
|
} |
|
|
|
return joined_args_spec_list; |
|
|
|
@@ -185,6 +197,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa |
|
|
|
if (!(joined_args_spec_list == args_spec_list)) { |
|
|
|
trace_.push_back(joined_args_spec_list); |
|
|
|
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); |
|
|
|
func_graph_->joined_shapes_.clear(); |
|
|
|
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), |
|
|
|
std::back_inserter(func_graph_->joined_shapes_), |
|
|
|
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); |
|
|
|
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); |
|
|
|
|