You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.cpp 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. /**
  2. * \file lite/load_and_run/src/models/model.cpp
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. */
  9. #include "model.h"
  10. #include <iostream>
  11. #include <memory>
  12. #include "model_lite.h"
  13. #include "model_mdl.h"
  14. using namespace lar;
  15. ModelType ModelBase::get_model_type(std::string model_path) {
  16. //! read magic number of dump file
  17. FILE* fin = fopen(model_path.c_str(), "rb");
  18. mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  19. char buf[16];
  20. mgb_assert(fread(buf, 1, 16, fin) == 16, "read model failed");
  21. fclose(fin);
  22. // get model type
  23. std::string tag(buf);
  24. ModelType type;
  25. if (tag.substr(0, 7) == std::string("mgb0001") ||
  26. tag.substr(0, 8) == std::string("mgb0000a") ||
  27. tag.substr(0, 4) == std::string("MGBS") ||
  28. tag.substr(0, 4) == std::string("MGBC") ||
  29. tag.substr(0, 8) == std::string("mgbtest0")) {
  30. type = ModelType::MEGDL_MODEL;
  31. } else {
  32. type = ModelType::LITE_MODEL;
  33. }
  34. return type;
  35. }
  36. std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) {
  37. mgb_log_debug("model path %s\n", model_path.c_str());
  38. auto model_type = get_model_type(model_path);
  39. if (ModelType::LITE_MODEL == model_type) {
  40. return std::make_shared<ModelLite>(model_path);
  41. } else if (ModelType::MEGDL_MODEL == model_type) {
  42. if (FLAGS_lite)
  43. return std::make_shared<ModelLite>(model_path);
  44. else
  45. return std::make_shared<ModelMdl>(model_path);
  46. } else {
  47. return nullptr;
  48. }
  49. }
  50. DEFINE_bool(lite, false, "using lite model to run mdl model");
  51. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}