Browse Source

Add the sharding propagation function:

1) users configure sharding strategies for operators;
2) framework will propagate the strategies from configured-ops to
non-configured ops using BFS;
3) the propagation goal is to minimize redistribution communication
cost;
tags/v1.4.0
Xiaoda Zhang 5 years ago
parent
commit
04381273b3
17 changed files with 1124 additions and 12 deletions
  1. +60
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc
  2. +4
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h
  3. +97
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc
  4. +3
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h
  5. +3
    -0
      mindspore/ccsrc/frontend/parallel/context.cc
  6. +5
    -0
      mindspore/ccsrc/frontend/parallel/context.h
  7. +17
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  8. +4
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  9. +41
    -8
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  10. +3
    -0
      mindspore/ccsrc/pipeline/jit/init.cc
  11. +30
    -3
      mindspore/parallel/_auto_parallel_context.py
  12. +318
    -0
      tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py
  13. +314
    -0
      tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py
  14. +87
    -0
      tests/ut/python/parallel/test_auto_parallel_shard_propagation.py
  15. +60
    -0
      tests/ut/python/parallel/test_auto_parallel_shard_propagation2.py
  16. +72
    -0
      tests/ut/python/parallel/test_auto_parallel_shard_propagation3.py
  17. +6
    -0
      tests/ut/python/parallel/test_auto_parallel_two_matmul.py

+ 60
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc View File

@@ -330,6 +330,66 @@ Status Edge::CalculateMemoryCostForInference() {
return SUCCESS;
}

CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
if (cost_map_.find(stra_pair) == cost_map_.end()) {
return nullptr;
}
auto cost_vec = cost_map_[stra_pair];
if (cost_vec.empty()) {
PrintStrategy(stra_pair.first);
PrintStrategy(stra_pair.second);
MS_LOG(EXCEPTION) << "No available cost under current strategy pair of the edge: " << edge_name_;
}
if (cost_vec.size() > 1) {
PrintStrategy(stra_pair.first);
PrintStrategy(stra_pair.second);
MS_LOG(INFO) << "Multiple costs available under the stratey pair of the edge: " << edge_name_;
}
return cost_vec[0];
}

StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &prev_op_stra) {
std::vector<std::pair<StrategyPtr, double>> next_op_stras;
for (auto &key_value : cost_map_) {
const auto &candidate_prev_op_stra = key_value.first.first;
if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
next_op_stras.push_back({key_value.first.second, key_value.second[0]->computation_cost_});
}
}
if (next_op_stras.empty()) {
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
return nullptr;
} else if (next_op_stras.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
<< ", choose the one with"
" minimum computation costs.";
}
std::sort(next_op_stras.begin(), next_op_stras.end(),
[](std::pair<StrategyPtr, double> a, std::pair<StrategyPtr, double> b) { return a.second <= b.second; });
return next_op_stras[0].first;
}

StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &next_op_stra) {
std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
for (auto &key_value : cost_map_) {
const auto &candidate_next_op_stra = key_value.first.second;
if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
prev_op_stras.push_back({key_value.first.first, key_value.second[0]->computation_cost_});
}
}
if (prev_op_stras.empty()) {
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
return nullptr;
} else if (prev_op_stras.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
<< ", choose the one with minimum "
"computation costs.";
}
std::sort(prev_op_stras.begin(), prev_op_stras.end(),
[](std::pair<StrategyPtr, double> a, std::pair<StrategyPtr, double> b) { return a.second <= b.second; });
return prev_op_stras[0].first;
}

void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
cost_map_ = cost_map;
pre_op_output_.clear();


+ 4
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h View File

@@ -81,6 +81,9 @@ class Edge {
// Init cost_map_: for each output layout and input layout, calculate the cost
Status InitEdgeCost();
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
CostPtr GetCostByStrategyPair(const CostPtrKey &);
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &);
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &);
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
// For two operators u--->v, given the output tensor layout of u,
// and the input tensor layout of v, return the redistribution cost,
@@ -153,7 +156,7 @@ class Edge {
// the index of outputs of prev_op, and the index of inputs of next_op
size_t prev_op_output_index_, next_op_input_index_;

// 'pre_op_output_indexs_' and 'next_op_input_indexs_' store the indexes of inputs and outputs if is_combined = true
// pre_op_output_indexs_ and next_op_input_indexs_ store the indices of inputs and outputs if is_combined = true
std::vector<size_t> pre_op_output_indexs_;
std::vector<size_t> next_op_input_indexs_;
// is this edge constructed by combining multiple edges? If is is, then is_combined = true, else is_combined = false


+ 97
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc View File

@@ -19,6 +19,7 @@
#include <string>
#include <utility>
#include <vector>
#include <queue>

#include "frontend/parallel/auto_parallel/graph_costmodel.h"
#include "frontend/parallel/ops_info/reshape_info.h"
@@ -87,6 +88,102 @@ bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t outp
return false;
}

