GitOrigin-RevId: 29f785b701
tags/v1.11.0
| @@ -27,7 +27,7 @@ class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD { | |||||
| std::shared_ptr<DeviceTensorND> load_tensor_shared( | std::shared_ptr<DeviceTensorND> load_tensor_shared( | ||||
| bool copy_immediatly = false) override { | bool copy_immediatly = false) override { | ||||
| (void)copy_immediatly; | |||||
| MGB_MARK_USED_VAR(copy_immediatly); | |||||
| mgb_assert(0); | mgb_assert(0); | ||||
| } | } | ||||
| @@ -56,7 +56,7 @@ public: | |||||
| } | } | ||||
| void dump_tensor( | void dump_tensor( | ||||
| const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
| TensorWriteMethod method) { | |||||
| TensorWriteMethod method, TensorFormat format = {}) { | |||||
| mgb_assert(0); | mgb_assert(0); | ||||
| } | } | ||||
| const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | ||||
| @@ -72,16 +72,20 @@ struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> { | |||||
| auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | ||||
| HostTensorND val; | HostTensorND val; | ||||
| val.copy_from(opr.get_dev_tensor()).sync(); | val.copy_from(opr.get_dev_tensor()).sync(); | ||||
| ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS); | |||||
| ctx.dump_tensor( | |||||
| {}, val, Meth::VALUE_ANONYMOUS, opr.get_dev_tensor().layout().format); | |||||
| } | } | ||||
| static cg::OperatorNodeBase* load( | static cg::OperatorNodeBase* load( | ||||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | OprLoadContext& ctx, const cg::VarNodeArray& inputs, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| mgb_assert(inputs.empty()); | mgb_assert(inputs.empty()); | ||||
| auto val = ctx.load_tensor(); | |||||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
| auto val = fbs_ctx.load_tensor(); | |||||
| auto format = fbs_ctx.load_tensor_format(0); | |||||
| TensorLayout layout_with_format = {val->shape(), val->dtype(), format}; | |||||
| auto dev_val = | auto dev_val = | ||||
| std::make_shared<DeviceTensorND>(val->comp_node(), val->layout()); | |||||
| std::make_shared<DeviceTensorND>(val->comp_node(), layout_with_format); | |||||
| dev_val->copy_from_fixlayout(*val); | dev_val->copy_from_fixlayout(*val); | ||||
| auto out_var = | auto out_var = | ||||
| opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | ||||
| @@ -136,7 +140,9 @@ struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||||
| HostTensorND val; | HostTensorND val; | ||||
| auto value = *opr.values()[i]; | auto value = *opr.values()[i]; | ||||
| val.copy_from(value).sync(); | val.copy_from(value).sync(); | ||||
| ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||||
| ctx.dump_tensor( | |||||
| opr.output(i)->name(), val, Meth::VALUE_SHARED, | |||||
| value.layout().format); | |||||
| } | } | ||||
| } | } | ||||
| @@ -152,10 +158,12 @@ struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||||
| nr = fopr->tensors()->size(); | nr = fopr->tensors()->size(); | ||||
| } | } | ||||
| Opr::ValueArray values(nr); | Opr::ValueArray values(nr); | ||||
| size_t id = 0; | |||||
| for (auto&& i : values) { | for (auto&& i : values) { | ||||
| i = ctx.load_tensor_shared(); | i = ctx.load_tensor_shared(); | ||||
| //! set tensor format | //! set tensor format | ||||
| TensorLayout layout_with_format = i->layout(); | |||||
| auto format = fbs_ctx.load_tensor_format(id++); | |||||
| TensorLayout layout_with_format{i->layout(), i->layout().dtype, format}; | |||||
| if (i->storage().comp_node().mem_node() == | if (i->storage().comp_node().mem_node() == | ||||
| CompNode::default_cpu().mem_node()) { | CompNode::default_cpu().mem_node()) { | ||||
| @@ -498,48 +498,66 @@ TEST(TestOprIO, MultipleDeviceTensorWithFormatHolderCpu) { | |||||
| auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||
| auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| { | |||||
| // dump | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||||
| auto test = [&](serialization::GraphDumpFormat format) { | |||||
| { | |||||
| // dump | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
| .rename(name); | |||||
| }; | |||||
| auto host_x = gen({8, 8, 8, 8}, cn); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}); | |||||
| opr::Convolution::Param param; | |||||
| param.pad_h = param.pad_w = 0; | |||||
| auto w1 = mkcvar("w1", {4, 8, 3, 3}), | |||||
| conv1 = opr::Convolution::make(x, w1, param); | |||||
| auto w2 = mkcvar("w2", {4, 4, 3, 3}), | |||||
| conv2 = opr::Convolution::make(conv1, w2, param); | |||||
| auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU); | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nhwcd4(); | |||||
| SymbolVar y_opt = | |||||
| gopt::optimize_for_inference({y}, options)[0].rename("out"); | |||||
| auto dumper = serialization::GraphDumper::make( | |||||
| serialization::OutputFile::make_fs(fname.c_str()), format); | |||||
| serialization::GraphDumper::DumpConfig config; | |||||
| config.keep_param_name = true; | |||||
| dumper->dump({y_opt}, config); | |||||
| } | |||||
| auto loader = serialization::GraphLoader::make( | |||||
| serialization::InputFile::make_fs(fname.c_str()), format); | |||||
| auto load = [&](CompNode dest_cn) { | |||||
| auto dest_cn_loc = dest_cn.locator_logical(); | |||||
| auto rst = | |||||
| loader->load({[&](CompNode::Locator& loc) { loc = dest_cn_loc; }}); | |||||
| HostTensorND host_z, host_z_expect; | |||||
| auto func = rst.graph_compile( | |||||
| {make_callback_copy(rst.output_var_map.at("out"), host_z)}); | |||||
| func->execute(); | |||||
| func->wait(); | |||||
| auto&& shared_tensor_map = loader->shared_tensor_id_map(); | |||||
| bool cd4 = false; | |||||
| for (auto&& i : shared_tensor_map) { | |||||
| auto&& shared_tensor = i.second.begin()->second; | |||||
| if (shared_tensor->format().type() == | |||||
| TensorFormat::Type::IMAGE2D_PACK4) { | |||||
| cd4 = true; | |||||
| } | |||||
| } | |||||
| ASSERT_TRUE(cd4); | |||||
| }; | }; | ||||
| auto host_x = gen({8, 8, 8, 8}, cn); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}); | |||||
| opr::Convolution::Param param; | |||||
| param.pad_h = param.pad_w = 0; | |||||
| auto w1 = mkcvar("w1", {4, 8, 3, 3}), | |||||
| conv1 = opr::Convolution::make(x, w1, param); | |||||
| auto w2 = mkcvar("w2", {4, 4, 3, 3}), | |||||
| conv2 = opr::Convolution::make(conv1, w2, param); | |||||
| auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU); | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nhwcd4(); | |||||
| SymbolVar y_opt = gopt::optimize_for_inference({y}, options)[0].rename("out"); | |||||
| auto dumper = serialization::GraphDumper::make( | |||||
| serialization::OutputFile::make_fs(fname.c_str())); | |||||
| serialization::GraphDumper::DumpConfig config; | |||||
| config.keep_param_name = true; | |||||
| dumper->dump({y_opt}, config); | |||||
| } | |||||
| auto loader = serialization::GraphLoader::make( | |||||
| serialization::InputFile::make_fs(fname.c_str())); | |||||
| auto load = [&](CompNode dest_cn) { | |||||
| auto dest_cn_loc = dest_cn.locator_logical(); | |||||
| auto rst = loader->load({[&](CompNode::Locator& loc) { loc = dest_cn_loc; }}); | |||||
| HostTensorND host_z, host_z_expect; | |||||
| auto func = rst.graph_compile( | |||||
| {make_callback_copy(rst.output_var_map.at("out"), host_z)}); | |||||
| func->execute(); | |||||
| load(cn); | |||||
| }; | }; | ||||
| load(cn); | |||||
| test({}); | |||||
| test(serialization::GraphDumpFormat::FLATBUFFERS_V2); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -32,7 +32,9 @@ class OprDumpContextMemory final : public OprDumpContextRawPOD { | |||||
| } | } | ||||
| void dump_tensor( | void dump_tensor( | ||||
| const std::string&, const HostTensorND&, TensorWriteMethod) override { | |||||
| const std::string&, const HostTensorND&, TensorWriteMethod, | |||||
| TensorFormat format = {}) override { | |||||
| MGB_MARK_USED_VAR(format); | |||||
| mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor"); | mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor"); | ||||
| } | } | ||||
| @@ -92,7 +92,7 @@ public: | |||||
| const GraphDumpConfig& config() const override { return m_config; } | const GraphDumpConfig& config() const override { return m_config; } | ||||
| void dump_tensor( | void dump_tensor( | ||||
| const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
| TensorWriteMethod method) override; | |||||
| TensorWriteMethod method, TensorFormat format = {}) override; | |||||
| flatbuffers::FlatBufferBuilder& builder() override { return m_builder; } | flatbuffers::FlatBufferBuilder& builder() override { return m_builder; } | ||||
| void append_param(uint32_t type, uint32_t value) override { | void append_param(uint32_t type, uint32_t value) override { | ||||
| static_assert( | static_assert( | ||||
| @@ -359,7 +359,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||||
| } | } | ||||
| void GraphDumperOSS::dump_tensor( | void GraphDumperOSS::dump_tensor( | ||||
| const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { | |||||
| const std::string& name, const HostTensorND& tensor, TensorWriteMethod method, | |||||
| TensorFormat) { | |||||
| using namespace flatbuffers; | using namespace flatbuffers; | ||||
| using Meth = TensorWriteMethod; | using Meth = TensorWriteMethod; | ||||
| mgb_assert( | mgb_assert( | ||||
| @@ -671,17 +672,17 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor_ | |||||
| sh_reg.first = tensor->name()->str(); | sh_reg.first = tensor->name()->str(); | ||||
| } | } | ||||
| if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
| if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) { | |||||
| // directly forward CPU memory | // directly forward CPU memory | ||||
| HostTensorND hv{comp_node}; | HostTensorND hv{comp_node}; | ||||
| load_tensor_value(&hv, layout, tensor); | load_tensor_value(&hv, layout, tensor); | ||||
| sh_ptr_ref = std::make_shared<DeviceTensorND>(); | sh_ptr_ref = std::make_shared<DeviceTensorND>(); | ||||
| *sh_ptr_ref = DeviceTensorND::make_proxy(hv); | |||||
| } else if (copy_immediatly) { | |||||
| HostTensorND hv{CompNode::default_cpu()}; | |||||
| load_tensor_value(&hv, layout, tensor); | |||||
| sh_ptr_ref = std::make_shared<DeviceTensorND>(); | |||||
| sh_ptr_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
| if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
| *sh_ptr_ref = DeviceTensorND::make_proxy(hv); | |||||
| } else { | |||||
| mgb_assert(copy_immediatly); | |||||
| sh_ptr_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
| } | |||||
| } else { | } else { | ||||
| // use lazy load for non-CPU devices | // use lazy load for non-CPU devices | ||||
| HostTensorND hv{CompNode::default_cpu()}; | HostTensorND hv{CompNode::default_cpu()}; | ||||
| @@ -455,7 +455,8 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||||
| } | } | ||||
| void GraphDumperOSSV2::dump_tensor( | void GraphDumperOSSV2::dump_tensor( | ||||
| const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { | |||||
| const std::string& name, const HostTensorND& tensor, TensorWriteMethod method, | |||||
| TensorFormat format) { | |||||
| using namespace flatbuffers; | using namespace flatbuffers; | ||||
| using Meth = TensorWriteMethod; | using Meth = TensorWriteMethod; | ||||
| mgb_assert( | mgb_assert( | ||||
| @@ -510,8 +511,8 @@ void GraphDumperOSSV2::dump_tensor( | |||||
| m_builder.CreateSharedString(tensor.comp_node().to_string_logical())); | m_builder.CreateSharedString(tensor.comp_node().to_string_logical())); | ||||
| auto fdtype = build_dtype(layout.dtype); | auto fdtype = build_dtype(layout.dtype); | ||||
| auto fformat_type = get_flatbuffer_tensor_format_type(layout.format); | |||||
| auto fformat = build_tensor_format(layout.format); | |||||
| auto fformat_type = get_flatbuffer_tensor_format_type(format); | |||||
| auto fformat = build_tensor_format(format); | |||||
| auto serialized_tensor = fbs::v2::CreateTensor( | auto serialized_tensor = fbs::v2::CreateTensor( | ||||
| m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data); | m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data); | ||||
| m_cur_opr_tensor.emplace_back(serialized_tensor); | m_cur_opr_tensor.emplace_back(serialized_tensor); | ||||
| @@ -605,7 +606,7 @@ CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node( | |||||
| return CompNode::load(loc); | return CompNode::load(loc); | ||||
| } | } | ||||
| TensorFormat load_tensor_format( | |||||
| TensorFormat get_tensor_format( | |||||
| const fbs::v2::TensorFormat fformat_type, const void* fformat, | const fbs::v2::TensorFormat fformat_type, const void* fformat, | ||||
| const CompNode& comp_node) { | const CompNode& comp_node) { | ||||
| switch (fformat_type) { | switch (fformat_type) { | ||||
| @@ -631,8 +632,7 @@ TensorFormat load_tensor_format( | |||||
| } | } | ||||
| } | } | ||||
| TensorLayout load_tensor_layout( | |||||
| const fbs::v2::Tensor* tensor, const CompNode& comp_node) { | |||||
| TensorLayout load_tensor_layout_without_format(const fbs::v2::Tensor* tensor) { | |||||
| TensorLayout layout; | TensorLayout layout; | ||||
| if (tensor->shape()) { | if (tensor->shape()) { | ||||
| layout.ndim = tensor->shape()->size(); | layout.ndim = tensor->shape()->size(); | ||||
| @@ -642,14 +642,21 @@ TensorLayout load_tensor_layout( | |||||
| // modify data type inplace for TensorLayout | // modify data type inplace for TensorLayout | ||||
| layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype())); | layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype())); | ||||
| } | } | ||||
| if (tensor->format() && tensor->format_type()) { | |||||
| layout.format = | |||||
| load_tensor_format(tensor->format_type(), tensor->format(), comp_node); | |||||
| } | |||||
| layout.init_contiguous_stride(); | layout.init_contiguous_stride(); | ||||
| return layout; | return layout; | ||||
| } | } | ||||
| TensorFormat GraphLoaderOSSV2::OprLoadContextImpl::load_tensor_format(size_t id) { | |||||
| mgb_assert(m_current_opr->tensors() && id < m_current_opr->tensors()->size()); | |||||
| auto tensor = m_current_opr->tensors()->Get(id); | |||||
| auto comp_node = load_comp_node(tensor->comp_node()); | |||||
| TensorFormat format; | |||||
| if (tensor->format() && tensor->format_type()) { | |||||
| format = get_tensor_format(tensor->format_type(), tensor->format(), comp_node); | |||||
| } | |||||
| return format; | |||||
| } | |||||
| //! the opr loader should make sure the exist of tensors and the number of | //! the opr loader should make sure the exist of tensors and the number of | ||||
| //! tensor, here just assert it. | //! tensor, here just assert it. | ||||
| std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() { | std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() { | ||||
| @@ -658,7 +665,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( | |||||
| m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | ||||
| auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | ||||
| auto comp_node = load_comp_node(tensor->comp_node()); | auto comp_node = load_comp_node(tensor->comp_node()); | ||||
| auto layout = load_tensor_layout(tensor, comp_node); | |||||
| auto layout = load_tensor_layout_without_format(tensor); | |||||
| auto ret = std::make_shared<HostTensorND>(comp_node, layout); | auto ret = std::make_shared<HostTensorND>(comp_node, layout); | ||||
| auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; | auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; | ||||
| @@ -692,7 +699,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: | |||||
| m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | ||||
| auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | ||||
| auto comp_node = load_comp_node(tensor->comp_node()); | auto comp_node = load_comp_node(tensor->comp_node()); | ||||
| auto layout = load_tensor_layout(tensor, comp_node); | |||||
| auto layout = load_tensor_layout_without_format(tensor); | |||||
| mgb_assert(tensor->data()); | mgb_assert(tensor->data()); | ||||
| if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) { | if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) { | ||||
| m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5); | m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5); | ||||
| @@ -712,7 +719,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: | |||||
| shared_pair.first = tensor->name()->str(); | shared_pair.first = tensor->name()->str(); | ||||
| } | } | ||||
| if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
| if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) { | |||||
| // directly forward CPU memory | // directly forward CPU memory | ||||
| shared_tensor_ref = std::make_shared<DeviceTensorND>(); | shared_tensor_ref = std::make_shared<DeviceTensorND>(); | ||||
| HostTensorND hv{comp_node}; | HostTensorND hv{comp_node}; | ||||
| @@ -722,18 +729,13 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: | |||||
| hv, tensor->data()->data(), tensor->data()->size(), | hv, tensor->data()->data(), tensor->data()->size(), | ||||
| m_loader->m_file->is_shared_memory()); | m_loader->m_file->is_shared_memory()); | ||||
| } | } | ||||
| *shared_tensor_ref = DeviceTensorND::make_proxy(hv); | |||||
| m_tensor_alignment->add_device_tensor(shared_tensor_ref); | |||||
| } else if (copy_immediatly) { | |||||
| HostTensorND hv{CompNode::default_cpu()}; | |||||
| shared_tensor_ref = std::make_shared<DeviceTensorND>(); | |||||
| if (tensor->data() && tensor->data()->size() > 0) { | |||||
| hv.dtype(layout.dtype).resize(layout); | |||||
| fill_tensor_memory( | |||||
| hv, tensor->data()->data(), tensor->data()->size(), | |||||
| m_loader->m_file->is_shared_memory()); | |||||
| if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
| *shared_tensor_ref = DeviceTensorND::make_proxy(hv); | |||||
| m_tensor_alignment->add_device_tensor(shared_tensor_ref); | |||||
| } else { | |||||
| mgb_assert(copy_immediatly); | |||||
| shared_tensor_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
| } | } | ||||
| shared_tensor_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
| } else { | } else { | ||||
| // use lazy load for non-CPU devices | // use lazy load for non-CPU devices | ||||
| HostTensorND hv{CompNode::default_cpu()}; | HostTensorND hv{CompNode::default_cpu()}; | ||||
| @@ -47,7 +47,7 @@ public: | |||||
| //! whether this can be write | //! whether this can be write | ||||
| virtual bool writable() { return false; } | virtual bool writable() { return false; } | ||||
| //! whether this file have been wrote | |||||
| //! tag this file have been wrote | |||||
| virtual void have_modified() {} | virtual void have_modified() {} | ||||
| /*! | /*! | ||||
| @@ -63,7 +63,7 @@ public: | |||||
| */ | */ | ||||
| virtual void dump_tensor( | virtual void dump_tensor( | ||||
| const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
| TensorWriteMethod method) = 0; | |||||
| TensorWriteMethod method, TensorFormat format = {}) = 0; | |||||
| //! get associated global configuration | //! get associated global configuration | ||||
| virtual const GraphDumpConfig& config() const = 0; | virtual const GraphDumpConfig& config() const = 0; | ||||
| @@ -63,7 +63,7 @@ public: | |||||
| void dump_tensor( | void dump_tensor( | ||||
| const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
| TensorWriteMethod method) override; | |||||
| TensorWriteMethod method, TensorFormat format = {}) override; | |||||
| void append_param(uint32_t type, uint32_t value) override { | void append_param(uint32_t type, uint32_t value) override { | ||||
| static_assert( | static_assert( | ||||
| @@ -148,6 +148,8 @@ public: | |||||
| return *m_loader->m_cur_load_config; | return *m_loader->m_cur_load_config; | ||||
| } | } | ||||
| TensorFormat load_tensor_format(size_t id); | |||||
| //! shared or copy the loaded flatbuffer memory to the CPU tensor, this can reduce | //! shared or copy the loaded flatbuffer memory to the CPU tensor, this can reduce | ||||
| //! the memory used when load model, but should consider the memory | //! the memory used when load model, but should consider the memory | ||||
| //! alignment | //! alignment | ||||