Merge pull request !2474 from Chong/zctags/v0.6.0-beta
| @@ -703,5 +703,48 @@ StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, Stra | |||
| } | |||
| return str; | |||
| } | |||
| // Chose strategy for CostSoftmaxCrossEntropyWithLogits | |||
| StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) { | |||
| uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); | |||
| if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { | |||
| return str; | |||
| } | |||
| switch (min_position) { | |||
| case 0: | |||
| str.inputTensor[0].str_n /= 2.0; | |||
| str.inputTensor[1].str_n /= 2.0; | |||
| str.cut_counter += 1; | |||
| str.cost = str.cost + cost_in_; | |||
| break; | |||
| case 1: | |||
| str.inputTensor[0].str_c /= 2.0; | |||
| str.inputTensor[1].str_c /= 2.0; | |||
| str.cut_counter += 1; | |||
| str.cost = str.cost + cost_in_; | |||
| break; | |||
| case 2: | |||
| str.inputTensor[0].str_h /= 2.0; | |||
| str.inputTensor[1].str_h /= 2.0; | |||
| str.outputTensor.str_w /= 2.0; | |||
| str.cut_counter += 1; | |||
| str.cost = str.cost + cost_in_; | |||
| break; | |||
| case 3: | |||
| str.inputTensor[0].str_w /= 2.0; | |||
| str.inputTensor[1].str_w /= 2.0; | |||
| str.cut_counter += 1; | |||
| str.cost = str.cost + cost_in_; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; | |||
| } | |||
| return str; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -222,6 +222,12 @@ class CostBatchParallel { | |||
| class CostBatchNorm : public CostBatchParallel {}; | |||
| class CostOneHot : public CostBatchParallel {}; | |||
| class CostPRelu : public CostBatchParallel {}; | |||
| class CostSoftmax : public CostBatchParallel {}; | |||
| class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { | |||
| StrategyRec ChoseStr(const std::vector<double> &cost_op, StrategyRec str); | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ | |||
| @@ -127,14 +127,6 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr | |||
| return strategies; | |||
| } | |||
| std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops) { | |||
| std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | |||
| strategies[1][0] = 1; | |||
| return strategies; | |||
| } | |||
| std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) { | |||
| std::vector<std::vector<int32_t>> strategies; | |||
| strategies.push_back(*s); | |||
| @@ -164,6 +156,32 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vec | |||
| return strategies; | |||
| } | |||
| std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_ops, std::vector<int32_t> s) { | |||
| int32_t axis = 0; | |||
| auto iter = ops[iter_ops]->attrs().find(AXIS); | |||
| if (iter != ops[iter_ops]->attrs().end()) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| if (iter->second->isa<Int32Imm>()) { | |||
| axis = iter->second->cast<Int32ImmPtr>()->value(); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; | |||
| } | |||
| } | |||
| int32_t axis_index = axis; | |||
| if (axis < 0) { | |||
| size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); | |||
| axis_index = static_cast<int32_t>(input_dim) + axis; | |||
| } | |||
| s[IntToSize(axis_index)] = 1; | |||
| std::vector<std::vector<int32_t>> strategies; | |||
| strategies.push_back(s); | |||
| return strategies; | |||
| } | |||
| std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops) { | |||
| @@ -279,13 +297,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> & | |||
| if (type == MATMUL) { | |||
| return PrepareMatMul(graph, ops, iter_graph, iter_ops); | |||
| } else if (type == PRELU) { | |||
| return PreparePReLU(graph, ops, iter_graph, iter_ops); | |||
| } else if (type == ONEHOT) { | |||
| return PrepareOneHot(graph, ops, iter_graph, iter_ops); | |||
| } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS || | |||
| type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { | |||
| return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | |||
| } else { | |||
| return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | |||
| } | |||
| @@ -510,6 +523,9 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect | |||
| if (ops[iter_ops]->type() == GATHERV2) { | |||
| return PrepareGatherV2(s_ptr); | |||
| } | |||
| if (ops[iter_ops]->type() == L2_NORMALIZE) { | |||
| return PrepareL2Normalize(ops, iter_ops, basic_stra); | |||
| } | |||
| for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); | |||
| iter_op_inputs++) { | |||
| @@ -34,14 +34,13 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share | |||
| std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s); | |||
| std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s); | |||
| std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_ops, std::vector<int32_t> s); | |||
| std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| @@ -38,6 +38,7 @@ enum OperatorType { | |||
| kRecBiasAdd, | |||
| kRecSoftmax, | |||
| kRecSparseSoftmaxCrossEntropyWithLogits, | |||
| kRecSoftmaxCrossEntropyWithLogits, | |||
| kRecOneHot, | |||
| kRecLog, | |||
| kRecExp, | |||
| @@ -250,12 +250,22 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph, | |||
| new_graph->nodes.push_back(graph->nodes[i]); | |||
| auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; | |||
| for (size_t j = 0; j < node_in->size(); j++) { | |||
| node_in->at(j) = index_list->at(node_in->at(j)); | |||
| for (size_t j = node_in->size(); j > 0; j--) { | |||
| bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); | |||
| if (IsEliminated) { | |||
| node_in->erase(node_in->begin() + j - 1); | |||
| } else { | |||
| node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); | |||
| } | |||
| } | |||
| auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; | |||
| for (size_t j = 0; j < node_out->size(); j++) { | |||
| node_out->at(j) = index_list->at(node_out->at(j)); | |||
| for (size_t j = node_out->size(); j > 0; j--) { | |||
| bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); | |||
| if (IsEliminated) { | |||
| node_out->erase(node_out->begin() + j - 1); | |||
| } else { | |||
| node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); | |||
| } | |||
| } | |||
| } | |||
| return new_graph; | |||
| @@ -67,7 +67,7 @@ const std::map<std::string, OperatorType> DictOpType{ | |||
| {REAL_DIV, OperatorType::kRecElmWiseOp}, | |||
| {SOFTMAX, OperatorType::kRecSoftmax}, | |||
| {LOG_SOFTMAX, OperatorType::kRecSoftmax}, | |||
| {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmax}, | |||
| {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, | |||
| {SQRT, OperatorType::kRecElmWiseOp}, | |||
| {NEG, OperatorType::kRecElmWiseOp}, | |||
| {POW, OperatorType::kRecElmWiseOp}, | |||
| @@ -76,15 +76,16 @@ double GetWeights(const Graph::NodeType &node) { | |||
| auto cost_ptr = std::make_shared<CostCommon>(); | |||
| return cost_ptr->GetMinCostIn(); | |||
| } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot) { | |||
| } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || | |||
| op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || | |||
| op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || | |||
| op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { | |||
| // For BatchParallel op | |||
| auto cost_ptr = std::make_shared<CostBatchParallel>(); | |||
| return cost_ptr->GetMaxCostIn(); | |||
| } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU || | |||
| op.op_type == OperatorType::kRecSoftmax || | |||
| op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { | |||
| // For unprocessed type | |||
| } else if (op.op_type == OperatorType::kRecUnkownType) { | |||
| // For Unkown type | |||
| return 0.0; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; | |||
| @@ -170,14 +171,18 @@ StrategyRec PartitionNode(const Graph::NodeType &node, | |||
| auto cost_ptr = std::make_shared<CostCommon>(); | |||
| return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); | |||
| } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot) { | |||
| } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || | |||
| node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || | |||
| node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { | |||
| // For BatchParallel type | |||
| auto cost_ptr = std::make_shared<CostBatchParallel>(); | |||
| return cost_ptr->GetOptimalStr(node); | |||
| } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU || | |||
| node.apply.op_type == OperatorType::kRecSoftmax || | |||
| node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { | |||
| // For unprocessed type | |||
| } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { | |||
| // For SoftmaxCrossEntropyWithLogits type | |||
| auto cost_ptr = std::make_shared<CostSoftmaxCrossEntropyWithLogits>(); | |||
| return cost_ptr->GetOptimalStr(node); | |||
| } else if (node.apply.op_type == OperatorType::kRecUnkownType) { | |||
| // For Unkown type | |||
| StrategyRec default_strategy; | |||
| return default_strategy; | |||
| } else { | |||