void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &ops_stras) {
if (ops_stras.empty()) {
MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
}
std::map<OperatorInfoPtr, bool> visited;
for (auto &op : ops_) {
visited[op] = false;
}
for (auto &op_stra : ops_stras) {
BFS(op_stra.first, op_stra.second, ops_stras, &visited);
}
}

void CheckShardingConsisitency(std::map<OperatorInfoPtr, StrategyPtr> configured_ops, OperatorInfoPtr curr_op,
OperatorInfoPtr another_op, CostPtr cost, EdgePtr edge) {
if ((configured_ops.find(another_op) == configured_ops.end()) &&
(cost == nullptr || cost->communication_cost_ != 0.0)) {
PrintStrategy(another_op->selected_strategy());
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << "There are redistribution cost occurs at edge: " << edge->edge_name()
<< ", consider configuring sharding strategies for two operators."
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
<< " and " << another_op->cnode()->fullname_with_scope();
}
}

void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
std::queue<std::pair<std::pair<OperatorInfoPtr, StrategyPtr>, size_t>> next_level;
next_level.push({{op, op_stra}, 0});
while (!next_level.empty()) {
auto curr_op = next_level.front().first.first;
auto configured_stra = next_level.front().first.second;
auto curr_depth = next_level.front().second;
visited->at(curr_op) = true;
MS_LOG(INFO) << "curr_depth: " << curr_depth;
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
for (auto &edge : curr_op->succ_edges()) {
const auto &next_op = edge->next_operator();
if (visited->at(next_op)) {
const auto cost = edge->GetCostByStrategyPair({curr_op->selected_strategy(), next_op->selected_strategy()});
CheckShardingConsisitency(configured_ops, curr_op, next_op, cost, edge);
continue;
}
if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
const auto &next_op_conf_stra = configured_ops[next_op];
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
<< " and " << next_op->cnode()->fullname_with_scope();
}
}
if (configured_ops.find(next_op) != configured_ops.end()) {
continue;
}
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if (next_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
}
next_level.push({{next_op, next_op_stra}, curr_depth + 1});
}
for (auto &edge : curr_op->prev_edges()) {
const auto &prev_op = edge->prev_operator();
if (visited->at(prev_op)) {
const auto cost = edge->GetCostByStrategyPair({prev_op->selected_strategy(), curr_op->selected_strategy()});
CheckShardingConsisitency(configured_ops, curr_op, prev_op, cost, edge);
continue;
}
if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
const auto &prev_op_conf_stra = configured_ops[prev_op];
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
<< " and " << curr_op->cnode()->fullname_with_scope();
}
}
if (configured_ops.find(prev_op) != configured_ops.end()) {
continue;
}
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if (prev_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
}
next_level.push({{prev_op, prev_op_stra}, curr_depth + 1});
}
next_level.pop();
}
}

std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
std::vector<OperatorInfoPtr> alive_ops) {
std::map<OperatorInfoPtr, bool> visited;


+ 3
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h View File

@@ -52,6 +52,9 @@ class CostGraph {
}
void RemoveOperator(const OperatorInfoPtr &op);
bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &);
void BFS(const OperatorInfoPtr &, const StrategyPtr &, std::map<OperatorInfoPtr, StrategyPtr>,
std::map<OperatorInfoPtr, bool> *);
// the edge is in the form: u --> v
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; }


+ 3
- 0
mindspore/ccsrc/frontend/parallel/context.cc View File

@@ -71,6 +71,7 @@ void ParallelContext::Reset() {
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
optimizer_weight_shard_size_ = -1;
optimizer_weight_shard_aggregated_save_ = false;
sharding_propagation_ = false;
}

