| @@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | ||||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | ||||
| constexpr char RESHAPEINFO[] = "ReshapeInfo"; | |||||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | void CostGraph::SetDeviceMemoryAndCostParameter() { | ||||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | ||||
| @@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; | |||||
| constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; | constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; | ||||
| constexpr char REQUIRES_GRAD[] = "requires_grad"; | constexpr char REQUIRES_GRAD[] = "requires_grad"; | ||||
| constexpr char PARAM_NAME[] = "name"; | constexpr char PARAM_NAME[] = "name"; | ||||
| constexpr char RESHAPEINFO[] = "ReshapeInfo"; | |||||
| constexpr char RELU_TYPE[] = "relu"; | constexpr char RELU_TYPE[] = "relu"; | ||||
| constexpr char RELU6_TYPE[] = "relu6"; | constexpr char RELU6_TYPE[] = "relu6"; | ||||
| @@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | OperatorInfoPtr operator_info = cnode->operator_info(); | ||||
| if (operator_info) { | if (operator_info) { | ||||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | |||||
| continue; | |||||
| } | |||||
| StrategyPtr strategyPtr = operator_info->strategy(); | StrategyPtr strategyPtr = operator_info->strategy(); | ||||
| MS_EXCEPTION_IF_NULL(node->scope()); | MS_EXCEPTION_IF_NULL(node->scope()); | ||||
| stra_map[param_name] = strategyPtr; | stra_map[param_name] = strategyPtr; | ||||
| @@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||||
| parallel_strategy_item->set_node_name(node_stra.first); | parallel_strategy_item->set_node_name(node_stra.first); | ||||
| straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); | straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); | ||||
| MS_EXCEPTION_IF_NULL(parallel_strategys); | MS_EXCEPTION_IF_NULL(parallel_strategys); | ||||
| MS_EXCEPTION_IF_NULL(node_stra.second); | |||||
| parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); | parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); | ||||
| for (auto &dims : node_stra.second->GetInputDim()) { | for (auto &dims : node_stra.second->GetInputDim()) { | ||||
| straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); | straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); | ||||