|
|
|
@@ -232,7 +232,7 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) { |
|
|
|
if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} else { |
|
|
|
return SUCCESS; |
|
|
|
@@ -257,16 +257,15 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { |
|
|
|
return Node; |
|
|
|
} |
|
|
|
|
|
|
|
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) { |
|
|
|
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
|
|
uint64_t iter_nodes = graph->nodes.size(); |
|
|
|
double used_memory = 0.0; |
|
|
|
|
|
|
|
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { |
|
|
|
if (graph->nodes[i_node].info == 0) { |
|
|
|
Graph::NodeType &Node = graph->nodes[i_node]; |
|
|
|
double used_memory = 0.0; |
|
|
|
|
|
|
|
for (int index = 0; index < 2; index++) { |
|
|
|
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * |
|
|
|
Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c * |
|
|
|
@@ -274,21 +273,15 @@ Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> g |
|
|
|
Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w * |
|
|
|
GetDataTypeSize(Node.apply.arguments[index].tensor_type); |
|
|
|
} |
|
|
|
|
|
|
|
used_memory += Node.tensor_parm.tensor_str.str_n * Node.tensor_parm.tensor_shape.shape_n * |
|
|
|
Node.tensor_parm.tensor_str.str_c * Node.tensor_parm.tensor_shape.shape_c * |
|
|
|
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h * |
|
|
|
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w * |
|
|
|
GetDataTypeSize(Node.tensor_parm.tensor_type); |
|
|
|
|
|
|
|
if (device_memory < used_memory) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
if (device_memory < (used_memory / num_device)) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!"; |
|
|
|
return FAILED; |
|
|
|
} else { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetDataTypeSize(const TensorType &type) { |
|
|
|
|