Browse Source

!9597 [Auto parallel] Simplying step_auto_parallel

From: @xiaoda_zh
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
decf796fb5
6 changed files with 367 additions and 349 deletions
  1. +328
    -1
      mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
  2. +24
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/node_info.h
  3. +12
    -283
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  4. +3
    -9
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.h
  5. +0
    -52
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  6. +0
    -4
      mindspore/ccsrc/frontend/parallel/step_parallel.h

+ 328
- 1
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc View File

@@ -18,10 +18,11 @@

#include <string>

#include "ir/anf.h"
#include "ir/param_info.h"
#include "ir/meta_tensor.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/step_parallel.h"

namespace mindspore {
namespace parallel {
@@ -45,5 +46,331 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
}
return param_value->requires_grad();
}

// Given the node, return whether each input is a parameter or a output of a operator.
// The returned boolean vector should be the same order of the inputs, thus its implementation
// is closely consistent with ExtractShape() in step_parallel.cc
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
std::vector<bool> is_parameter;
std::vector<AnfNodePtr> node_inputs{node->inputs()};
// input is a ValueList or ValueTuple, then all inputs are not parameter.
if ((node_inputs.size() == 2) &&
(IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
std::vector<ValuePtr> inputs_seq;
if (IsValueNode<ValueList>(node_inputs[1])) {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
} else {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
}
return std::vector<bool>(inputs_seq.size(), false);
}
if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
}
for (size_t i = 1; i < node_inputs.size(); ++i) {
auto input = node_inputs[i];

if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
is_parameter.push_back(ParameterRequireGrad(input_parameter));
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
is_parameter.push_back(false);
}
}
return is_parameter;
}

// Given the type, return the number of bytes to represent this type
size_t GetLengthOfDataType(const TypePtr &type) {
switch (type->type_id()) {
case kNumberTypeBool:
return sizeof(bool);
case kNumberTypeInt8:
return sizeof(int8_t);
case kNumberTypeInt16:
return sizeof(int16_t);
case kNumberTypeInt32:
return sizeof(int32_t);
case kNumberTypeInt64:
return sizeof(int64_t);
case kNumberTypeUInt8:
return sizeof(uint8_t);
case kNumberTypeUInt16:
return sizeof(uint16_t);
case kNumberTypeUInt32:
return sizeof(uint32_t);
case kNumberTypeUInt64:
return sizeof(uint64_t);
case kNumberTypeFloat16:
return sizeof(float) / 2;
case kNumberTypeFloat32:
return sizeof(float);
case kNumberTypeFloat64:
return sizeof(double);
case kNumberTypeInt:
return sizeof(int64_t);
case kNumberTypeUInt:
return sizeof(unsigned int64_t);
case kNumberTypeFloat:
return sizeof(float);
default:
MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
}
}

size_t GetInputsTypeLen(const AnfNodePtr &input) {
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
}

size_t input_type_len = 0;
auto type = input->Type();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<mindspore::TensorType>()) {
auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
input_type_len = GetLengthOfDataType(input_element_type);
} else {
MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
}
return input_type_len;
}

std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<size_t> inputs_type_len;
std::vector<AnfNodePtr> node_inputs{node->inputs()};

if ((node_inputs.size() == 2) &&
(IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
std::vector<ValuePtr> inputs_seq;
if (IsValueNode<ValueList>(node_inputs[1])) {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
} else {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
}
for (auto &ele : inputs_seq) {
auto tensor = ele->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
}
return inputs_type_len;
}

if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
}

// extract input element length
for (auto &input : node_inputs) {
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
if (parameters.size() != 1) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
} else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
// extract input shape from parameter and apply node
inputs_type_len.push_back(GetInputsTypeLen(input));
}
}
return inputs_type_len;
}

