Browse Source

fix accu grads shape when enable opt shard

pull/15904/head
Ziyan 4 years ago
parent
commit
3a11b8b39c
1 changed files with 7 additions and 1 deletions
  1. +7
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 7
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -1852,19 +1852,25 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
continue;
}
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract);
// from pipeline or grad accumulation
if (param_name.find(ACCU_GRADS) != std::string::npos) {
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
cloned_abstract->set_shape(parallel_shape);
// in opt shard, accu_grad' shape is different from the original param's shape
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
tensor_layout->set_opt_shard_group("");
}
} else {
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
}
cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
cloned_parameter_node->set_abstract(cloned_abstract);
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()


Loading…
Cancel
Save