void ParallelContext::set_device_num(int64_t device_num) {
@@ -266,5 +267,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func

MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}

void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
} // namespace parallel
} // namespace mindspore

+ 5
- 0
mindspore/ccsrc/frontend/parallel/context.h View File

@@ -124,6 +124,8 @@ class ParallelContext {

bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
void set_sharding_propagation(const bool);
bool sharding_propagation() const { return sharding_propagation_; }

void Reset();
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
@@ -160,6 +162,9 @@ class ParallelContext {
std::string communi_parallel_mode_;
int64_t optimizer_weight_shard_size_;
bool optimizer_weight_shard_aggregated_save_;
// In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
// will propagate the sharding strategies to other operators with minimum redistribution cost.
bool sharding_propagation_;
};

} // namespace parallel


+ 17
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -1699,6 +1699,23 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
}
}

void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth) {
MS_EXCEPTION_IF_NULL(s_strategy);
if ((selected_strategy_depth_ != -1) && (SizeToLong(curr_depth) > selected_strategy_depth_)) {
MS_LOG(INFO) << name_ << " has already been set strategy.";
return;
}
MS_LOG(INFO) << "Set strategy for: " << name_;
PrintStrategy(s_strategy);
selected_strategy_ = s_strategy;
selected_strategy_depth_ = SizeToLong(curr_depth);
}

CNodePtr OperatorInfo::cnode() {
MS_EXCEPTION_IF_NULL(cnode_);
return cnode_;
}

double OperatorInfo::GetForwardMemoryCostFromCNode() {
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
}


+ 4
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -67,6 +67,7 @@ class OperatorInfo {
refkey_parameter_name_ = "";
stage_device_list_ = g_device_manager->GetDeviceListInThisStage();
stage_device_size_ = SizeToLong(stage_device_list_.size());
cnode_ = nullptr;
}

virtual ~OperatorInfo() = default;
@@ -140,6 +141,7 @@ class OperatorInfo {
selected_strategy_ = s_strategy;
selected_cost_ = cost;
}
void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t);
StrategyPtr selected_strategy() const { return selected_strategy_; }
CostPtr selected_cost() const { return selected_cost_; }
// Approximate the list of available strategies
@@ -155,6 +157,7 @@ class OperatorInfo {
const std::vector<ValuePtr> &input_value() const { return input_value_; }
void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
CNodePtr cnode();
bool is_alive() const { return is_alive_; }
void SetNotAlive() { is_alive_ = false; }
StrategyPtr strategy() const { return strategy_; }
@@ -274,6 +277,7 @@ class OperatorInfo {
std::vector<std::shared_ptr<Edge>> prev_edges_;
std::vector<std::shared_ptr<Edge>> succ_edges_;
StrategyPtr selected_strategy_;
int64_t selected_strategy_depth_ = -1;
// Used in DP algorithm
bool is_alive_;
CostPtr selected_cost_;


+ 41
- 8
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -232,6 +232,8 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn
return true;
}

// 'configured_stra_ops_' includes all operators that are configured sharding strategies.
std::map<OperatorInfoPtr, StrategyPtr> configured_stra_ops_;
void InitCostGraph() {
if (entire_costgraph == nullptr) {
entire_costgraph = std::make_shared<CostGraph>();
@@ -239,6 +241,7 @@ void InitCostGraph() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
CostModelContext::GetInstance()->PrintCostModel();
entire_costgraph->Init();
configured_stra_ops_.clear();
}

void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
@@ -266,6 +269,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
if (used_devices == 1) {
configured_stra_ops_.insert({operator_info, strategyPtr});
return;
}
// 'used_devices == -1' means that 'used_devices_' is not set
@@ -275,6 +279,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
<< ", total devices: " << total_device_num;
}
}
configured_stra_ops_.insert({operator_info, strategyPtr});
}
}

