| @@ -1852,19 +1852,25 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it"; | MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); | |||||
| auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>(); | |||||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | ||||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | ||||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | MS_EXCEPTION_IF_NULL(cloned_abstract); | ||||
| // from pipeline or grad accumulation | |||||
| if (param_name.find(ACCU_GRADS) != std::string::npos) { | if (param_name.find(ACCU_GRADS) != std::string::npos) { | ||||
| auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array(); | auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array(); | ||||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | ||||
| MS_EXCEPTION_IF_NULL(parallel_shape); | MS_EXCEPTION_IF_NULL(parallel_shape); | ||||
| cloned_abstract->set_shape(parallel_shape); | cloned_abstract->set_shape(parallel_shape); | ||||
| // in opt shard, accu_grad' shape is different from the original param's shape | |||||
| if (ParallelContext::GetInstance()->enable_parallel_optimizer()) { | |||||
| tensor_layout->set_opt_shard_group(""); | |||||
| } | |||||
| } else { | } else { | ||||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | ||||
| } | } | ||||
| cloned_parameter->set_user_data<TensorLayout>(tensor_layout); | |||||
| cloned_parameter_node->set_abstract(cloned_abstract); | cloned_parameter_node->set_abstract(cloned_abstract); | ||||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | ||||
| << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | ||||