GitOrigin-RevId: ec44429f55
tags/v1.10.0
| @@ -180,7 +180,13 @@ void DataParser::parse_npy(const std::string& name, const std::string& path) { | |||||
| inputs.insert(std::make_pair(name, std::move(hv))); | inputs.insert(std::make_pair(name, std::move(hv))); | ||||
| } | } | ||||
| void DataParser::parse_string(const std::string name, const std::string& str) { | |||||
| void DataParser::parse_string(const std::string& name, const std::string& str) { | |||||
| //! parse shape | |||||
| if ('{' == str[0]) { | |||||
| parse_shape(name, str); | |||||
| return; | |||||
| } | |||||
| // data type | // data type | ||||
| megdnn::DType data_type = mgb::dtype::Int32(); | megdnn::DType data_type = mgb::dtype::Int32(); | ||||
| if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { | if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { | ||||
| @@ -257,3 +263,31 @@ void DataParser::parse_string(const std::string name, const std::string& str) { | |||||
| } | } | ||||
| inputs.insert(std::make_pair(name, std::move(hv))); | inputs.insert(std::make_pair(name, std::move(hv))); | ||||
| } | } | ||||
| void DataParser::parse_shape(const std::string& name, const std::string& str) { | |||||
| //! {d0,d1,..,dn} | |||||
| mgb_assert( | |||||
| "{" == str.substr(0, 1), | |||||
| "invalid value: %s for parse_shape, valid format: {d0,d1,..,dn}\n", | |||||
| str.c_str()); | |||||
| megdnn::SmallVector<size_t> shape; | |||||
| std::string shape_size = ""; | |||||
| for (size_t i = 0; i < str.size(); ++i) { | |||||
| char c = str[i]; | |||||
| if ('{' == c || ' ' == c) { | |||||
| continue; | |||||
| } else if (',' == c || '}' == c) { | |||||
| shape.push_back(std::stoul(shape_size)); | |||||
| shape_size = ""; | |||||
| if ('}' == c) { | |||||
| break; | |||||
| } | |||||
| } else { | |||||
| shape_size += c; | |||||
| } | |||||
| } | |||||
| mgb::HostTensorND hv(mgb::CompNode::default_cpu(), shape); | |||||
| mgb::HostTensorStorage storage(mgb::CompNode::default_cpu()); | |||||
| hv.only_reset_raw_storage(storage); | |||||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||||
| } | |||||
| @@ -30,7 +30,10 @@ private: | |||||
| //! parser for .npy data | //! parser for .npy data | ||||
| void parse_npy(const std::string& name, const std::string& path); | void parse_npy(const std::string& name, const std::string& path); | ||||
| //! parser for user define string | |||||
| void parse_string(const std::string name, const std::string& str); | |||||
| //! parser for user defined string | |||||
| void parse_string(const std::string& name, const std::string& str); | |||||
| //! parser for user defined shape | |||||
| void parse_shape(const std::string& name, const std::string& str); | |||||
| }; | }; | ||||
| } // namespace lar | } // namespace lar | ||||
| @@ -73,7 +73,17 @@ void InputOption::config_model_internel<ModelMdl>( | |||||
| tensormap.find(i.first) != tensormap.end(), | tensormap.find(i.first) != tensormap.end(), | ||||
| "can't find tesnor named %s", i.first.c_str()); | "can't find tesnor named %s", i.first.c_str()); | ||||
| auto& in = tensormap.find(i.first)->second; | auto& in = tensormap.find(i.first)->second; | ||||
| in->copy_from(i.second); | |||||
| if (i.second.storage().empty()) { | |||||
| mgb::HostTensorND hv; | |||||
| hv.comp_node(mgb::CompNode::default_cpu(), true) | |||||
| .dtype(in->dtype()) | |||||
| .resize(i.second.shape()); | |||||
| mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||||
| memset((char*)raw_ptr, 1, hv.layout().total_nr_elems()); | |||||
| in->copy_from(hv); | |||||
| } else { | |||||
| in->copy_from(i.second); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -39,10 +39,24 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||||
| template <> | template <> | ||||
| void GoptLayoutOption::config_model_internel<ModelMdl>( | void GoptLayoutOption::config_model_internel<ModelMdl>( | ||||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
| if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||||
| if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
| if (m_layout_transform) { | if (m_layout_transform) { | ||||
| mgb_log_warn("using global layout transform optimization\n"); | |||||
| mgb_log_debug("update input shape for global layout transform\n"); | |||||
| auto&& load_result = model->get_mdl_load_result(); | auto&& load_result = model->get_mdl_load_result(); | ||||
| if (m_force_batch_size > 0) { | |||||
| for (auto&& i : load_result.tensor_map) { | |||||
| auto& in = i.second; | |||||
| mgb::TensorShape new_shape = in->shape(); | |||||
| new_shape[0] = m_force_batch_size; | |||||
| mgb::HostTensorND new_tensor; | |||||
| new_tensor.comp_node(mgb::CompNode::default_cpu(), true) | |||||
| .dtype(in->dtype()) | |||||
| .resize(new_shape); | |||||
| mgb::dt_byte* raw_ptr = new_tensor.raw_ptr(); | |||||
| memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems()); | |||||
| in->copy_from(new_tensor); | |||||
| } | |||||
| } | |||||
| for (auto&& item : load_result.output_var_list) { | for (auto&& item : load_result.output_var_list) { | ||||
| if (item.shape()[0] > 1) { | if (item.shape()[0] > 1) { | ||||
| mgb_log_warn( | mgb_log_warn( | ||||
| @@ -81,7 +95,11 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||||
| } | } | ||||
| load_result.output_var_list = output_vars; | load_result.output_var_list = output_vars; | ||||
| } | } | ||||
| } | |||||
| } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||||
| if (m_layout_transform) { | |||||
| mgb_log_warn("using global layout transform optimization\n"); | |||||
| auto&& load_result = model->get_mdl_load_result(); | |||||
| load_result.output_var_list = mgb::gopt::layout_transform( | load_result.output_var_list = mgb::gopt::layout_transform( | ||||
| load_result.output_var_list, m_layout_transform_target); | load_result.output_var_list, m_layout_transform_target); | ||||
| @@ -156,6 +174,8 @@ GoptLayoutOption::GoptLayoutOption() { | |||||
| } | } | ||||
| m_layout_transform_dump_file = FLAGS_layout_transform_dump; | m_layout_transform_dump_file = FLAGS_layout_transform_dump; | ||||
| m_force_batch_size = FLAGS_layout_transform_batch_size; | |||||
| m_option = { | m_option = { | ||||
| {"layout_transform", lar::String::make("")}, | {"layout_transform", lar::String::make("")}, | ||||
| }; | }; | ||||
| @@ -182,6 +202,14 @@ bool GoptLayoutOption::is_valid() { | |||||
| } | } | ||||
| } | } | ||||
| ret = ret || !FLAGS_layout_transform_dump.empty(); | ret = ret || !FLAGS_layout_transform_dump.empty(); | ||||
| if (FLAGS_layout_transform_batch_size > 0) { | |||||
| mgb_assert( | |||||
| FLAGS_layout_transform_batch_size > 0 && | |||||
| !FLAGS_layout_transform.empty(), | |||||
| "\"layout-transform-batch-size\" should be set with " | |||||
| "\"layout-transform\""); | |||||
| ret = ret || FLAGS_layout_transform_batch_size > 0; | |||||
| } | |||||
| return ret || m_valid; | return ret || m_valid; | ||||
| } | } | ||||
| @@ -233,5 +261,8 @@ DEFINE_string( | |||||
| "The computing graph after global layout transform will be dumped to the given " | "The computing graph after global layout transform will be dumped to the given " | ||||
| "file path."); | "file path."); | ||||
| DEFINE_int32( | |||||
| layout_transform_batch_size, -1, | |||||
| "the batch size of input for global layout transform optimization working on"); | |||||
| REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | ||||
| REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); | REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); | ||||
| @@ -5,6 +5,7 @@ | |||||
| #include "models/model.h" | #include "models/model.h" | ||||
| #include "option_base.h" | #include "option_base.h" | ||||
| DECLARE_string(layout_transform); | DECLARE_string(layout_transform); | ||||
| DECLARE_int32(layout_transform_batch_size); | |||||
| DECLARE_string(layout_transform_dump); | DECLARE_string(layout_transform_dump); | ||||
| namespace lar { | namespace lar { | ||||
| @@ -38,5 +39,6 @@ private: | |||||
| mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | ||||
| static bool m_valid; | static bool m_valid; | ||||
| OptionValMap m_option; | OptionValMap m_option; | ||||
| int32_t m_force_batch_size; | |||||
| }; | }; | ||||
| } // namespace lar | } // namespace lar | ||||