GitOrigin-RevId: ec44429f55
tags/v1.10.0
| @@ -180,7 +180,13 @@ void DataParser::parse_npy(const std::string& name, const std::string& path) { | |||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||
| } | |||
| void DataParser::parse_string(const std::string name, const std::string& str) { | |||
| void DataParser::parse_string(const std::string& name, const std::string& str) { | |||
| //! parse shape | |||
| if ('{' == str[0]) { | |||
| parse_shape(name, str); | |||
| return; | |||
| } | |||
| // data type | |||
| megdnn::DType data_type = mgb::dtype::Int32(); | |||
| if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { | |||
| @@ -257,3 +263,31 @@ void DataParser::parse_string(const std::string name, const std::string& str) { | |||
| } | |||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||
| } | |||
| void DataParser::parse_shape(const std::string& name, const std::string& str) { | |||
| //! {d0,d1,..,dn} | |||
| mgb_assert( | |||
| "{" == str.substr(0, 1), | |||
| "invalid value: %s for parse_shape, valid format: {d0,d1,..,dn}\n", | |||
| str.c_str()); | |||
| megdnn::SmallVector<size_t> shape; | |||
| std::string shape_size = ""; | |||
| for (size_t i = 0; i < str.size(); ++i) { | |||
| char c = str[i]; | |||
| if ('{' == c || ' ' == c) { | |||
| continue; | |||
| } else if (',' == c || '}' == c) { | |||
| shape.push_back(std::stoul(shape_size)); | |||
| shape_size = ""; | |||
| if ('}' == c) { | |||
| break; | |||
| } | |||
| } else { | |||
| shape_size += c; | |||
| } | |||
| } | |||
| mgb::HostTensorND hv(mgb::CompNode::default_cpu(), shape); | |||
| mgb::HostTensorStorage storage(mgb::CompNode::default_cpu()); | |||
| hv.only_reset_raw_storage(storage); | |||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||
| } | |||
| @@ -30,7 +30,10 @@ private: | |||
| //! parser for .npy data | |||
| void parse_npy(const std::string& name, const std::string& path); | |||
| //! parser for user define string | |||
| void parse_string(const std::string name, const std::string& str); | |||
| //! parser for user defined string | |||
| void parse_string(const std::string& name, const std::string& str); | |||
| //! parser for user defined shape | |||
| void parse_shape(const std::string& name, const std::string& str); | |||
| }; | |||
| } // namespace lar | |||
| @@ -73,7 +73,17 @@ void InputOption::config_model_internel<ModelMdl>( | |||
| tensormap.find(i.first) != tensormap.end(), | |||
| "can't find tesnor named %s", i.first.c_str()); | |||
| auto& in = tensormap.find(i.first)->second; | |||
| in->copy_from(i.second); | |||
| if (i.second.storage().empty()) { | |||
| mgb::HostTensorND hv; | |||
| hv.comp_node(mgb::CompNode::default_cpu(), true) | |||
| .dtype(in->dtype()) | |||
| .resize(i.second.shape()); | |||
| mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||
| memset((char*)raw_ptr, 1, hv.layout().total_nr_elems()); | |||
| in->copy_from(hv); | |||
| } else { | |||
| in->copy_from(i.second); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -39,10 +39,24 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||
| template <> | |||
| void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||
| if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (m_layout_transform) { | |||
| mgb_log_warn("using global layout transform optimization\n"); | |||
| mgb_log_debug("update input shape for global layout transform\n"); | |||
| auto&& load_result = model->get_mdl_load_result(); | |||
| if (m_force_batch_size > 0) { | |||
| for (auto&& i : load_result.tensor_map) { | |||
| auto& in = i.second; | |||
| mgb::TensorShape new_shape = in->shape(); | |||
| new_shape[0] = m_force_batch_size; | |||
| mgb::HostTensorND new_tensor; | |||
| new_tensor.comp_node(mgb::CompNode::default_cpu(), true) | |||
| .dtype(in->dtype()) | |||
| .resize(new_shape); | |||
| mgb::dt_byte* raw_ptr = new_tensor.raw_ptr(); | |||
| memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems()); | |||
| in->copy_from(new_tensor); | |||
| } | |||
| } | |||
| for (auto&& item : load_result.output_var_list) { | |||
| if (item.shape()[0] > 1) { | |||
| mgb_log_warn( | |||
| @@ -81,7 +95,11 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
| } | |||
| load_result.output_var_list = output_vars; | |||
| } | |||
| } | |||
| } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||
| if (m_layout_transform) { | |||
| mgb_log_warn("using global layout transform optimization\n"); | |||
| auto&& load_result = model->get_mdl_load_result(); | |||
| load_result.output_var_list = mgb::gopt::layout_transform( | |||
| load_result.output_var_list, m_layout_transform_target); | |||
| @@ -156,6 +174,8 @@ GoptLayoutOption::GoptLayoutOption() { | |||
| } | |||
| m_layout_transform_dump_file = FLAGS_layout_transform_dump; | |||
| m_force_batch_size = FLAGS_layout_transform_batch_size; | |||
| m_option = { | |||
| {"layout_transform", lar::String::make("")}, | |||
| }; | |||
| @@ -182,6 +202,14 @@ bool GoptLayoutOption::is_valid() { | |||
| } | |||
| } | |||
| ret = ret || !FLAGS_layout_transform_dump.empty(); | |||
| if (FLAGS_layout_transform_batch_size > 0) { | |||
| mgb_assert( | |||
| FLAGS_layout_transform_batch_size > 0 && | |||
| !FLAGS_layout_transform.empty(), | |||
| "\"layout-transform-batch-size\" should be set with " | |||
| "\"layout-transform\""); | |||
| ret = ret || FLAGS_layout_transform_batch_size > 0; | |||
| } | |||
| return ret || m_valid; | |||
| } | |||
| @@ -233,5 +261,8 @@ DEFINE_string( | |||
| "The computing graph after global layout transform will be dumped to the given " | |||
| "file path."); | |||
| DEFINE_int32( | |||
| layout_transform_batch_size, -1, | |||
| "the batch size of input for global layout transform optimization working on"); | |||
| REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | |||
| REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); | |||
| @@ -5,6 +5,7 @@ | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_string(layout_transform); | |||
| DECLARE_int32(layout_transform_batch_size); | |||
| DECLARE_string(layout_transform_dump); | |||
| namespace lar { | |||
| @@ -38,5 +39,6 @@ private: | |||
| mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| int32_t m_force_batch_size; | |||
| }; | |||
| } // namespace lar | |||