Browse Source

skip strategy ckpt save for reshape

tags/v0.6.0-beta
yao_yf 5 years ago
parent
commit
37338813f0
4 changed files with 5 additions and 1 deletions
  1. +0
    -1
      mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
  2. +1
    -0
      mindspore/ccsrc/parallel/ops_info/ops_utils.h
  3. +3
    -0
      mindspore/ccsrc/parallel/step_parallel.cc
  4. +1
    -0
      mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc

+ 0
- 1
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc View File

@@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE; int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
constexpr char RESHAPEINFO[] = "ReshapeInfo";


void CostGraph::SetDeviceMemoryAndCostParameter() { void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());


+ 1
- 0
mindspore/ccsrc/parallel/ops_info/ops_utils.h View File

@@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
constexpr char REQUIRES_GRAD[] = "requires_grad"; constexpr char REQUIRES_GRAD[] = "requires_grad";
constexpr char PARAM_NAME[] = "name"; constexpr char PARAM_NAME[] = "name";
constexpr char RESHAPEINFO[] = "ReshapeInfo";


constexpr char RELU_TYPE[] = "relu"; constexpr char RELU_TYPE[] = "relu";
constexpr char RELU6_TYPE[] = "relu6"; constexpr char RELU6_TYPE[] = "relu6";


+ 3
- 0
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info(); OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) { if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue;
}
StrategyPtr strategyPtr = operator_info->strategy(); StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope()); MS_EXCEPTION_IF_NULL(node->scope());
stra_map[param_name] = strategyPtr; stra_map[param_name] = strategyPtr;


+ 1
- 0
mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc View File

@@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
parallel_strategy_item->set_node_name(node_stra.first); parallel_strategy_item->set_node_name(node_stra.first);
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
MS_EXCEPTION_IF_NULL(parallel_strategys); MS_EXCEPTION_IF_NULL(parallel_strategys);
MS_EXCEPTION_IF_NULL(node_stra.second);
parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage()));
for (auto &dims : node_stra.second->GetInputDim()) { for (auto &dims : node_stra.second->GetInputDim()) {
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();


Loading…
Cancel
Save