GitOrigin-RevId: e4771d6bc4
tags/v1.7.0.m1
| @@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase): | |||||
| self._last_param = p | self._last_param = p | ||||
| self._param_fields = [] | self._param_fields = [] | ||||
| self._fb_fields = ["builder"] | self._fb_fields = ["builder"] | ||||
| if p.is_legacy: | |||||
| self._skip_current_param = True | |||||
| return | |||||
| self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | ||||
| p.name, indent=1) | p.name, indent=1) | ||||
| self._write("using MegDNNType = megdnn::param::%s;", p.name) | self._write("using MegDNNType = megdnn::param::%s;", p.name) | ||||
| @@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| def _on_param_begin(self, p): | def _on_param_begin(self, p): | ||||
| self._last_param = p | self._last_param = p | ||||
| self._cur_const_val = {} | self._cur_const_val = {} | ||||
| if p.is_legacy: | |||||
| self._skip_current_param = True | |||||
| return | |||||
| self._write_doc(p.name) | self._write_doc(p.name) | ||||
| self._write("table %s {", p.name, indent=1) | self._write("table %s {", p.name, indent=1) | ||||
| @@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase): | |||||
| def _on_param_begin(self, p): | def _on_param_begin(self, p): | ||||
| self._last_param = p | self._last_param = p | ||||
| if p.is_legacy: | |||||
| self._skip_current_param = True | |||||
| return | |||||
| self._packed = True | self._packed = True | ||||
| self._current_tparams = [] | self._current_tparams = [] | ||||
| self._const = set() | self._const = set() | ||||
| @@ -62,6 +62,37 @@ struct PersistentAddUpdateParam { | |||||
| } // namespace opr_add_update | } // namespace opr_add_update | ||||
| // Old SerializedDType used in MegBrain 7.22.0 - 7.23.1 | |||||
| // Should be kept as-is even if there are new dtypes. | |||||
| struct SerializedDTypeV1 { | |||||
| static constexpr uint32_t TAG = megdnn::param::FakeSerializedDType::TAG; | |||||
| DTypeEnum enumv; | |||||
| union { | |||||
| megdnn::DTypeParam<dtype::Quantized8Asymm> Quantized8Asymm; | |||||
| megdnn::DTypeParam<dtype::QuantizedS8> QuantizedS8; | |||||
| megdnn::DTypeParam<dtype::QuantizedS32> QuantizedS32; | |||||
| } param; | |||||
| operator DType() const { | |||||
| switch (enumv) { | |||||
| #define cb(_dt) \ | |||||
| case DTypeEnum::_dt: \ | |||||
| return DType::from_enum(enumv); | |||||
| MEGDNN_FOREACH_DTYPE_NAME(cb) | |||||
| #undef cb | |||||
| case DTypeEnum::Quantized8Asymm: | |||||
| return dtype::Quantized8Asymm{param.Quantized8Asymm}; | |||||
| case DTypeEnum::QuantizedS8: | |||||
| return dtype::QuantizedS8{param.QuantizedS8}; | |||||
| case DTypeEnum::QuantizedS32: | |||||
| return dtype::QuantizedS32{param.QuantizedS32}; | |||||
| default: | |||||
| mgb_assert( | |||||
| false, "unexpected old serialized dtype: invalid enumv %d", | |||||
| static_cast<uint32_t>(enumv)); | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct OprPersistentParam<opr::AddUpdate> { | struct OprPersistentParam<opr::AddUpdate> { | ||||
| using Param = opr_add_update::PersistentAddUpdateParam; | using Param = opr_add_update::PersistentAddUpdateParam; | ||||
| @@ -104,7 +135,18 @@ struct ParamConverter<megdnn::DType> { | |||||
| return fbs::intl::build_dtype(builder, dtype); | return fbs::intl::build_dtype(builder, dtype); | ||||
| } | } | ||||
| }; | }; | ||||
| } // namespace fbs | |||||
| template <> | |||||
| struct ParamConverter<SerializedDTypeV1> { | |||||
| using FlatBufferType = SerializedDTypeV1; | |||||
| static SerializedDTypeV1 to_param(const FlatBufferType* fb) { | |||||
| mgb_assert( | |||||
| false, | |||||
| "You are calling SerializedDTypeV1 in flatbuffer, you should not call " | |||||
| "here, this code is just to avoid compiling errors, but not be used in " | |||||
| "flatbuffer."); | |||||
| } | |||||
| }; | |||||
| }; // namespace fbs | |||||
| #endif | #endif | ||||
| template <> | template <> | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/serialization/serializer.h" | |||||
| #include "megbrain/test/autocheck.h" | #include "megbrain/test/autocheck.h" | ||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| #include "megbrain/test/megdnn_helper.h" | #include "megbrain/test/megdnn_helper.h" | ||||
| @@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| #if MGB_ENABLE_FBS_SERIALIZATION | |||||
| TEST(TestOprDNN, MatrixMulSerialization) { | |||||
| using namespace serialization; | |||||
| auto fname = output_file("MatrixMulSerializationTest"); | |||||
| auto dump = [&]() { | |||||
| opr::MatrixMul::Param param; | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| auto graph = ComputingGraph::make(); | |||||
| HostTensorND a_host{cn, {24, 24}, dtype::Float32()}; | |||||
| HostTensorND b_host{cn, {24, 24}, dtype::Float32()}; | |||||
| auto a = opr::ImmutableTensor::make(*graph, a_host); | |||||
| auto b = opr::ImmutableTensor::make(*graph, b_host); | |||||
| auto opr = opr::MatrixMul::make(a, b, param, {}); | |||||
| auto dumper = GraphDumper::make( | |||||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||||
| auto rst = dumper->dump({opr}); | |||||
| ASSERT_EQ(rst.outputs.size(), 1u); | |||||
| }; | |||||
| auto load = [&]() { | |||||
| auto loader = GraphLoader::make( | |||||
| InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||||
| auto rst = loader->load(); | |||||
| ASSERT_EQ(rst.output_var_list.size(), 1u); | |||||
| auto opr = rst.output_var_list[0].node()->owner_opr(); | |||||
| ASSERT_TRUE(opr->same_type<opr::MatrixMul>()); | |||||
| }; | |||||
| dump(); | |||||
| load(); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| // | |||||
| // | |||||
| @@ -47,7 +47,13 @@ namespace { | |||||
| constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; | constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; | ||||
| constexpr uint32_t MGB_MAGIC = 0x5342474D; | |||||
| constexpr uint32_t MGB_MAGIC = 0x4342474D; | |||||
| // In order to maintain compatibility and to allow old models to be loaded, we keep | |||||
| // the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) | |||||
| constexpr uint32_t MAGIC_V0 = 0x5342474D; | |||||
| // Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the | |||||
| // old magic(MAGIC_V0) is false. | |||||
| bool magic_compare = true; | |||||
| template <typename T> | template <typename T> | ||||
| bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) { | bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) { | ||||
| @@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso | |||||
| } | } | ||||
| } | } | ||||
| //! feature bits for backward compatibility; default value should be 0 | |||||
| struct FeatureBits64 { | |||||
| //! reserved for new fields | |||||
| uint64_t : 64; | |||||
| static void write(OutputFile& fout) { | |||||
| static_assert(sizeof(FeatureBits64) == 8, "bad feature bits"); | |||||
| FeatureBits64 fb64; | |||||
| memset(&fb64, 0, sizeof(fb64)); | |||||
| fout.write(&fb64, 8); | |||||
| } | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -266,7 +284,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
| } | } | ||||
| fbs::OperatorBuilder builder(m_builder); | fbs::OperatorBuilder builder(m_builder); | ||||
| builder.add_type_id(registry->unversioned_type_id); | |||||
| builder.add_type_id(registry->persist_type_id); | |||||
| builder.add_inputs(inputs); | builder.add_inputs(inputs); | ||||
| if (m_config.keep_opr_priority) { | if (m_config.keep_opr_priority) { | ||||
| builder.add_priority(opr->node_prop().attribute().priority); | builder.add_priority(opr->node_prop().attribute().priority); | ||||
| @@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||||
| uint32_t magic = MGB_MAGIC; | uint32_t magic = MGB_MAGIC; | ||||
| m_file->write(&magic, sizeof(magic)); | m_file->write(&magic, sizeof(magic)); | ||||
| // write FeatureBits | |||||
| FeatureBits64::write(*m_file); | |||||
| // Padding | // Padding | ||||
| uint32_t reserved = 0; | uint32_t reserved = 0; | ||||
| m_file->write(&reserved, sizeof(reserved)); | m_file->write(&reserved, sizeof(reserved)); | ||||
| @@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) { | |||||
| class GraphLoaderOSS final : public GraphLoader { | class GraphLoaderOSS final : public GraphLoader { | ||||
| const LoadConfig* m_cur_load_config = nullptr; | const LoadConfig* m_cur_load_config = nullptr; | ||||
| std::unique_ptr<InputFile> m_file; | std::unique_ptr<InputFile> m_file; | ||||
| FeatureBits64 m_feature_bits; | |||||
| SharedBuffer m_graph_buf{{}, 0}; | SharedBuffer m_graph_buf{{}, 0}; | ||||
| const fbs::Graph* m_graph; | const fbs::Graph* m_graph; | ||||
| SharedTensorIDMap m_shared_tensor_map; | SharedTensorIDMap m_shared_tensor_map; | ||||
| @@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb | |||||
| } | } | ||||
| config.comp_node_arr(comp_node_arr); | config.comp_node_arr(comp_node_arr); | ||||
| } | } | ||||
| auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||||
| const OprRegistry* registry; | |||||
| if (magic_compare) { | |||||
| registry = OprRegistry::find_by_id(fbopr->type_id()); | |||||
| } else { | |||||
| registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||||
| } | |||||
| mgb_throw_if( | mgb_throw_if( | ||||
| !registry, SerializationError, | !registry, SerializationError, | ||||
| "failed to find opr with type %s, use python env " | "failed to find opr with type %s, use python env " | ||||
| @@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi | |||||
| uint32_t magic; | uint32_t magic; | ||||
| m_file->read(&magic, sizeof(magic)); | m_file->read(&magic, sizeof(magic)); | ||||
| mgb_throw_if( | mgb_throw_if( | ||||
| magic != MGB_MAGIC, SerializationError, | |||||
| "wrong magic: wanted %#08x, actual %#08x (not a invalid fbs " | |||||
| (magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError, | |||||
| "wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs " | |||||
| "model?)", | "model?)", | ||||
| MGB_MAGIC, magic); | |||||
| MGB_MAGIC, MAGIC_V0, magic); | |||||
| if (magic == MGB_MAGIC) { | |||||
| // read FeatureBits | |||||
| magic_compare = true; | |||||
| m_file->read(&m_feature_bits, 8); | |||||
| } else { | |||||
| magic_compare = false; | |||||
| } | |||||
| m_file->skip(4); | m_file->skip(4); | ||||
| uint64_t offset_to_fbs; | uint64_t offset_to_fbs; | ||||
| @@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) { | |||||
| uint64_t magic_with_reserved = 0; | uint64_t magic_with_reserved = 0; | ||||
| file.read(&magic_with_reserved, sizeof(magic_with_reserved)); | file.read(&magic_with_reserved, sizeof(magic_with_reserved)); | ||||
| file.skip(-sizeof(magic_with_reserved)); | file.skip(-sizeof(magic_with_reserved)); | ||||
| return magic_with_reserved == MGB_MAGIC; | |||||
| return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); | |||||
| } | } | ||||
| } // namespace serialization | } // namespace serialization | ||||
| @@ -199,7 +199,7 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||||
| static ser::OprWithOutputAccessor compat_loader( \ | static ser::OprWithOutputAccessor compat_loader( \ | ||||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | ||||
| const mgb::cg::OperatorNodeConfig& config) { \ | const mgb::cg::OperatorNodeConfig& config) { \ | ||||
| auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||||
| auto&& ctx_ = static_cast<ser::OprLoadContext&>(ctx); \ | |||||
| return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ | return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ | ||||
| } \ | } \ | ||||
| static void entry() { \ | static void entry() { \ | ||||