|
|
|
@@ -1852,19 +1852,25 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { |
|
|
|
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); |
|
|
|
auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); |
|
|
|
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); |
|
|
|
auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); |
|
|
|
MS_EXCEPTION_IF_NULL(cloned_abstract); |
|
|
|
// from pipeline or grad accumulation |
|
|
|
if (param_name.find(ACCU_GRADS) != std::string::npos) { |
|
|
|
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array(); |
|
|
|
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(parallel_shape); |
|
|
|
cloned_abstract->set_shape(parallel_shape); |
|
|
|
// in opt shard, accu_grad' shape is different from the original param's shape |
|
|
|
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) { |
|
|
|
tensor_layout->set_opt_shard_group(""); |
|
|
|
} |
|
|
|
} else { |
|
|
|
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); |
|
|
|
} |
|
|
|
cloned_parameter->set_user_data<TensorLayout>(tensor_layout); |
|
|
|
cloned_parameter_node->set_abstract(cloned_abstract); |
|
|
|
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() |
|
|
|
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() |
|
|
|
|