GitOrigin-RevId: e4771d6bc4
tags/v1.7.0.m1
| @@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase): | |||
| self._last_param = p | |||
| self._param_fields = [] | |||
| self._fb_fields = ["builder"] | |||
| if p.is_legacy: | |||
| self._skip_current_param = True | |||
| return | |||
| self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||
| p.name, indent=1) | |||
| self._write("using MegDNNType = megdnn::param::%s;", p.name) | |||
| @@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| def _on_param_begin(self, p): | |||
| self._last_param = p | |||
| self._cur_const_val = {} | |||
| if p.is_legacy: | |||
| self._skip_current_param = True | |||
| return | |||
| self._write_doc(p.name) | |||
| self._write("table %s {", p.name, indent=1) | |||
| @@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase): | |||
| def _on_param_begin(self, p): | |||
| self._last_param = p | |||
| if p.is_legacy: | |||
| self._skip_current_param = True | |||
| return | |||
| self._packed = True | |||
| self._current_tparams = [] | |||
| self._const = set() | |||
| @@ -62,6 +62,37 @@ struct PersistentAddUpdateParam { | |||
| } // namespace opr_add_update | |||
| // Old SerializedDType used in MegBrain 7.22.0 - 7.23.1 | |||
| // Should be kept as-is even if there are new dtypes. | |||
| struct SerializedDTypeV1 { | |||
| static constexpr uint32_t TAG = megdnn::param::FakeSerializedDType::TAG; | |||
| DTypeEnum enumv; | |||
| union { | |||
| megdnn::DTypeParam<dtype::Quantized8Asymm> Quantized8Asymm; | |||
| megdnn::DTypeParam<dtype::QuantizedS8> QuantizedS8; | |||
| megdnn::DTypeParam<dtype::QuantizedS32> QuantizedS32; | |||
| } param; | |||
| operator DType() const { | |||
| switch (enumv) { | |||
| #define cb(_dt) \ | |||
| case DTypeEnum::_dt: \ | |||
| return DType::from_enum(enumv); | |||
| MEGDNN_FOREACH_DTYPE_NAME(cb) | |||
| #undef cb | |||
| case DTypeEnum::Quantized8Asymm: | |||
| return dtype::Quantized8Asymm{param.Quantized8Asymm}; | |||
| case DTypeEnum::QuantizedS8: | |||
| return dtype::QuantizedS8{param.QuantizedS8}; | |||
| case DTypeEnum::QuantizedS32: | |||
| return dtype::QuantizedS32{param.QuantizedS32}; | |||
| default: | |||
| mgb_assert( | |||
| false, "unexpected old serialized dtype: invalid enumv %d", | |||
| static_cast<uint32_t>(enumv)); | |||
| } | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprPersistentParam<opr::AddUpdate> { | |||
| using Param = opr_add_update::PersistentAddUpdateParam; | |||
| @@ -104,7 +135,18 @@ struct ParamConverter<megdnn::DType> { | |||
| return fbs::intl::build_dtype(builder, dtype); | |||
| } | |||
| }; | |||
| } // namespace fbs | |||
| template <> | |||
| struct ParamConverter<SerializedDTypeV1> { | |||
| using FlatBufferType = SerializedDTypeV1; | |||
| static SerializedDTypeV1 to_param(const FlatBufferType* fb) { | |||
| mgb_assert( | |||
| false, | |||
| "You are calling SerializedDTypeV1 in flatbuffer, you should not call " | |||
| "here, this code is just to avoid compiling errors, but not be used in " | |||
| "flatbuffer."); | |||
| } | |||
| }; | |||
| }; // namespace fbs | |||
| #endif | |||
| template <> | |||
| @@ -16,6 +16,7 @@ | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/serialization/serializer.h" | |||
| #include "megbrain/test/autocheck.h" | |||
| #include "megbrain/test/helper.h" | |||
| #include "megbrain/test/megdnn_helper.h" | |||
| @@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) { | |||
| } | |||
| #endif | |||
| #if MGB_ENABLE_FBS_SERIALIZATION | |||
| TEST(TestOprDNN, MatrixMulSerialization) { | |||
| using namespace serialization; | |||
| auto fname = output_file("MatrixMulSerializationTest"); | |||
| auto dump = [&]() { | |||
| opr::MatrixMul::Param param; | |||
| auto cn = CompNode::load("cpu0"); | |||
| auto graph = ComputingGraph::make(); | |||
| HostTensorND a_host{cn, {24, 24}, dtype::Float32()}; | |||
| HostTensorND b_host{cn, {24, 24}, dtype::Float32()}; | |||
| auto a = opr::ImmutableTensor::make(*graph, a_host); | |||
| auto b = opr::ImmutableTensor::make(*graph, b_host); | |||
| auto opr = opr::MatrixMul::make(a, b, param, {}); | |||
| auto dumper = GraphDumper::make( | |||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||
| auto rst = dumper->dump({opr}); | |||
| ASSERT_EQ(rst.outputs.size(), 1u); | |||
| }; | |||
| auto load = [&]() { | |||
| auto loader = GraphLoader::make( | |||
| InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||
| auto rst = loader->load(); | |||
| ASSERT_EQ(rst.output_var_list.size(), 1u); | |||
| auto opr = rst.output_var_list[0].node()->owner_opr(); | |||
| ASSERT_TRUE(opr->same_type<opr::MatrixMul>()); | |||
| }; | |||
| dump(); | |||
| load(); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| // | |||
| // | |||
| @@ -47,7 +47,13 @@ namespace { | |||
| constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; | |||
| constexpr uint32_t MGB_MAGIC = 0x5342474D; | |||
| constexpr uint32_t MGB_MAGIC = 0x4342474D; | |||
| // In order to maintain compatibility and to allow old models to be loaded, we keep | |||
| // the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) | |||
| constexpr uint32_t MAGIC_V0 = 0x5342474D; | |||
| // Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the | |||
| // old magic(MAGIC_V0) is false. | |||
| bool magic_compare = true; | |||
| template <typename T> | |||
| bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) { | |||
| @@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso | |||
| } | |||
| } | |||
| //! feature bits for backward compatibility; default value should be 0 | |||
| struct FeatureBits64 { | |||
| //! reserved for new fields | |||
| uint64_t : 64; | |||
| static void write(OutputFile& fout) { | |||
| static_assert(sizeof(FeatureBits64) == 8, "bad feature bits"); | |||
| FeatureBits64 fb64; | |||
| memset(&fb64, 0, sizeof(fb64)); | |||
| fout.write(&fb64, 8); | |||
| } | |||
| }; | |||
| } // namespace | |||
| namespace mgb { | |||
| @@ -266,7 +284,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||
| } | |||
| fbs::OperatorBuilder builder(m_builder); | |||
| builder.add_type_id(registry->unversioned_type_id); | |||
| builder.add_type_id(registry->persist_type_id); | |||
| builder.add_inputs(inputs); | |||
| if (m_config.keep_opr_priority) { | |||
| builder.add_priority(opr->node_prop().attribute().priority); | |||
| @@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||
| uint32_t magic = MGB_MAGIC; | |||
| m_file->write(&magic, sizeof(magic)); | |||
| // write FeatureBits | |||
| FeatureBits64::write(*m_file); | |||
| // Padding | |||
| uint32_t reserved = 0; | |||
| m_file->write(&reserved, sizeof(reserved)); | |||
| @@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) { | |||
| class GraphLoaderOSS final : public GraphLoader { | |||
| const LoadConfig* m_cur_load_config = nullptr; | |||
| std::unique_ptr<InputFile> m_file; | |||
| FeatureBits64 m_feature_bits; | |||
| SharedBuffer m_graph_buf{{}, 0}; | |||
| const fbs::Graph* m_graph; | |||
| SharedTensorIDMap m_shared_tensor_map; | |||
| @@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb | |||
| } | |||
| config.comp_node_arr(comp_node_arr); | |||
| } | |||
| auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||
| const OprRegistry* registry; | |||
| if (magic_compare) { | |||
| registry = OprRegistry::find_by_id(fbopr->type_id()); | |||
| } else { | |||
| registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||
| } | |||
| mgb_throw_if( | |||
| !registry, SerializationError, | |||
| "failed to find opr with type %s, use python env " | |||
| @@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi | |||
| uint32_t magic; | |||
| m_file->read(&magic, sizeof(magic)); | |||
| mgb_throw_if( | |||
| magic != MGB_MAGIC, SerializationError, | |||
| "wrong magic: wanted %#08x, actual %#08x (not a invalid fbs " | |||
| (magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError, | |||
| "wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs " | |||
| "model?)", | |||
| MGB_MAGIC, magic); | |||
| MGB_MAGIC, MAGIC_V0, magic); | |||
| if (magic == MGB_MAGIC) { | |||
| // read FeatureBits | |||
| magic_compare = true; | |||
| m_file->read(&m_feature_bits, 8); | |||
| } else { | |||
| magic_compare = false; | |||
| } | |||
| m_file->skip(4); | |||
| uint64_t offset_to_fbs; | |||
| @@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) { | |||
| uint64_t magic_with_reserved = 0; | |||
| file.read(&magic_with_reserved, sizeof(magic_with_reserved)); | |||
| file.skip(-sizeof(magic_with_reserved)); | |||
| return magic_with_reserved == MGB_MAGIC; | |||
| return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); | |||
| } | |||
| } // namespace serialization | |||
| @@ -199,7 +199,7 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||
| static ser::OprWithOutputAccessor compat_loader( \ | |||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||
| auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||
| auto&& ctx_ = static_cast<ser::OprLoadContext&>(ctx); \ | |||
| return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ | |||
| } \ | |||
| static void entry() { \ | |||