| @@ -164,9 +164,34 @@ std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &gr | |||
| return strategies; | |||
| } | |||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s) { | |||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_ops, std::vector<int32_t> s) { | |||
| std::vector<std::vector<int32_t>> strategies; | |||
| strategies.push_back(*s); | |||
| int32_t axis = 0; | |||
| auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2)); | |||
| if (axis_input < 0) { | |||
| axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); | |||
| } | |||
| axis = axis_input; | |||
| if (axis >= SizeToInt(s.size())) { | |||
| MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; | |||
| } | |||
| s[axis] = 1; | |||
| strategies.push_back(s); | |||
| auto pos = ops[iter_ops]->name().find("Info"); | |||
| auto name = ops[iter_ops]->name().substr(0, pos); | |||
| if (name == "GatherV2") { | |||
| return strategies; | |||
| } | |||
| std::vector<int32_t> s_indices; | |||
| for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { | |||
| s_indices.push_back(1); | |||
| } | |||
| strategies.push_back(s_indices); | |||
| return strategies; | |||
| } | |||
| @@ -607,7 +632,7 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect | |||
| return PrepareBiasAdd(s_ptr); | |||
| } | |||
| if (ops[iter_ops]->type() == GATHERV2) { | |||
| return PrepareGatherV2(s_ptr); | |||
| return PrepareGatherV2(ops, iter_ops, basic_stra); | |||
| } | |||
| if (ops[iter_ops]->type() == L2_NORMALIZE) { | |||
| return PrepareL2Normalize(ops, iter_ops, basic_stra); | |||
| @@ -38,7 +38,8 @@ std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vect | |||
| std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s); | |||
| std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_ops, std::vector<int32_t> s); | |||
| std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_ops, std::vector<int32_t> s); | |||
| std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, | |||
| @@ -40,7 +40,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { | |||
| return tensor; | |||
| } | |||
| Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops) { | |||
| Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) { | |||
| Graph::NodeType NewOp; | |||
| NewOp.name = ops[iter_ops]->name(); | |||
| NewOp.info = InfoType::kApplication; | |||
| @@ -140,7 +140,7 @@ std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo | |||
| return graph; | |||
| } | |||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, std::shared_ptr<Graph> graph) { | |||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) { | |||
| for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { | |||
| for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { | |||
| size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); | |||
| @@ -110,7 +110,7 @@ const std::map<std::string, OperatorType> DictOpType{ | |||
| const TensorParam MakeTensor(int n, int c, int h, int w); | |||
| Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops); | |||
| Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops); | |||
| OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, | |||
| Graph::NodeType NewTensor); | |||
| @@ -121,7 +121,7 @@ TensorParam Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &o | |||
| std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const std::vector<std::vector<std::string>> &input_tensor_names); | |||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, std::shared_ptr<Graph> graph); | |||
| void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph); | |||
| size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names, | |||
| const std::string &input_name); | |||
| @@ -93,7 +93,7 @@ double GetWeights(const Graph::NodeType &node) { | |||
| } | |||
| // Sort all the nodes by their weights | |||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) { | |||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<std::pair<double, size_t>> weight_to_node_index; | |||
| @@ -124,7 +124,7 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) { | |||
| // Get optimal strategy to partition the target node | |||
| StrategyRec PartitionNode(const Graph::NodeType &node, | |||
| const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, | |||
| std::shared_ptr<Graph> graph) { | |||
| const std::shared_ptr<Graph> &graph) { | |||
| bool enable_conv_chw_partition = false; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -191,7 +191,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, | |||
| } | |||
| // Parttion graph into all devices. | |||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) { | |||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, | |||
| const std::shared_ptr<Graph> &graph) { | |||
| if (num_device < 1) { | |||
| MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; | |||
| } | |||
| @@ -261,7 +262,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { | |||
| return Node; | |||
| } | |||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) { | |||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (num_device == 0) { | |||
| MS_LOG(EXCEPTION) << "Failure: device number is 0."; | |||
| @@ -32,19 +32,19 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph); | |||
| std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph); | |||
| double GetWeights(const Graph::NodeType &node); | |||
| StrategyRec PartitionNode(const Graph::NodeType &node, | |||
| const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, | |||
| std::shared_ptr<Graph> graph); | |||
| const std::shared_ptr<Graph> &graph); | |||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph); | |||
| Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph); | |||
| Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); | |||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph); | |||
| Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph); | |||
| size_t GetDataTypeSize(const TensorType &type); | |||
| } // namespace parallel | |||