@@ -344,6 +349,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
return nullptr;
}
if (ParallelContext::GetInstance()->sharding_propagation() &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
const auto &swc_vec = operator_info->GetStrategyCost();
if (swc_vec.empty()) {
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
}
MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
configured_stra_ops_.insert({operator_info, swc_vec[0]->strategy_ptr});
}
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
@@ -608,9 +622,25 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
node_op_info->AddPrevEdge(edge_ptr);
prev_op_info->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
const auto next_op_stra = configured_stra_ops_[node_op_info];
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithZeroComm(next_op_stra);
if (cast_stra == nullptr) {
MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
}
prev_op_info->ClearStrategyCost();
if (prev_op_info->SetCostUnderStrategy(cast_stra) != SUCCESS) {
MS_LOG(EXCEPTION) << "Failure: operator " << prev_op_info->name() << " SetCostUnderStrategy failed";
}
if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost re-initialization failed.";
}
MS_LOG(INFO) << "Set strategy for: " << prev_op_info->name() << " under the strategy of: " << node_op_info->name();
configured_stra_ops_.insert({prev_op_info, cast_stra});
}
MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name();
(*edge_count)++;
return;
}

void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
@@ -884,11 +914,11 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// subsequent operator;
// Step 3.1: Calculate memory usage:
// note the memory usage calculation is different in training phase and inference phase.
// Step 4: Run the Dynamic Programming algorithm:
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
// tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm
// runs on each of them.
// Step 4: Run the strategy searching algorithm:
// If 'sharding_propagation' is configured to be true, then the configured-sharding-strategies will propagate
// to the non-configured operators, with the goal of minimizing redistribution cost.
// Otherwise, DP algorithm is used to search strategy of the costgraph. Note that there may be several connected
// components in the costgraph, and the DP algorithm runs on each of them.
//
// OUTPUT: the determined strategy for each operator.

@@ -929,8 +959,11 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
}

// Step 4: run DP algorithm on the costgraph.
if (GetStrategy(entire_costgraph) != SUCCESS) {
// Step 4: run the strategy searching algorithm
if (ParallelContext::GetInstance()->sharding_propagation()) {
entire_costgraph->StrategyPropagate(configured_stra_ops_);
configured_stra_ops_.clear();
} else if (GetStrategy(entire_costgraph) != SUCCESS) {
MS_LOG(ERROR) << "Strategy search for cost-graph fails";
return FAILED;
}


+ 3
- 0
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -186,6 +186,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Set whether to integrated save weight shard when enable parallel optimizer.")
.def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
"Get whether to integrated save weight shard when enable parallel optimizer.")
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
"Set sharding strategy propagation value.")
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");

(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")


+ 30
- 3
mindspore/parallel/_auto_parallel_context.py View File

@@ -443,6 +443,26 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()

def set_sharding_propagation(self, sharding_propagation):
"""
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
will search the desired strategies.
Default: False.

Args:
sharding_propagation (bool): Enable/disable strategy propagation.
"""
self.check_context_handle()
if not isinstance(sharding_propagation, bool):
raise TypeError("'sharding_propagation' is an invalid type.")
self._context_handle.set_sharding_propagation(sharding_propagation)

def get_sharding_propagation(self):
"""Get the value of sharding strategy propagation."""
self.check_context_handle()
return self._context_handle.get_sharding_propagation()

def set_communi_parallel_mode(self, communi_parallel_mode):
"""
Set communication parallel mode.
@@ -563,7 +583,8 @@ _set_auto_parallel_context_func_map = {
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save}
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation}


_get_auto_parallel_context_func_map = {
@@ -584,7 +605,8 @@ _get_auto_parallel_context_func_map = {
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save}
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().get_sharding_propagation}


@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@@ -593,7 +615,8 @@ _get_auto_parallel_context_func_map = {
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int,
optimizer_weight_shard_aggregated_save=bool)
optimizer_weight_shard_aggregated_save=bool,
sharding_propagation=bool)

def _set_auto_parallel_context(**kwargs):
"""
@@ -655,6 +678,10 @@ def _set_auto_parallel_context(**kwargs):
Default: -1, which means fully use parallel optimizer in data parallel dimension.
optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
optimizer. Default: False.
sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
the strategy-configured operators will propagate the strategies to other
operators with minimum redistribution cost; otherwise, the algorithm will
search the desired strategies. Default: False.

Raises:
ValueError: If input key is not attribute in auto parallel context.


+ 318
- 0
tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py View File

@@ -0,0 +1,318 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import numpy as np

import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as resset_op_id
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True


def weight_variable():
return TruncatedNormal(0.02)


def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 3x3 kernel size."""
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)


def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 1x1 kernel size."""
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)