std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<TypePtr> outputs_type;
// extract output element type
auto primary_output_type = node->Type();
MS_EXCEPTION_IF_NULL(primary_output_type);
if (primary_output_type->isa<mindspore::Tuple>()) {
// in this case, the output is a tuple
auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
auto elements = tuple_output_type->elements();
for (auto &ele : elements) {
if (ele->isa<mindspore::TensorType>()) {
auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
outputs_type.push_back(ele_element_type);
} else {
MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
}
}
} else {
// in this case, the output is a single tensor
if (primary_output_type->isa<mindspore::TensorType>()) {
auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
outputs_type.push_back(element_type);
} else {
MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
}
}
return outputs_type;
}

std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters;
if (!IsValueNode<RefKey>(node)) {
MS_LOG(ERROR) << "The node is not a ref key";
return parameters;
}

auto ref_key = GetValueNode<RefKeyPtr>(node);
MS_EXCEPTION_IF_NULL(ref_key);
auto name = ref_key->tag();

auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto roots = manager->roots();
if (roots.size() != 1) {
MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1";
return parameters;
}

FuncGraphPtr root_g = roots.back();
MS_EXCEPTION_IF_NULL(root_g);
for (auto &param_node : root_g->parameters()) {
auto param = param_node->cast<ParameterPtr>();
if (param && (name == param->name())) {
parameters.push_back(param_node);
MS_LOG(INFO) << "The name of ref key is: " << name;
return parameters;
}
}

MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter";
return parameters;
}

bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}

auto value_node = cnode->input(0)->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == prim_name) {
return true;
}
return false;
}

bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == RESHAPE) {
auto operator_info = cnode->user_data<OperatorInfo>();
std::string op_info_name = operator_info->name();
if (op_cache->find(op_info_name) != op_cache->end()) {
return false;
}
op_cache->insert(op_info_name);
return true;
}
return false;
}

// Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) {
// if previous node is a parameter, handle it in the outsize.
if (node->isa<Parameter>()) {
return false;
}
if (!node->isa<CNode>()) {
return false;
}
CNodePtr cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
auto node_op_info = cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
*pre_operator_info = node_op_info;
*out_index = 0;
return true;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == TUPLE_GETITEM) {
*out_index = GetTupleGetItemIndex(cnode);
// find tuple_get_item's previous node
auto pre_node = cnode->input(1);
if (!pre_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
}
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
*pre_operator_info = pre_op_info;
return true;
}
return false;
}
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (prim->name() == DEPEND && index != 1) {
continue;
}
if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
continue;
}
return true;
}
MS_LOG(WARNING)
<< "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
return false;
}

// Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
// if reshape's output connect to several primitive, return the first layout found
bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[cnode];
for (auto &node_pair : node_set) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
auto op_info = use_apply->user_data<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = op_info;
*in_index = node_pair.second - 1;
return true;
}
MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " "
<< IsParallelCareNode(use_apply) << " " << (op_info != nullptr);

if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
return true;
}
}
return false;
}
} // namespace parallel
} // namespace mindspore

+ 24
- 0
mindspore/ccsrc/frontend/parallel/graph_util/node_info.h View File

@@ -18,13 +18,37 @@
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_

#include <string>
#include <vector>
#include <memory>
#include <unordered_set>
#include "base/base.h"
#include "ir/anf.h"
#include "frontend/parallel/ops_info/operator_info.h"

namespace mindspore {
namespace parallel {
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
std::string ParameterName(const AnfNodePtr &node_ptr);

bool ParameterRequireGrad(const AnfNodePtr &node_ptr);

size_t GetLengthOfDataType(const TypePtr &type);

std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node);

std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node);

std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node);

std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph);

bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);

bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache);

bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index);

bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index);
} // namespace parallel
} // namespace mindspore



+ 12
- 283
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -63,7 +63,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
// check whether strategy_search_mode is valid
std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
// Setting searching mode: dynanic programming as default.
// Setting searching mode: dynamic programming as default.
strategy_search_mode = DYNAMIC_PROGRAMMING;
MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
}
@@ -112,170 +112,6 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
return changes;
}

