| @@ -164,9 +164,34 @@ std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &gr | |||||
| return strategies; | return strategies; | ||||
| } | } | ||||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s) { | |||||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||||
| const size_t iter_ops, std::vector<int32_t> s) { | |||||
| std::vector<std::vector<int32_t>> strategies; | std::vector<std::vector<int32_t>> strategies; | ||||
| strategies.push_back(*s); | |||||
| int32_t axis = 0; | |||||
| auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2)); | |||||
| if (axis_input < 0) { | |||||
| axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); | |||||
| } | |||||
| axis = axis_input; | |||||
| if (axis >= SizeToInt(s.size())) { | |||||
| MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; | |||||
| } | |||||
| s[axis] = 1; | |||||
| strategies.push_back(s); | |||||
| auto pos = ops[iter_ops]->name().find("Info"); | |||||
| auto name = ops[iter_ops]->name().substr(0, pos); | |||||
| if (name == "GatherV2") { | |||||
| return strategies; | |||||
| } | |||||
| std::vector<int32_t> s_indices; | |||||
| for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { | |||||
| s_indices.push_back(1); | |||||
| } | |||||
| strategies.push_back(s_indices); | |||||
| return strategies; | return strategies; | ||||
| } | } | ||||
| @@ -607,7 +632,7 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect | |||||
| return PrepareBiasAdd(s_ptr); | return PrepareBiasAdd(s_ptr); | ||||
| } | } | ||||
| if (ops[iter_ops]->type() == GATHERV2) { | if (ops[iter_ops]->type() == GATHERV2) { | ||||
| return PrepareGatherV2(s_ptr); | |||||
| return PrepareGatherV2(ops, iter_ops, basic_stra); | |||||
| } | } | ||||
| if (ops[iter_ops]->type() == L2_NORMALIZE) { | if (ops[iter_ops]->type() == L2_NORMALIZE) { | ||||
| return PrepareL2Normalize(ops, iter_ops, basic_stra); | return PrepareL2Normalize(ops, iter_ops, basic_stra); | ||||
| @@ -38,7 +38,8 @@ std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vect | |||||
| std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph, | std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph, | ||||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const size_t iter_graph, const size_t iter_ops); | const size_t iter_graph, const size_t iter_ops); | ||||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s); | |||||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||||
| const size_t iter_ops, std::vector<int32_t> s); | |||||
| std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const size_t iter_ops, std::vector<int32_t> s); | const size_t iter_ops, std::vector<int32_t> s); | ||||
| std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, | std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, | ||||
| @@ -40,7 +40,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops) { | |||||
| Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) { | |||||
| Graph::NodeType NewOp; | Graph::NodeType NewOp; | ||||
| NewOp.name = ops[iter_ops]->name(); | NewOp.name = ops[iter_ops]->name(); | ||||
| NewOp.info = InfoType::kApplication; | NewOp.info = InfoType::kApplication; | ||||
| @@ -140,7 +140,7 @@ std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo | |||||
| return graph; | return graph; | ||||
| } | } | ||||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, std::shared_ptr<Graph> graph) { | |||||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) { | |||||
| for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { | for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { | ||||
| for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { | for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { | ||||
| size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); | size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); | ||||
| @@ -110,7 +110,7 @@ const std::map<std::string, OperatorType> DictOpType{ | |||||
| const TensorParam MakeTensor(int n, int c, int h, int w); | const TensorParam MakeTensor(int n, int c, int h, int w); | ||||
| Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops); | |||||
| Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops); | |||||
| OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, | OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, | ||||
| Graph::NodeType NewTensor); | Graph::NodeType NewTensor); | ||||
| @@ -121,7 +121,7 @@ TensorParam Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &o | |||||
| std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const std::vector<std::vector<std::string>> &input_tensor_names); | const std::vector<std::vector<std::string>> &input_tensor_names); | ||||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, std::shared_ptr<Graph> graph); | |||||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph); | |||||
| size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names, | size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names, | ||||
| const std::string &input_name); | const std::string &input_name); | ||||
| @@ -93,7 +93,7 @@ double GetWeights(const Graph::NodeType &node) { | |||||
| } | } | ||||
| // Sort all the nodes by their weights | // Sort all the nodes by their weights | ||||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) { | |||||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::vector<std::pair<double, size_t>> weight_to_node_index; | std::vector<std::pair<double, size_t>> weight_to_node_index; | ||||
| @@ -124,7 +124,7 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) { | |||||
| // Get optimal strategy to partition the target node | // Get optimal strategy to partition the target node | ||||
| StrategyRec PartitionNode(const Graph::NodeType &node, | StrategyRec PartitionNode(const Graph::NodeType &node, | ||||
| const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, | const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, | ||||
| std::shared_ptr<Graph> graph) { | |||||
| const std::shared_ptr<Graph> &graph) { | |||||
| bool enable_conv_chw_partition = false; | bool enable_conv_chw_partition = false; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| @@ -191,7 +191,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, | |||||
| } | } | ||||
| // Parttion graph into all devices. | // Parttion graph into all devices. | ||||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) { | |||||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, | |||||
| const std::shared_ptr<Graph> &graph) { | |||||
| if (num_device < 1) { | if (num_device < 1) { | ||||
| MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; | MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; | ||||
| } | } | ||||
| @@ -261,7 +262,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { | |||||
| return Node; | return Node; | ||||
| } | } | ||||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) { | |||||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| if (num_device == 0) { | if (num_device == 0) { | ||||
| MS_LOG(EXCEPTION) << "Failure: device number is 0."; | MS_LOG(EXCEPTION) << "Failure: device number is 0."; | ||||
| @@ -32,19 +32,19 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph); | |||||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph); | |||||
| double GetWeights(const Graph::NodeType &node); | double GetWeights(const Graph::NodeType &node); | ||||
| StrategyRec PartitionNode(const Graph::NodeType &node, | StrategyRec PartitionNode(const Graph::NodeType &node, | ||||
| const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, | const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, | ||||
| std::shared_ptr<Graph> graph); | |||||
| const std::shared_ptr<Graph> &graph); | |||||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph); | |||||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph); | |||||
| Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); | Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); | ||||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph); | |||||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph); | |||||
| size_t GetDataTypeSize(const TensorType &type); | size_t GetDataTypeSize(const TensorType &type); | ||||
| } // namespace parallel | } // namespace parallel | ||||