| @@ -46,6 +46,6 @@ if [[ ! -z ${QUANTIZE} ]]; then | |||
| QUANT_OPTIONS="--configFile=${WEIGHT_QUANT_CONFIG}" | |||
| fi | |||
| LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS | |||
| if [ -n "$3" ]; then | |||
| if [[ ! -z ${MIX_FLAG} ]]; then | |||
| LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=mix_lenet_tod.mindir --outputFile=mix_lenet_tod | |||
| fi | |||
| @@ -104,7 +104,7 @@ fi | |||
| cd model/ || exit 1 | |||
| rm -f *.ms | |||
| EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER $MIX_FLAG || exit 1 | |||
| EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} MIX_FLAG=${MIX_FLAG} ./prepare_model.sh $BATCH $DOCKER || exit 1 | |||
| cd ../ | |||
| # Copy the .ms model to the package folder | |||
| @@ -257,7 +257,7 @@ void NetRunner::Usage() { | |||
| bool NetRunner::ReadArgs(int argc, char *argv[]) { | |||
| int opt; | |||
| while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vob:")) != -1) { | |||
| while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vmob:")) != -1) { | |||
| switch (opt) { | |||
| case 'f': | |||
| ms_file_ = std::string(optarg); | |||
| @@ -280,8 +280,8 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) { | |||
| case 'b': | |||
| virtual_batch_ = atoi(optarg); | |||
| break; | |||
| case 'r': | |||
| is_raw_mix_precision_ = atoi(optarg); | |||
| case 'm': | |||
| is_raw_mix_precision_ = true; | |||
| break; | |||
| case 'h': | |||
| default: | |||
| @@ -27,7 +27,7 @@ class LiteSession; | |||
| class TrainLoop; | |||
| struct TrainLoopCallBackData { | |||
| TrainLoopCallBackData(bool train_mode, int epoch, LiteSession *session, TrainLoop *loop) | |||
| TrainLoopCallBackData(bool train_mode, unsigned int epoch, LiteSession *session, TrainLoop *loop) | |||
| : train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {} | |||
| bool train_mode_; /**< training mode of LiteSession object */ | |||
| @@ -28,8 +28,8 @@ | |||
| namespace mindspore { | |||
| Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) { | |||
| if ((impl_ == nullptr) || (impl_->session_ == nullptr)) { | |||
| MS_LOG(ERROR) << "Model implement is null."; | |||
| if ((impl_ == nullptr) || (impl_->session_ == nullptr) || ds == nullptr) { | |||
| MS_LOG(ERROR) << "Model implement or dataset is null."; | |||
| return kLiteUninitializedObj; | |||
| } | |||
| auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get())); | |||
| @@ -67,8 +67,8 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto | |||
| } | |||
| Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) { | |||
| if ((impl_ == nullptr) || (impl_->session_ == nullptr)) { | |||
| MS_LOG(ERROR) << "Model implement is null."; | |||
| if ((impl_ == nullptr) || (impl_->session_ == nullptr) || ds == nullptr) { | |||
| MS_LOG(ERROR) << "Model implement or dataset is null."; | |||
| return kLiteUninitializedObj; | |||
| } | |||
| @@ -47,6 +47,10 @@ Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> * | |||
| } | |||
| auto model_metrics = GetMetrics(); | |||
| for (auto m : model_metrics) { | |||
| if (m == nullptr) { | |||
| MS_LOG(ERROR) << "Null input metrics"; | |||
| return kLiteUninitializedObj; | |||
| } | |||
| if (m->metrics_impl_) { | |||
| // For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation | |||
| auto internal_m = m->metrics_impl_->GetInternalMetrics(); | |||
| @@ -79,6 +83,9 @@ Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i | |||
| return kLiteUninitializedObj; | |||
| } | |||
| for (auto cb : *i_cbs) { | |||
| if (cb == nullptr) { | |||
| return kLiteUninitializedObj; | |||
| } | |||
| if (cb->callback_impl_) { | |||
| // For off-the-shelf callback it is guaranteed that we have also an MSLite implementation | |||
| auto internal_cb = cb->callback_impl_->GetInternalCallback(); | |||
| @@ -91,6 +91,10 @@ class OptimizerKernel : public InnerKernel { | |||
| indices.push_back(lr_idx_); | |||
| for (size_t ix = 0; ix < indices.size(); ix++) { | |||
| if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name()) { | |||
| if (param->Size() != in_tensors_.at(indices[ix])->Size()) { | |||
| MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "set size not same"; | |||
| return false; | |||
| } | |||
| auto value = static_cast<float *>(param->MutableData())[0]; | |||
| static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value; | |||
| if (lr_idx_ == indices[ix]) { | |||
| @@ -22,6 +22,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "include/dataset/iterator.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "nnacl/op_base.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -35,14 +36,24 @@ using session::RET_STOP_TRAINING; | |||
| TrainLoop::~TrainLoop() {} | |||
| int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) { | |||
| train_session_->Train(); | |||
| MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr"); | |||
| auto ret = train_session_->Train(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "TrainLoop train failed"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_CHECK_GT(epochs, 0, RET_ERROR); | |||
| session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this); | |||
| if (load_func == nullptr) load_func = TrainLoop::LoadData; | |||
| for (auto cb : cbs) cb->Begin(cb_data); | |||
| for (auto cb : cbs) { | |||
| MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr"); | |||
| cb->Begin(cb_data); | |||
| } | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr"); | |||
| for (int i = 0; i < epochs; i++) { | |||
| cb_data.epoch_ = epoch_++; | |||
| for (auto cb : cbs) cb->EpochBegin(cb_data); | |||
| @@ -51,10 +62,9 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall | |||
| int s = 0; | |||
| iter->GetNextRow(&row_vec); | |||
| while (row_vec.size() != 0) { | |||
| auto ret = load_func(cb_data.session_->GetInputs(), &row_vec); | |||
| while (!row_vec.empty()) { | |||
| ret = load_func(cb_data.session_->GetInputs(), &row_vec); | |||
| if (ret != RET_OK) break; | |||
| cb_data.step_ = s++; | |||
| for (auto cb : cbs) cb->StepBegin(cb_data); | |||
| @@ -64,7 +74,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall | |||
| } | |||
| int break_loop = false; | |||
| for (auto cb : cbs) { | |||
| int ret = cb->EpochEnd(cb_data); | |||
| ret = cb->EpochEnd(cb_data); | |||
| if (ret != RET_CONTINUE) { | |||
| if (ret == RET_EXIT) { | |||
| MS_LOG(ERROR) << "Error in TrainLoop callback"; | |||
| @@ -85,23 +95,35 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall | |||
| } | |||
| int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) { | |||
| train_session_->Eval(); | |||
| MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr"); | |||
| auto ret = train_session_->Eval(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "TrainLoop train failed"; | |||
| return RET_ERROR; | |||
| } | |||
| session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this); | |||
| if (load_func == nullptr) load_func = TrainLoop::LoadData; | |||
| for (auto metric : metrics_) metric->Clear(); | |||
| for (auto cb : cbs) cb->Begin(cb_data); | |||
| for (auto metric : metrics_) { | |||
| MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr"); | |||
| metric->Clear(); | |||
| } | |||
| for (auto cb : cbs) { | |||
| MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr"); | |||
| cb->Begin(cb_data); | |||
| } | |||
| for (auto cb : cbs) cb->EpochBegin(cb_data); | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr"); | |||
| MSTensorVec row_vec; | |||
| int s = 0; | |||
| iter->GetNextRow(&row_vec); | |||
| while (row_vec.size() != 0) { | |||
| while (!row_vec.empty()) { | |||
| if (s >= max_steps) break; | |||
| auto ret = load_func(cb_data.session_->GetInputs(), &row_vec); | |||
| ret = load_func(cb_data.session_->GetInputs(), &row_vec); | |||
| if (ret != RET_OK) break; | |||
| cb_data.step_ = ++s; | |||
| @@ -63,6 +63,10 @@ int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) { | |||
| } | |||
| cfg_ = *train_cfg; | |||
| } | |||
| if (context == nullptr) { | |||
| MS_LOG(ERROR) << "context cannot be nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| allocator_ = context->allocator; | |||
| return lite::LiteSession::Init(context); | |||
| } | |||
| @@ -141,7 +141,7 @@ Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) { | |||
| std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat}; | |||
| std::vector<TypeId> fp32_type = {kNumberTypeFloat32, kNumberTypeFloat}; | |||
| if (!IsContain(valid_type, tensor->data_type())) { | |||
| MS_LOG(ERROR) << "source data type must be fp32 or fp16"; | |||
| MS_LOG(ERROR) << "source data type must be fp32 or fp16,cur is " << tensor->data_type(); | |||
| return nullptr; | |||
| } | |||
| @@ -139,6 +139,18 @@ TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) { | |||
| for (size_t ix = 0; ix < params1.size(); ix++) { | |||
| ASSERT_EQ(static_cast<float *>(params1[ix].MutableData())[0], static_cast<float>(ix) + pi); | |||
| } | |||
| if (!params.empty()) { | |||
| auto ¶m = params.at(0); | |||
| param.SetShape({20, 20}); | |||
| param.SetDataType(DataType::kNumberTypeInt8); | |||
| } | |||
| ASSERT_TRUE(model.SetOptimizerParams(params) != kSuccess); | |||
| if (!params.empty()) { | |||
| auto ¶m = params.at(0); | |||
| param.SetTensorName("failed_name"); | |||
| } | |||
| ASSERT_TRUE(model.SetOptimizerParams(params) != kSuccess); | |||
| } | |||
| TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) { | |||
| @@ -159,5 +171,62 @@ TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) { | |||
| static_cast<float *>(graients[ix].MutableData())[0] = static_cast<float>(ix) + pi; | |||
| } | |||
| ASSERT_TRUE(model.ApplyGradients(graients) == kSuccess); | |||
| if (!graients.empty()) { | |||
| auto ¶m = graients.at(0); | |||
| param.SetShape({20, 20}); | |||
| } | |||
| ASSERT_TRUE(model.ApplyGradients(graients) != kSuccess); | |||
| if (!graients.empty()) { | |||
| auto ¶m = graients.at(0); | |||
| param.SetTensorName("failed_name"); | |||
| } | |||
| ASSERT_TRUE(model.ApplyGradients(graients) != kSuccess); | |||
| } | |||
| TEST_F(TestCxxApiLiteModel, test_fp32_SUCCESS) { | |||
| Model model; | |||
| Graph graph; | |||
| auto context = std::make_shared<Context>(); | |||
| auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>(); | |||
| cpu_context->SetEnableFP16(true); | |||
| context->MutableDeviceInfo().push_back(cpu_context); | |||
| auto train_cfg = std::make_shared<TrainCfg>(); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; | |||
| ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess); | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false; | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| cpu_context->SetEnableFP16(false); | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| } | |||
| TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) { | |||
| Model model; | |||
| Graph graph; | |||
| auto context = std::make_shared<Context>(); | |||
| auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>(); | |||
| cpu_context->SetEnableFP16(true); | |||
| context->MutableDeviceInfo().push_back(cpu_context); | |||
| auto train_cfg = std::make_shared<TrainCfg>(); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; | |||
| ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess); | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false; | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| cpu_context->SetEnableFP16(false); | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -115,7 +115,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: | |||
| fusion_pm->AddPass(std::make_shared<opt::AffineFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>()); | |||
| } | |||
| if (config->fmk == converter::kFmkTypeMs) { | |||
| if (config->fmk == converter::kFmkTypeMs && !config->trainModel) { | |||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||
| if (remove_unused_cast_pass == nullptr) { | |||
| MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified"; | |||
| @@ -14,30 +14,30 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/parser_utils.h" | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h" | |||
| #include "tools/converter/parser/unused_node_remove_pass.h" | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ops/transpose.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/conv1d_inout_adjust.h" | |||
| #include "tools/converter/parser/inputs_adjust.h" | |||
| #include "ops/transpose.h" | |||
| #include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h" | |||
| #include "tools/converter/parser/unused_node_remove_pass.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/optimizer/format/to_format_base.h" | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| constexpr size_t kNumWeightIndex = 2; | |||
| bool IsWeightNodeSensitive(const AnfNodePtr &node) { | |||
| return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) || | |||
| opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) || | |||
| opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion) || | |||
| opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) || | |||
| opt::CheckPrimitiveType(node, prim::kPrimAdam); | |||
| } | |||
| std::unordered_map<std::string, size_t> weight_indexs = {{ops::kNameConv2DFusion, 2}, | |||
| {ops::kNameConv2DBackpropInputFusion, 2}, | |||
| {ops::kNameConv2dTransposeFusion, 2}, | |||
| {ops::kNameApplyMomentum, 1}, | |||
| {ops::kNameSGD, 1}, | |||
| {ops::kNameAdam, 1}}; | |||
| } // namespace | |||
| void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) { | |||
| @@ -146,15 +146,9 @@ int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format | |||
| return lite::RET_OK; | |||
| } | |||
| AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode) { | |||
| AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index) { | |||
| MS_ASSERT(graph != nullptr && cnode != nullptr); | |||
| if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) && | |||
| !opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) && | |||
| !opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) { | |||
| MS_LOG(ERROR) << "cnode is not a member of convolution's family."; | |||
| return nullptr; | |||
| } | |||
| auto weight_node = cnode->input(opt::kInputIndexTwo); | |||
| auto weight_node = cnode->input(index); | |||
| bool is_real_weight = | |||
| !opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) && !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad); | |||
| while (!is_real_weight) { | |||
| @@ -169,7 +163,7 @@ AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnod | |||
| } | |||
| auto manager = Manage(graph); | |||
| MS_ASSERT(manager != nullptr); | |||
| manager->Replace(cnode->input(opt::kInputIndexTwo), weight_node); | |||
| manager->Replace(cnode->input(index), weight_node); | |||
| return weight_node; | |||
| } | |||
| @@ -179,18 +173,19 @@ int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, sche | |||
| if (src_format == dst_format) { | |||
| return lite::RET_OK; | |||
| } | |||
| if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) && | |||
| !opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) && | |||
| !opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) { | |||
| MS_LOG(ERROR) << "cnode is not a member of convolution's family."; | |||
| auto primitive_ptr = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| auto primitive_name = primitive_ptr->name(); | |||
| if (weight_indexs.find(primitive_name) == weight_indexs.end()) { | |||
| MS_LOG(ERROR) << primitive_name << " is not a member of convolution's family."; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetRealConvWeightNode(graph, cnode) == nullptr) { | |||
| size_t index = weight_indexs[primitive_name]; | |||
| if (GetRealConvWeightNode(graph, cnode, index) == nullptr) { | |||
| MS_LOG(ERROR) << "current conv node is invalid, node name is " << cnode->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| bool is_const_weight = true; | |||
| auto weight_node = cnode->input(opt::kInputIndexTwo); | |||
| auto weight_node = cnode->input(index); | |||
| if (utils::isa<CNode>(weight_node)) { | |||
| is_const_weight = false; | |||
| } else if (utils::isa<Parameter>(weight_node)) { | |||
| @@ -234,7 +229,7 @@ int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_ | |||
| MS_LOG(ERROR) << "post node is invalid."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!IsWeightNodeSensitive(post_node)) { | |||
| if (!opt::ToFormatBase::IsWeightNodeSensitive(post_node)) { | |||
| continue; | |||
| } | |||
| has_visited->insert(post_node); | |||
| @@ -285,6 +280,9 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod | |||
| MS_LOG(ERROR) << "conv weight is non-const."; | |||
| return RET_ERROR; | |||
| } | |||
| if (weight_value->shape().size() != kShape4dDims) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto status = opt::TransFilterFormat(weight_value, src_format, dst_format); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(src_format) << "To" << EnumNameFormat(dst_format) | |||
| @@ -328,7 +326,7 @@ int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &wei | |||
| MS_LOG(ERROR) << "post node is invalid."; | |||
| return RET_ERROR; | |||
| } | |||
| if (IsWeightNodeSensitive(post_node)) { | |||
| if (opt::ToFormatBase::IsWeightNodeSensitive(post_node)) { | |||
| has_visited->insert(post_node); | |||
| continue; | |||
| } | |||
| @@ -30,7 +30,7 @@ void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all | |||
| int CommonAnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs); | |||
| int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm); | |||
| int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm); | |||
| AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode); | |||
| AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index); | |||
| int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, schema::Format src_format, | |||
| schema::Format dst_format, std::set<AnfNodePtr> *has_visited); | |||
| int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format, | |||
| @@ -25,6 +25,7 @@ | |||
| #include "ops/batch_norm.h" | |||
| #include "ops/batch_to_space.h" | |||
| #include "ops/bias_add.h" | |||
| #include "ops/cast.h" | |||
| #include "ops/concat.h" | |||
| #include "ops/crop.h" | |||
| #include "ops/depth_to_space.h" | |||
| @@ -102,12 +103,12 @@ static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ | |||
| // a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not. | |||
| static const std::unordered_map<std::string, bool> DynamicFormatOpList = { | |||
| {ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true}, | |||
| {ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false}, | |||
| {ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false}, | |||
| {ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false}, | |||
| {ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false}, | |||
| {ops::kNameQuantDTypeCast, false}}; | |||
| {ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true}, | |||
| {ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false}, | |||
| {ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false}, | |||
| {ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false}, | |||
| {ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false}, | |||
| {ops::kNameQuantDTypeCast, false}, {ops::kNameCast, false}}; | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; } | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; } | |||
| @@ -368,9 +368,7 @@ STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<A | |||
| } | |||
| continue; | |||
| } | |||
| if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) && | |||
| !CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) && | |||
| !CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) { | |||
| if (!IsWeightNodeSensitive(cnode)) { | |||
| continue; | |||
| } | |||
| if (has_visited->find(node) != has_visited->end()) { | |||
| @@ -26,6 +26,11 @@ | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| #include "tools/optimizer/graph/infershape_pass.h" | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/fusion/conv2d_transpose_fusion.h" | |||
| #include "ops/adam.h" | |||
| #include "ops/sgd.h" | |||
| #include "ops/apply_momentum.h" | |||
| using mindspore::converter::FmkType; | |||
| namespace mindspore { | |||
| @@ -37,6 +42,16 @@ class ToFormatBase : public Pass { | |||
| : Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {} | |||
| ~ToFormatBase() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| static bool IsConvFamilyNode(const AnfNodePtr &node) { | |||
| return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) || | |||
| opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) || | |||
| opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion); | |||
| } | |||
| static bool IsOptimizerNode(const AnfNodePtr &node) { | |||
| return opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) || | |||
| opt::CheckPrimitiveType(node, prim::kPrimAdam); | |||
| } | |||
| static bool IsWeightNodeSensitive(const AnfNodePtr &node) { return IsConvFamilyNode(node) || IsOptimizerNode(node); } | |||
| private: | |||
| bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); | |||
| @@ -366,8 +366,10 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const | |||
| MS_ASSERT(pre_type != nullptr && post_type != nullptr); | |||
| size_t trans_count = 0; | |||
| std::vector<AnfNodePtr> in_nodes; | |||
| auto graph_inputs = func_graph->get_inputs(); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i)) || | |||
| std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) { | |||
| in_nodes.push_back(cnode->input(i)); | |||
| } | |||
| } | |||