// Given the node, return whether each input is a parameter or a output of a operator.
// The returned boolean vector should be the same order of the inputs, thus its implementation
// is closely consistent with ExtractShape() in step_parallel.cc
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
std::vector<bool> is_parameter;
std::vector<AnfNodePtr> node_inputs{node->inputs()};
// input is a ValueList or ValueTuple, then all inputs are not parameter.
if ((node_inputs.size() == 2) &&
(IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
std::vector<ValuePtr> inputs_seq;
if (IsValueNode<ValueList>(node_inputs[1])) {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
} else {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
}
return std::vector<bool>(inputs_seq.size(), false);
}
if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
}
for (size_t i = 1; i < node_inputs.size(); ++i) {
auto input = node_inputs[i];

if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
is_parameter.push_back(ParameterRequireGrad(input_parameter));
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
is_parameter.push_back(false);
}
}
return is_parameter;
}

// Given the type, return the number of bytes to represent this type
size_t GetLengthOfDataType(const TypePtr &type) {
switch (type->type_id()) {
case kNumberTypeBool:
return sizeof(bool);
case kNumberTypeInt8:
return sizeof(int8_t);
case kNumberTypeInt16:
return sizeof(int16_t);
case kNumberTypeInt32:
return sizeof(int32_t);
case kNumberTypeInt64:
return sizeof(int64_t);
case kNumberTypeUInt8:
return sizeof(uint8_t);
case kNumberTypeUInt16:
return sizeof(uint16_t);
case kNumberTypeUInt32:
return sizeof(uint32_t);
case kNumberTypeUInt64:
return sizeof(uint64_t);
case kNumberTypeFloat16:
return sizeof(float) / 2;
case kNumberTypeFloat32:
return sizeof(float);
case kNumberTypeFloat64:
return sizeof(double);
case kNumberTypeInt:
return sizeof(int64_t);
case kNumberTypeUInt:
return sizeof(unsigned int64_t);
case kNumberTypeFloat:
return sizeof(float);
default:
MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
}
}

size_t GetInputsTypeLen(const AnfNodePtr &input) {
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
}

size_t input_type_len = 0;
auto type = input->Type();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<mindspore::TensorType>()) {
auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
input_type_len = GetLengthOfDataType(input_element_type);
} else {
MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
}
return input_type_len;
}

std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<size_t> inputs_type_len;
std::vector<AnfNodePtr> node_inputs{node->inputs()};

if ((node_inputs.size() == 2) &&
(IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
std::vector<ValuePtr> inputs_seq;
if (IsValueNode<ValueList>(node_inputs[1])) {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
} else {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
}
for (auto &ele : inputs_seq) {
auto tensor = ele->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
}
return inputs_type_len;
}

if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
}

// extract input element length
for (auto &input : node_inputs) {
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
if (parameters.size() != 1) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
} else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
// extract input shape from parameter and apply node
inputs_type_len.push_back(GetInputsTypeLen(input));
}
}
return inputs_type_len;
}

std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<TypePtr> outputs_type;
// extract output element type
auto primary_output_type = node->Type();
MS_EXCEPTION_IF_NULL(primary_output_type);
if (primary_output_type->isa<mindspore::Tuple>()) {
// in this case, the output is a tuple
auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
auto elements = tuple_output_type->elements();
for (auto &ele : elements) {
if (ele->isa<mindspore::TensorType>()) {
auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
outputs_type.push_back(ele_element_type);
} else {
MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
}
}
} else {
// in this case, the output is a single tensor
if (primary_output_type->isa<mindspore::TensorType>()) {
auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
outputs_type.push_back(element_type);
} else {
MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
}
}
return outputs_type;
}

bool IsElementWiseOperator(const std::string &op_name) {
// clang-format off
static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
@@ -381,6 +217,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn
return true;
}

void InitCostGraph() {
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
}

OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode);
@@ -491,8 +332,6 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
// Using CNode's UniqueIds to construct nodes
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@@ -506,7 +345,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
// Step 1
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@@ -588,8 +427,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
// Using CNode's UniqueIdThroughCopys to construct nodes
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@@ -937,115 +774,6 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
}
}

bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == RESHAPE) {
auto operator_info = cnode->user_data<OperatorInfo>();
std::string op_info_name = operator_info->name();
if (op_cache->find(op_info_name) != op_cache->end()) {
return false;
}
op_cache->insert(op_info_name);
return true;
}
return false;
}

// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) {
// if previous node is a parameter, handle it in the outsize.
if (node->isa<Parameter>()) {
return false;
}
if (!node->isa<CNode>()) {
return false;
}
CNodePtr cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
auto node_op_info = cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
*pre_operator_info = node_op_info;
*out_index = 0;
return true;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == TUPLE_GETITEM) {
*out_index = GetTupleGetItemIndex(cnode);
// find tuple_get_item's previous node
auto pre_node = cnode->input(1);
if (!pre_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
}
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
*pre_operator_info = pre_op_info;
return true;
}
return false;
}
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (prim->name() == DEPEND && index != 1) {
continue;
}
if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
continue;
}
return true;
}
MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
return false;
}

// find next node, then obtain its strategy_cost_ vector to get its layout vector.
// if reshape's output connect to several primitive, return the first layout found
bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[cnode];
for (auto &node_pair : node_set) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
auto op_info = use_apply->user_data<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = op_info;
*in_index = node_pair.second - 1;
return true;
}
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << (op_info != nullptr);

if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
return true;
}
}
return false;
}

void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
std::unordered_set<std::string> op_cache;
for (auto node : all_nodes) {
@@ -1066,8 +794,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
pre_operator_info = reshape_info;
pre_stra_costs = reshape_info->strategy_cost();
} else {
if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
}
pre_stra_costs = pre_operator_info->strategy_cost();
}
@@ -1075,9 +803,9 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
int64_t in_index = 0;
OperatorInfoPtr next_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index);
if (!find_next_node) {
MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
}
// set input_layout and output_layout for reshape.
// init reshape and set cost for each input_layout and output_layout.
@@ -1122,6 +850,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
//
// OUTPUT: the determined strategy for each operator.

InitCostGraph();
// Step 1
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {


+ 3
- 9
mindspore/ccsrc/frontend/parallel/step_auto_parallel.h View File

@@ -28,20 +28,14 @@

namespace mindspore {
namespace parallel {
bool IsSplittableOperator(const std::string &);

bool IsAutoParallelCareNode(const CNodePtr &);

// main step of Auto-parallel
bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);

size_t GetLengthOfDataType(const TypePtr &type);

std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node);
bool IsSplittableOperator(const std::string &);

std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node);
bool IsAutoParallelCareNode(const CNodePtr &);

std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node);
void InitCostGraph();

Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);



+ 0
- 52
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -292,22 +292,6 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
return tensorinfo_in.tensor_layout();
}

bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}

auto value_node = cnode->input(0)->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == prim_name) {
return true;
}
return false;
}

std::string GetPrimName(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
@@ -1219,42 +1203,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
return shapes;
}

std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters;
if (!IsValueNode<RefKey>(node)) {
MS_LOG(ERROR) << "The node is not a ref key";
return parameters;
}

auto ref_key = GetValueNode<RefKeyPtr>(node);
MS_EXCEPTION_IF_NULL(ref_key);
auto name = ref_key->tag();

auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto roots = manager->roots();
if (roots.size() != 1) {
MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1";
return parameters;
}

FuncGraphPtr root_g = roots.back();
MS_EXCEPTION_IF_NULL(root_g);
for (auto &param_node : root_g->parameters()) {
auto param = param_node->cast<ParameterPtr>();
if (param && (name == param->name())) {
parameters.push_back(param_node);
MS_LOG(INFO) << "The name of ref key is: " << name;
return parameters;
}
}

MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter";
return parameters;
}

Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);


+ 0
- 4
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -100,8 +100,6 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs);

Shapes GetNodeShape(const AnfNodePtr &node);

std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph);

// Extract shape from anfnode
std::vector<Shapes> ExtractShape(const CNodePtr &node);

@@ -154,8 +152,6 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);

std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);

bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);

using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>;
using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>;



Loading…
Cancel
Save