|
|
|
@@ -29,52 +29,55 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
#define DEVICE_MEMORY 1024.0 * 1024.0 * 1024.0 // 1GB |
|
|
|
|
|
|
|
// Get the target node's weight for sorting. |
|
|
|
double GetWeights(const Graph::NodeType &node) { |
|
|
|
const OperatorRec &op = node.apply; |
|
|
|
|
|
|
|
if (op.op_type == 0) { |
|
|
|
if (op.op_type == OperatorType::kRecMatMul) { |
|
|
|
// For MatMul |
|
|
|
auto cost_ptr = std::make_shared<CostMatMul>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(op); |
|
|
|
} else if (op.op_type == 1) { |
|
|
|
} else if (op.op_type == OperatorType::kRecConvolution) { |
|
|
|
// For Convolution |
|
|
|
auto cost_ptr = std::make_shared<CostConvolution>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(node); |
|
|
|
} else if (op.op_type == 2) { |
|
|
|
} else if (op.op_type == OperatorType::kRecPooling) { |
|
|
|
// For Pooling |
|
|
|
auto cost_ptr = std::make_shared<CostPooling>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(); |
|
|
|
} else if (op.op_type == 3) { |
|
|
|
} else if (op.op_type == OperatorType::kRecAdd) { |
|
|
|
// For Add |
|
|
|
auto cost_ptr = std::make_shared<CostAdd>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(); |
|
|
|
} else if (op.op_type == 4 || op.op_type == 7 || op.op_type == 9) { |
|
|
|
} else if (op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecReLU || |
|
|
|
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { |
|
|
|
// For Softmax & || Activation |
|
|
|
auto cost_ptr = std::make_shared<CostCommon>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(); |
|
|
|
} else if (op.op_type == 5) { |
|
|
|
} else if (op.op_type == OperatorType::kRecReshape) { |
|
|
|
// For Reshape |
|
|
|
auto cost_ptr = std::make_shared<CostReshape>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(); |
|
|
|
} else if (op.op_type == 6) { |
|
|
|
} else if (op.op_type == OperatorType::kRecBiasAdd) { |
|
|
|
// For BiasAdd |
|
|
|
auto cost_ptr = std::make_shared<CostBiasAdd>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(); |
|
|
|
} else if (op.op_type == 8) { |
|
|
|
} else if (op.op_type == OperatorType::kRecBatchNorm) { |
|
|
|
// For BatchNorm |
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>(); |
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(); |
|
|
|
} else if (op.op_type == OperatorType::kRecUnkownType) { |
|
|
|
// For unknown type |
|
|
|
return 0.0; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; |
|
|
|
} |
|
|
|
@@ -155,13 +158,17 @@ StrategyRec PartitionNode(const Graph::NodeType &node, |
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>(); |
|
|
|
|
|
|
|
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); |
|
|
|
} else if (node.apply.op_type == 10) { |
|
|
|
// For unknown type |
|
|
|
StrategyRec default_strategy; |
|
|
|
return default_strategy; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Parttion graph into all devices. |
|
|
|
Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> graph) { |
|
|
|
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) { |
|
|
|
if (num_device < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; |
|
|
|
} |
|
|
|
@@ -207,7 +214,7 @@ Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> gr |
|
|
|
} |
|
|
|
|
|
|
|
InferUndecideStrategy(graph); |
|
|
|
if (DevicesMemoryControl(graph) != SUCCESS) { |
|
|
|
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} else { |
|
|
|
return SUCCESS; |
|
|
|
@@ -306,15 +313,15 @@ void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status DevicesMemoryControl(std::shared_ptr<Graph> graph) { |
|
|
|
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
|
|
uint64_t iter_nodes = graph->nodes.size(); |
|
|
|
double used_memory = 0.0; |
|
|
|
|
|
|
|
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { |
|
|
|
if (graph->nodes[i_node].info == 0) { |
|
|
|
Graph::NodeType &Node = graph->nodes[i_node]; |
|
|
|
double used_memory = 0.0; |
|
|
|
|
|
|
|
for (int index = 0; index < 2; index++) { |
|
|
|
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * |
|
|
|
@@ -329,12 +336,12 @@ Status DevicesMemoryControl(std::shared_ptr<Graph> graph) { |
|
|
|
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h * |
|
|
|
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w * |
|
|
|
GetDataTypeSize(Node.tensor_parm.tensor_type); |
|
|
|
if (DEVICE_MEMORY < used_memory) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (device_memory < used_memory) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|