From: @ch-l Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -309,7 +309,14 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph, | |||
| } | |||
| std::vector<int64_t> axis_list; | |||
| auto iter = ops[iter_ops]->attrs().find(AXIS); | |||
| string axis_name = AXIS; | |||
| int64_t default_axis = -1; | |||
| if (ops[iter_ops]->type() == LAYER_NORM) { | |||
| axis_name = "begin_norm_axis"; | |||
| default_axis = 1; | |||
| } | |||
| auto iter = ops[iter_ops]->attrs().find(axis_name); | |||
| if (iter != ops[iter_ops]->attrs().end()) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| if (iter->second->isa<Int64Imm>()) { | |||
| @@ -326,8 +333,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph, | |||
| MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t."; | |||
| } | |||
| } else { | |||
| axis_list.push_back(-1); | |||
| axis_list.push_back(default_axis); | |||
| } | |||
| for (auto &axis : axis_list) { | |||
| if (axis < 0) { | |||
| int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); | |||
| @@ -481,10 +489,10 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector | |||
| return PrepareMatMul(graph, ops, iter_graph, iter_ops); | |||
| } else if (type == ONEHOT) { | |||
| return PrepareOneHot(graph, ops, iter_graph, iter_ops); | |||
| } else if (type == SOFTMAX) { | |||
| } else if ((type == SOFTMAX) || (type == LAYER_NORM)) { | |||
| return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops); | |||
| } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || | |||
| (type == "FusedBatchNormEx") || (type == "Dropout")) { | |||
| (type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) { | |||
| return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | |||
| } else { | |||
| return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | |||
| @@ -51,7 +51,8 @@ enum OperatorType { | |||
| kRecReduce, | |||
| kRecPReLU, | |||
| kRecGatherV2, | |||
| kRecArgWithValue | |||
| kRecArgWithValue, | |||
| kRecUnsortedSegmentOp | |||
| }; | |||
| enum InfoType { kApplication, kConstant }; | |||
| @@ -61,6 +61,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> | |||
| NewOp.tensor_parm = MakeTensor( | |||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], | |||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); | |||
| } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) { | |||
| NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], | |||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[1], | |||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[2]); | |||
| } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { | |||
| NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], | |||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); | |||
| @@ -69,7 +73,7 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> | |||
| } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { | |||
| NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); | |||
| } else { | |||
| MS_LOG(ERROR) << "Tensor's shape is unknown."; | |||
| MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected."; | |||
| } | |||
| NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); | |||
| @@ -90,6 +94,11 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf | |||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], | |||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], | |||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); | |||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) { | |||
| NewTensor.apply.arguments[iter_input_tensors] = | |||
| MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], | |||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], | |||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]); | |||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { | |||
| NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); | |||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { | |||
| @@ -98,7 +107,7 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf | |||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { | |||
| NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); | |||
| } else { | |||
| MS_LOG(ERROR) << "Tensor's shape is unknown."; | |||
| MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected."; | |||
| } | |||
| } | |||
| return NewTensor.apply; | |||
| @@ -47,6 +47,7 @@ const std::map<std::string, OperatorType> DictOpType{ | |||
| {BIAS_ADD, OperatorType::kRecBiasAdd}, | |||
| {BATCH_NORM, OperatorType::kRecBatchNorm}, | |||
| {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, | |||
| {LAYER_NORM, OperatorType::kRecBatchNorm}, | |||
| {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, | |||
| {ONEHOT, OperatorType::kRecOneHot}, | |||
| {SQUEEZE, OperatorType::kRecSqueeze}, | |||
| @@ -58,6 +59,9 @@ const std::map<std::string, OperatorType> DictOpType{ | |||
| {GATHERV2, OperatorType::kRecGatherV2}, | |||
| {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, | |||
| {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, | |||
| {UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp}, | |||
| {UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp}, | |||
| {UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp}, | |||
| // Activation OP | |||
| {ACTIVATION, OperatorType::kRecReLU}, | |||
| {RELU, OperatorType::kRecReLU}, | |||
| @@ -139,7 +143,8 @@ const std::map<std::string, OperatorType> DictOpType{ | |||
| {ASSIGN, OperatorType::kRecElmWiseOp}, | |||
| {ASSIGN_ADD, OperatorType::kRecElmWiseOp}, | |||
| {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, | |||
| {"AssignAdd", OperatorType::kRecElmWiseOp}}; | |||
| {"AssignAdd", OperatorType::kRecElmWiseOp}, | |||
| {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}}; | |||
| const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w); | |||
| @@ -76,7 +76,8 @@ double GetWeights(const Graph::NodeType &node) { | |||
| return cost_ptr->GetMinCostIn(); | |||
| } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || | |||
| op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || | |||
| op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp || | |||
| op.op_type == OperatorType::kRecSoftmax || | |||
| op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || | |||
| op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { | |||
| // For BatchParallel op | |||
| @@ -172,7 +173,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, | |||
| return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); | |||
| } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || | |||
| node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || | |||
| node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { | |||
| node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || | |||
| node.apply.op_type == kRecUnsortedSegmentOp) { | |||
| // For BatchParallel type | |||
| auto cost_ptr = std::make_shared<CostBatchParallel>(); | |||
| return cost_ptr->GetOptimalStr(node); | |||