From: @ch-l Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -309,7 +309,14 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph, | |||||
| } | } | ||||
| std::vector<int64_t> axis_list; | std::vector<int64_t> axis_list; | ||||
| auto iter = ops[iter_ops]->attrs().find(AXIS); | |||||
| string axis_name = AXIS; | |||||
| int64_t default_axis = -1; | |||||
| if (ops[iter_ops]->type() == LAYER_NORM) { | |||||
| axis_name = "begin_norm_axis"; | |||||
| default_axis = 1; | |||||
| } | |||||
| auto iter = ops[iter_ops]->attrs().find(axis_name); | |||||
| if (iter != ops[iter_ops]->attrs().end()) { | if (iter != ops[iter_ops]->attrs().end()) { | ||||
| MS_EXCEPTION_IF_NULL(iter->second); | MS_EXCEPTION_IF_NULL(iter->second); | ||||
| if (iter->second->isa<Int64Imm>()) { | if (iter->second->isa<Int64Imm>()) { | ||||
| @@ -326,8 +333,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph, | |||||
| MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t."; | MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t."; | ||||
| } | } | ||||
| } else { | } else { | ||||
| axis_list.push_back(-1); | |||||
| axis_list.push_back(default_axis); | |||||
| } | } | ||||
| for (auto &axis : axis_list) { | for (auto &axis : axis_list) { | ||||
| if (axis < 0) { | if (axis < 0) { | ||||
| int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); | int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); | ||||
| @@ -481,10 +489,10 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector | |||||
| return PrepareMatMul(graph, ops, iter_graph, iter_ops); | return PrepareMatMul(graph, ops, iter_graph, iter_ops); | ||||
| } else if (type == ONEHOT) { | } else if (type == ONEHOT) { | ||||
| return PrepareOneHot(graph, ops, iter_graph, iter_ops); | return PrepareOneHot(graph, ops, iter_graph, iter_ops); | ||||
| } else if (type == SOFTMAX) { | |||||
| } else if ((type == SOFTMAX) || (type == LAYER_NORM)) { | |||||
| return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops); | return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops); | ||||
| } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || | } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || | ||||
| (type == "FusedBatchNormEx") || (type == "Dropout")) { | |||||
| (type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) { | |||||
| return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | ||||
| } else { | } else { | ||||
| return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | ||||
| @@ -51,7 +51,8 @@ enum OperatorType { | |||||
| kRecReduce, | kRecReduce, | ||||
| kRecPReLU, | kRecPReLU, | ||||
| kRecGatherV2, | kRecGatherV2, | ||||
| kRecArgWithValue | |||||
| kRecArgWithValue, | |||||
| kRecUnsortedSegmentOp | |||||
| }; | }; | ||||
| enum InfoType { kApplication, kConstant }; | enum InfoType { kApplication, kConstant }; | ||||
| @@ -61,6 +61,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> | |||||
| NewOp.tensor_parm = MakeTensor( | NewOp.tensor_parm = MakeTensor( | ||||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], | ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], | ||||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); | ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); | ||||
| } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) { | |||||
| NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], | |||||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[1], | |||||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[2]); | |||||
| } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { | } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { | ||||
| NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], | NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], | ||||
| ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); | ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); | ||||
| @@ -69,7 +73,7 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> | |||||
| } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { | } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { | ||||
| NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); | NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Tensor's shape is unknown."; | |||||
| MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected."; | |||||
| } | } | ||||
| NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); | NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); | ||||
| @@ -90,6 +94,11 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf | |||||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], | ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], | ||||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], | ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], | ||||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); | ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); | ||||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) { | |||||
| NewTensor.apply.arguments[iter_input_tensors] = | |||||
| MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], | |||||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], | |||||
| ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]); | |||||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { | } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { | ||||
| NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); | NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); | ||||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { | } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { | ||||
| @@ -98,7 +107,7 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf | |||||
| } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { | } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { | ||||
| NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); | NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Tensor's shape is unknown."; | |||||
| MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected."; | |||||
| } | } | ||||
| } | } | ||||
| return NewTensor.apply; | return NewTensor.apply; | ||||
| @@ -47,6 +47,7 @@ const std::map<std::string, OperatorType> DictOpType{ | |||||
| {BIAS_ADD, OperatorType::kRecBiasAdd}, | {BIAS_ADD, OperatorType::kRecBiasAdd}, | ||||
| {BATCH_NORM, OperatorType::kRecBatchNorm}, | {BATCH_NORM, OperatorType::kRecBatchNorm}, | ||||
| {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, | {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, | ||||
| {LAYER_NORM, OperatorType::kRecBatchNorm}, | |||||
| {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, | {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, | ||||
| {ONEHOT, OperatorType::kRecOneHot}, | {ONEHOT, OperatorType::kRecOneHot}, | ||||
| {SQUEEZE, OperatorType::kRecSqueeze}, | {SQUEEZE, OperatorType::kRecSqueeze}, | ||||
| @@ -58,6 +59,9 @@ const std::map<std::string, OperatorType> DictOpType{ | |||||
| {GATHERV2, OperatorType::kRecGatherV2}, | {GATHERV2, OperatorType::kRecGatherV2}, | ||||
| {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, | {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, | ||||
| {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, | {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, | ||||
| {UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp}, | |||||
| {UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp}, | |||||
| {UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp}, | |||||
| // Activation OP | // Activation OP | ||||
| {ACTIVATION, OperatorType::kRecReLU}, | {ACTIVATION, OperatorType::kRecReLU}, | ||||
| {RELU, OperatorType::kRecReLU}, | {RELU, OperatorType::kRecReLU}, | ||||
| @@ -139,7 +143,8 @@ const std::map<std::string, OperatorType> DictOpType{ | |||||
| {ASSIGN, OperatorType::kRecElmWiseOp}, | {ASSIGN, OperatorType::kRecElmWiseOp}, | ||||
| {ASSIGN_ADD, OperatorType::kRecElmWiseOp}, | {ASSIGN_ADD, OperatorType::kRecElmWiseOp}, | ||||
| {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, | {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, | ||||
| {"AssignAdd", OperatorType::kRecElmWiseOp}}; | |||||
| {"AssignAdd", OperatorType::kRecElmWiseOp}, | |||||
| {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}}; | |||||
| const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w); | const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w); | ||||
| @@ -76,7 +76,8 @@ double GetWeights(const Graph::NodeType &node) { | |||||
| return cost_ptr->GetMinCostIn(); | return cost_ptr->GetMinCostIn(); | ||||
| } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || | } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || | ||||
| op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || | |||||
| op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp || | |||||
| op.op_type == OperatorType::kRecSoftmax || | |||||
| op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || | op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || | ||||
| op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { | op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { | ||||
| // For BatchParallel op | // For BatchParallel op | ||||
| @@ -172,7 +173,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, | |||||
| return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); | return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); | ||||
| } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || | } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || | ||||
| node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || | node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || | ||||
| node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { | |||||
| node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || | |||||
| node.apply.op_type == kRecUnsortedSegmentOp) { | |||||
| // For BatchParallel type | // For BatchParallel type | ||||
| auto cost_ptr = std::make_shared<CostBatchParallel>(); | auto cost_ptr = std::make_shared<CostBatchParallel>(); | ||||
| return cost_ptr->GetOptimalStr(node); | return cost_ptr->GetOptimalStr(node); | ||||