def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 7x7 kernel size."""
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)


def _fused_bn(channels, momentum=0.9):
"""Get a fused batchnorm"""
return nn.BatchNorm2d(channels, momentum=momentum)


class ResidualBlock(nn.Cell):
expansion = 4

def __init__(self,
in_channels,
out_channels,
stride=1,
momentum=0.9):
super(ResidualBlock, self).__init__()

out_chls = out_channels // self.expansion
self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
self.bn1 = _fused_bn(out_chls, momentum=momentum)

self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
self.bn2 = _fused_bn(out_chls, momentum=momentum)

self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
self.bn3 = _fused_bn(out_channels, momentum=momentum)

self.relu = P.ReLU()
self.downsample = (in_channels != out_channels)
self.stride = stride
if self.downsample:
self.conv_down_sample = _conv1x1(in_channels, out_channels,
stride=stride)
self.bn_down_sample = _fused_bn(out_channels, momentum=momentum)
elif self.stride != 1:
self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')

self.add = P.Add()

def construct(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
elif self.stride != 1:
identity = self.maxpool_down(identity)

out = self.add(out, identity)
out = self.relu(out)

return out


class ResNet(nn.Cell):
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides=None,
num_classes=100,
matmul_stra=None,
squeeze_stra=None):
super(ResNet, self).__init__()
if strides is None:
strides = [1, 2, 2, 2]
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")

self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _fused_bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])

self.mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
weight_init=weight_variable(),
bias_init=weight_variable()).add_flags_recursive(fp16=True)
self.end_point.matmul.shard(matmul_stra)
self.squeeze = P.Squeeze().shard(squeeze_stra)
self.cast = P.Cast()

def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
layers = []

resblk = block(in_channel, out_channel, stride=1)
layers.append(resblk)

for _ in range(1, layer_num - 1):
resblk = block(out_channel, out_channel, stride=1)
layers.append(resblk)

resblk = block(out_channel, out_channel, stride=stride)
layers.append(resblk)

return nn.SequentialCell(layers)

def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.squeeze(out)
out = self.end_point(out)

return out


def resnet50(class_num=10, matmul_stra=None, squeeze_stra=None):
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[2, 2, 2, 1],
class_num,
matmul_stra=matmul_stra,
squeeze_stra=squeeze_stra)


class SoftmaxCrossEntropyExpand(_Loss):
def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp()
self.sum = P.ReduceSum(keep_dims=True).shard(((1, 8),))
self.onehot = P.OneHot().shard(((1, 8), (), ()))
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.div = P.Div().shard(((1, 8), (1, 1)))
self.log = P.Log()
self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
self.mul = P.Mul()
self.mul2 = P.Mul()
self.cast = P.Cast()
self.mean = P.ReduceMean(keep_dims=False).add_prim_attr("cross_batch", True)
self.sparse = sparse
self.max = P.ReduceMax(keep_dims=True).shard(((1, 8),))
self.sub = P.Sub().shard(((1, 8), (1, 1)))
self.cast1 = P.Cast()

def construct(self, logit, label):
logit = self.cast1(logit, mstype.float32)
logit_max = self.max(logit)
exp = self.exp(self.sub(logit, logit_max))
exp_sum = self.sum(exp, -1)
softmax_result = self.div(exp, exp_sum)
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)

softmax_result_log = self.log(softmax_result)
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
loss = self.mul2(F.scalar_to_array(-1.0), loss)
loss = self.mean(loss, -1)

return loss


class DatasetLenet():
def __init__(self, predict, label, length=3):
self.predict = predict
self.label = label
self.index = 0
self.length = length

def __iter__(self):
return self

def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label

def reset(self):
self.index = 0

def get_dataset_size(self):
return 32

def get_repeat_count(self):
return 1

def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self


def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num,
sharding_propagation=True)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
set_algo_parameters(elementwise_op_strategy_follow=True)
resset_op_id()
np.random.seed(6)
input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
label_np = np.zeros([batch_size]).astype(np.int32)
for i in range(0, batch_size):
label_np[i] = i % num_classes
dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
matmul_stra = ((1, 1), (dev_num, 1))
squeeze_stra = ((dev_num, 1, 1, 1),)
net = resnet50(num_classes, matmul_stra, squeeze_stra)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(5, dataset, dataset_sink_mode=False)
strategies = _executor._get_shard_strategy(model._train_network)
for (k, v) in strategies.items():
if re.search('Conv2D-op', k) is not None:
assert v[0][0] == dev_num
elif re.search('MatMul-op', k) is not None:
assert v == [[1, 1], [dev_num, 1]]
elif re.search('ReduceSum-op', k) is not None:
assert v == [[1, dev_num]]
context.reset_auto_parallel_context()

+ 314
- 0
tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py View File

@@ -0,0 +1,314 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import numpy as np

import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as resset_op_id
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True


def weight_variable():
return TruncatedNormal(0.02)


def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 3x3 kernel size."""
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)


