Browse Source

!9122 [AutoParallel] add LayerNorm Dropout and SegmentSum/Max/Min for GPT

From: @ch-l
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a09f1e30b6
5 changed files with 35 additions and 10 deletions
  1. +12
    -4
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h
  3. +11
    -2
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc
  4. +6
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h
  5. +4
    -2
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc

+ 12
- 4
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -309,7 +309,14 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
}

std::vector<int64_t> axis_list;
auto iter = ops[iter_ops]->attrs().find(AXIS);
string axis_name = AXIS;
int64_t default_axis = -1;
if (ops[iter_ops]->type() == LAYER_NORM) {
axis_name = "begin_norm_axis";
default_axis = 1;
}

auto iter = ops[iter_ops]->attrs().find(axis_name);
if (iter != ops[iter_ops]->attrs().end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int64Imm>()) {
@@ -326,8 +333,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
}
} else {
axis_list.push_back(-1);
axis_list.push_back(default_axis);
}

for (auto &axis : axis_list) {
if (axis < 0) {
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
@@ -481,10 +489,10 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX) {
} else if ((type == SOFTMAX) || (type == LAYER_NORM)) {
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
(type == "FusedBatchNormEx") || (type == "Dropout")) {
(type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);


+ 2
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h View File

@@ -51,7 +51,8 @@ enum OperatorType {
kRecReduce,
kRecPReLU,
kRecGatherV2,
kRecArgWithValue
kRecArgWithValue,
kRecUnsortedSegmentOp
};

enum InfoType { kApplication, kConstant };


+ 11
- 2
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc View File

@@ -61,6 +61,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
NewOp.tensor_parm = MakeTensor(
ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[2]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
@@ -69,7 +73,7 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) {
NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
} else {
MS_LOG(ERROR) << "Tensor's shape is unknown.";
MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
}
NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp);
@@ -90,6 +94,11 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]);
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) {
NewTensor.apply.arguments[iter_input_tensors] =
MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]);
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) {
NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) {
@@ -98,7 +107,7 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) {
NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
} else {
MS_LOG(ERROR) << "Tensor's shape is unknown.";
MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
}
}
return NewTensor.apply;


+ 6
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h View File

@@ -47,6 +47,7 @@ const std::map<std::string, OperatorType> DictOpType{
{BIAS_ADD, OperatorType::kRecBiasAdd},
{BATCH_NORM, OperatorType::kRecBatchNorm},
{FUSE_BATCH_NORM, OperatorType::kRecBatchNorm},
{LAYER_NORM, OperatorType::kRecBatchNorm},
{SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
{ONEHOT, OperatorType::kRecOneHot},
{SQUEEZE, OperatorType::kRecSqueeze},
@@ -58,6 +59,9 @@ const std::map<std::string, OperatorType> DictOpType{
{GATHERV2, OperatorType::kRecGatherV2},
{ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
{ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
{UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp},
{UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp},
{UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp},
// Activation OP
{ACTIVATION, OperatorType::kRecReLU},
{RELU, OperatorType::kRecReLU},
@@ -139,7 +143,8 @@ const std::map<std::string, OperatorType> DictOpType{
{ASSIGN, OperatorType::kRecElmWiseOp},
{ASSIGN_ADD, OperatorType::kRecElmWiseOp},
{ASSIGN_SUB, OperatorType::kRecElmWiseOp},
{"AssignAdd", OperatorType::kRecElmWiseOp}};
{"AssignAdd", OperatorType::kRecElmWiseOp},
{DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}};

const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);



+ 4
- 2
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc View File

@@ -76,7 +76,8 @@ double GetWeights(const Graph::NodeType &node) {

return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot ||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax ||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp ||
op.op_type == OperatorType::kRecSoftmax ||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
// For BatchParallel op
@@ -172,7 +173,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot ||
node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax ||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
node.apply.op_type == kRecUnsortedSegmentOp) {
// For BatchParallel type
auto cost_ptr = std::make_shared<CostBatchParallel>();
return cost_ptr->GetOptimalStr(node);


Loading…
Cancel
Save