Browse Source

!12416 fix the bug for getting output layout

From: @yangzhenzhang
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
bcab044715
1 changed files with 160 additions and 133 deletions
  1. +160
    -133
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 160
- 133
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -956,64 +956,71 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
}
}

static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
if (param_v.size() != 1) {
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
<< param_v.size();
}
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, true);
}
return std::make_pair(node, true);
}
return std::make_pair(nullptr, false);
}

// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false);
} else if (node->isa<Parameter>()) {
}

if (node->isa<Parameter>()) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
} else {
return std::make_pair(node, false);
}
} else if (node->isa<ValueNode>()) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
if (param_v.size() != 1) {
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
<< param_v.size();
}
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, true);
} else {
return std::make_pair(node, true);
}
return std::make_pair(node, false);
}

if (node->isa<ValueNode>()) {
return FindParameterByValueNode(node, func_graph);
}

CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (!FindParameter(cnode->input(index), func_graph).first) {
continue;
}
return FindParameter(cnode->input(index), func_graph);
}
}

if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
return std::make_pair(node, false);
}

if (IsParallelCareNode(cnode)) {
return std::make_pair(nullptr, false);
} else {
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (!FindParameter(cnode->input(index), func_graph).first) {
continue;
}
return FindParameter(cnode->input(index), func_graph);
}
} else {
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
return std::make_pair(node, false);
}
if (IsParallelCareNode(cnode)) {
return std::make_pair(nullptr, false);
} else {
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) {
continue;
}
if (!FindParameter(cnode->input(index), func_graph).first) {
continue;
}
return FindParameter(cnode->input(index), func_graph);
}
}
}

ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) {
continue;
}
if (!FindParameter(cnode->input(index), func_graph).first) {
continue;
}
return FindParameter(cnode->input(index), func_graph);
}
return std::make_pair(nullptr, false);
}
@@ -1101,6 +1108,25 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}

static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it.";
return false;
}

if ((node->inputs().size() == 2) &&
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
return false;
}

if (mirror_ops.size() != node_size - 1) {
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
<< node_size - 1;
}
return true;
}

void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size();
@@ -1113,21 +1139,11 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
node_size--;
}
}
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it.";
return;
}

if ((node->inputs().size() == 2) &&
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) {
return;
}

if (mirror_ops.size() != node_size - 1) {
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
<< node_size - 1;
}
for (size_t index = 1; index < node_size; ++index) {
OperatorVector backward_op = mirror_ops[index - 1];
if (backward_op.empty()) {
@@ -1181,15 +1197,15 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
} else {
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
continue;
}
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
}
}
@@ -1849,13 +1865,29 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
strategyPtr->ResetInputs(strategys);
}

static bool CheckExtractInfomation(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}

ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
return false;
}

if (!IsParallelCareNode(cnode)) {
return false;
}
return true;
}

void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) {
// load strategy map from checkpoint
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
(StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
vector<std::string> last_forward_node_ids;
if (!is_training) {
@@ -1865,76 +1897,71 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini

for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
if (!CheckExtractInfomation(cnode)) {
continue;
}

SetVirtualDatasetStrategy(cnode);
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) {
continue;
}

auto attrs = prim->attrs();
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
if (IsParallelCareNode(cnode)) {
std::vector<Shapes> shape_list = ExtractShape(cnode);
if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
}
OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
if (operator_ == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed";
}
auto &inputs = cnode->inputs();
std::vector<ValuePtr> input_value;
for (size_t index = 1; index < inputs.size(); ++index) {
if (inputs[index]->isa<ValueNode>()) {
input_value.push_back(GetValueNode(inputs[index]));
} else {
input_value.emplace_back(nullptr);
}
}
StrategyPtr strategyPtr = nullptr;
(*operator_).set_input_value(input_value);
(*operator_).set_outputs_dtype(cnode->Type());
(*operator_).set_cnode(cnode);
if (prim->name() == RESHAPE) {
cnode->set_user_data<OperatorInfo>(operator_);

std::vector<Shapes> shape_list = ExtractShape(cnode);
if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
}
OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
MS_EXCEPTION_IF_NULL(operator_);

auto &inputs = cnode->inputs();
std::vector<ValuePtr> input_value;
for (size_t index = 1; index < inputs.size(); ++index) {
if (inputs[index]->isa<ValueNode>()) {
input_value.push_back(GetValueNode(inputs[index]));
continue;
}
// load strategy checkpoint
// key of strategy map
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
if (!param_names.empty()) {
strategy_key_name = prim->name() + "_" + param_names[0].first;
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
bool full_batch = ParallelContext::GetInstance()->full_batch();
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
} else if (StrategyFound(attrs)) {
strategyPtr = ExtractStrategy(attrs);
} else {
strategyPtr = stra_map[strategy_key_name];
}
if (strategyPtr != nullptr) {
if (is_last_nodes && full_batch) {
SetLastNodeStrategy(strategyPtr);
}
if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
cnode->set_user_data<OperatorInfo>(operator_);
} else {
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
}
input_value.emplace_back(nullptr);
}
StrategyPtr strategyPtr = nullptr;
(*operator_).set_input_value(input_value);
(*operator_).set_outputs_dtype(cnode->Type());
(*operator_).set_cnode(cnode);
if (prim->name() == RESHAPE) {
cnode->set_user_data<OperatorInfo>(operator_);
continue;
}
// load strategy checkpoint
// key of strategy map
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
if (!param_names.empty()) {
strategy_key_name = prim->name() + "_" + param_names[0].first;
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
bool full_batch = ParallelContext::GetInstance()->full_batch();
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
} else if (StrategyFound(attrs)) {
strategyPtr = ExtractStrategy(attrs);
} else {
strategyPtr = stra_map[strategy_key_name];
}

MS_EXCEPTION_IF_NULL(strategyPtr);
if (is_last_nodes && full_batch) {
SetLastNodeStrategy(strategyPtr);
}
if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
cnode->set_user_data<OperatorInfo>(operator_);
}
}

@@ -1994,9 +2021,9 @@ std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, si
MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
MS_EXCEPTION_IF_NULL(distribute_operator);
if (distribute_operator->outputs_tensor_info().size() < output_index) {
if (distribute_operator->outputs_tensor_info().size() <= output_index) {
MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size()
<< ", must be less than output_index " << output_index;
<< ", must be greater than output_index " << output_index;
}
TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();


Loading…
Cancel
Save