def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 1x1 kernel size."""
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)


def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 7x7 kernel size."""
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)


def _fused_bn(channels, momentum=0.9):
"""Get a fused batchnorm"""
return nn.BatchNorm2d(channels, momentum=momentum)


class ResidualBlock(nn.Cell):
expansion = 4

def __init__(self,
in_channels,
out_channels,
stride=1,
momentum=0.9):
super(ResidualBlock, self).__init__()

out_chls = out_channels // self.expansion
self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
self.bn1 = _fused_bn(out_chls, momentum=momentum)

self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
self.bn2 = _fused_bn(out_chls, momentum=momentum)

self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
self.bn3 = _fused_bn(out_channels, momentum=momentum)

self.relu = P.ReLU()
self.downsample = (in_channels != out_channels)
self.stride = stride
if self.downsample:
self.conv_down_sample = _conv1x1(in_channels, out_channels,
stride=stride)
self.bn_down_sample = _fused_bn(out_channels, momentum=momentum)
elif self.stride != 1:
self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')

self.add = P.Add()

def construct(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
elif self.stride != 1:
identity = self.maxpool_down(identity)

out = self.add(out, identity)
out = self.relu(out)

return out


class ResNet(nn.Cell):
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides=None,
num_classes=100,
matmul_stra=None,
squeeze_stra=None):
super(ResNet, self).__init__()
if strides is None:
strides = [1, 2, 2, 2]
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")

self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _fused_bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])

self.mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
weight_init=weight_variable(),
bias_init=weight_variable()).add_flags_recursive(fp16=True)
self.end_point.matmul.shard(matmul_stra)
self.squeeze = P.Squeeze().shard(squeeze_stra)
self.cast = P.Cast()

def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
layers = []

resblk = block(in_channel, out_channel, stride=1)
layers.append(resblk)

for _ in range(1, layer_num - 1):
resblk = block(out_channel, out_channel, stride=1)
layers.append(resblk)

resblk = block(out_channel, out_channel, stride=stride)
layers.append(resblk)

return nn.SequentialCell(layers)

def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.squeeze(out)
out = self.end_point(out)

return out


def resnet50(class_num=10, matmul_stra=None, squeeze_stra=None):
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[2, 2, 2, 1],
class_num,
matmul_stra=matmul_stra,
squeeze_stra=squeeze_stra)


class SoftmaxCrossEntropyExpand(_Loss):
def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp()
self.sum = P.ReduceSum(keep_dims=True)
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.div = P.Div()
self.log = P.Log()
self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
self.mul = P.Mul()
self.mul2 = P.Mul()
self.cast = P.Cast()
self.mean = P.ReduceMean(keep_dims=False).add_prim_attr("cross_batch", True)
self.sparse = sparse
self.max = P.ReduceMax(keep_dims=True).shard(((8, 1),))
self.sub = P.Sub().shard(((8, 1), (1, 1)))
self.cast1 = P.Cast()

def construct(self, logit, label):
logit = self.cast1(logit, mstype.float32)
logit_max = self.max(logit)
exp = self.exp(self.sub(logit, logit_max))
exp_sum = self.sum(exp, -1)
softmax_result = self.div(exp, exp_sum)
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)

softmax_result_log = self.log(softmax_result)
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
loss = self.mul2(F.scalar_to_array(-1.0), loss)
loss = self.mean(loss, -1)

return loss


class DatasetLenet():
def __init__(self, predict, label, length=3):
self.predict = predict
self.label = label
self.index = 0
self.length = length

def __iter__(self):
return self

def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label

def reset(self):
self.index = 0

def get_dataset_size(self):
return 32

def get_repeat_count(self):
return 1

def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self


def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num,
sharding_propagation=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
resset_op_id()
np.random.seed(6)
input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
label_np = np.zeros([batch_size]).astype(np.int32)
for i in range(0, batch_size):
label_np[i] = i % num_classes
dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
matmul_stra = ((dev_num, 1), (1, 1))
net = resnet50(num_classes, matmul_stra)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(5, dataset, dataset_sink_mode=False)
strategies = _executor._get_shard_strategy(model._train_network)
for (k, v) in strategies.items():
if re.search('Conv2D-op', k) is not None:
assert v[0][0] == dev_num
elif re.search('MatMul-op', k) is not None:
assert v == [[dev_num, 1], [1, 1]]
elif re.search('ReduceSum-op', k) is not None:
assert v == [[dev_num, 1]]

+ 87
- 0
tests/ut/python/parallel/test_auto_parallel_shard_propagation.py View File

@@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P


class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.sigmoid = P.Sigmoid().shard(strategy2)
self.relu = P.ReLU().shard(strategy3)
self.mul_weight = Parameter(mul_weight, "w1")

def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.sigmoid(out)
out = self.relu(out)
return out


_x = Tensor(np.ones([64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)


def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()


def test_auto_parallel_activation1():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
strategy1 = ((4, 4), (4, 4))
strategy2 = None
net = Net(_w1, strategy1, strategy2)
compile_net(net)


def test_auto_parallel_activation2():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
strategy1 = None
strategy2 = ((4, 4),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)

def auto_parallel_activation3():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((4, 4),)
net = Net(_w1, strategy1, strategy2, strategy3)
compile_net(net)

def test_auto_parallel_activation4():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((8, 2),)
net = Net(_w1, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)

+ 60
- 0
tests/ut/python/parallel/test_auto_parallel_shard_propagation2.py View File

@@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P


class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.cast = P.Cast().shard(strategy2)
self.sigmoid = P.Sigmoid().shard(strategy3)
self.mul_weight = Parameter(mul_weight, "w1")

def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.cast(out, mstype.float16)
out = self.sigmoid(out)
return out


_x = Tensor(np.ones([64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)


def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()


def test_auto_parallel_activation4():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
sharding_propagation=True)
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((8, 2),)
net = Net(_w1, strategy1, strategy2, strategy3)
compile_net(net)

+ 72
- 0
tests/ut/python/parallel/test_auto_parallel_shard_propagation3.py View File

@@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P


class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.add = P.TensorAdd()
self.relu = P.ReLU().shard(strategy1)
self.max = P.ReduceMax(keep_dims=True).shard(strategy2)
self.sub = P.Sub().shard(strategy3)
self.mul_weight = Parameter(mul_weight, "w1")

def construct(self, x, b):
out = self.add(x, self.mul_weight)
out = self.relu(out)
out2 = self.max(out)
out = self.sub(out, out2)
return out


_x = Tensor(np.ones([64, 32000]), dtype=ms.float32)
_w1 = Tensor(np.ones([64, 32000]), dtype=ms.float32)
_b = Tensor(np.ones([64, 32000]), dtype=ms.float32)


def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()


def test_auto_parallel_activation1():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
sharding_propagation=True)
strategy1 = None
strategy2 = ((8, 1),)
strategy3 = ((1, 8), (1, 1))
net = Net(_w1, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)

def test_auto_parallel_activation2():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
sharding_propagation=True)
strategy1 = ((1, 8),)
strategy2 = ((1, 1),)
strategy3 = ((1, 8), (1, 1))
net = Net(_w1, strategy1, strategy2, strategy3)
compile_net(net)

+ 6
- 0
tests/ut/python/parallel/test_auto_parallel_two_matmul.py View File

@@ -68,6 +68,12 @@ def test_two_matmul():

size = 16
context.set_auto_parallel_context(device_num=size, global_rank=0)
strategy_pro = context.get_auto_parallel_context("sharding_propagation")
assert not strategy_pro
context.set_auto_parallel_context(sharding_propagation=True)
strategy_pro = context.get_auto_parallel_context("sharding_propagation")
assert strategy_pro
context.set_auto_parallel_context(sharding_propagation=False)
cost_model_context.set_cost_model_context(device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
costmodel_alpha=1.0,
costmodel_beta=60.0,


Loading…
Cancel
Save