GitOrigin-RevId: 8f21fda9d3
tags/v1.10.0
| @@ -1,15 +1,5 @@ | |||
| /** | |||
| * \file inlude/lite/pack_model.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace lite { | |||
| @@ -67,7 +57,7 @@ private: | |||
| Header m_header; | |||
| friend class FbsHelper; | |||
| FbsHelper* m_fbs_helper; | |||
| std::shared_ptr<FbsHelper> m_fbs_helper; | |||
| }; | |||
| } // namespace lite | |||
| @@ -5,7 +5,6 @@ cc_library( | |||
| hdrs = glob(["src/**/*.h"]), | |||
| includes = ["src"], | |||
| features = if_opt([ | |||
| "no_exceptions", | |||
| "no_rtti", | |||
| ]), | |||
| @@ -1,15 +1,7 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/common.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| DECLARE_int32(thread); | |||
| namespace lar { | |||
| /*! | |||
| @@ -71,6 +63,122 @@ enum class OptLayoutType { | |||
| NHWCD4 = 1 << 6, | |||
| NCHW44_DOT = 1 << 7 | |||
| }; | |||
| /** | |||
| * base class to story option value | |||
| */ | |||
| enum class JsonValueType { | |||
| Bool = 0, | |||
| Number, | |||
| NumberInt32, | |||
| NumberUint64, | |||
| String, | |||
| }; | |||
| struct Value { | |||
| virtual JsonValueType get_type() const = 0; | |||
| virtual std::string type_string() const = 0; | |||
| virtual void reset_value() = 0; | |||
| virtual ~Value() = default; | |||
| }; | |||
| /** | |||
| * class for double option | |||
| */ | |||
| struct Number final : public Value { | |||
| Number(double v) : m_val(v), m_default_val(v) {} | |||
| static std::shared_ptr<Number> make(double v) { | |||
| return std::make_shared<Number>(v); | |||
| } | |||
| void set_value(double v) { m_val = v; } | |||
| double get_value() { return m_val; } | |||
| double get_default() { return m_default_val; } | |||
| void reset_value() override { m_val = m_default_val; } | |||
| JsonValueType get_type() const override { return JsonValueType::Number; } | |||
| std::string type_string() const override { return "Number"; } | |||
| private: | |||
| double m_val; | |||
| double m_default_val; | |||
| }; | |||
| /** | |||
| * class for int32_t option | |||
| */ | |||
| struct NumberInt32 final : public Value { | |||
| NumberInt32(int32_t v) : m_val(v), m_default_val(v) {} | |||
| static std::shared_ptr<NumberInt32> make(int32_t v) { | |||
| return std::make_shared<NumberInt32>(v); | |||
| } | |||
| void set_value(int32_t v) { m_val = v; } | |||
| int32_t get_value() { return m_val; } | |||
| int32_t get_default() { return m_default_val; } | |||
| void reset_value() override { m_val = m_default_val; } | |||
| JsonValueType get_type() const override { return JsonValueType::NumberInt32; } | |||
| std::string type_string() const override { return "NumberInt32"; } | |||
| private: | |||
| int32_t m_val; | |||
| int32_t m_default_val; | |||
| }; | |||
| /** | |||
| * class for uint64 option | |||
| */ | |||
| struct NumberUint64 final : public Value { | |||
| NumberUint64(uint64_t v) : m_val(v), m_default_val(v) {} | |||
| static std::shared_ptr<NumberUint64> make(uint64_t v) { | |||
| return std::make_shared<NumberUint64>(v); | |||
| } | |||
| void set_value(uint64_t v) { m_val = v; } | |||
| uint64_t get_value() { return m_val; } | |||
| uint64_t get_default() { return m_default_val; } | |||
| void reset_value() override { m_val = m_default_val; } | |||
| JsonValueType get_type() const override { return JsonValueType::NumberUint64; } | |||
| std::string type_string() const override { return "NumberUint64"; } | |||
| private: | |||
| uint64_t m_val; | |||
| uint64_t m_default_val; | |||
| }; | |||
| /** | |||
| * class for boolean option | |||
| */ | |||
| struct Bool final : public Value { | |||
| Bool(bool v) : m_val(v), m_default_val(v) {} | |||
| static std::shared_ptr<Bool> make(bool v) { return std::make_shared<Bool>(v); } | |||
| void set_value(bool v) { m_val = v; } | |||
| bool get_value() { return m_val; } | |||
| bool get_default() { return m_default_val; } | |||
| void reset_value() override { m_val = m_default_val; } | |||
| JsonValueType get_type() const override { return JsonValueType::Bool; } | |||
| std::string type_string() const override { return "Bool"; } | |||
| private: | |||
| bool m_val; | |||
| bool m_default_val; | |||
| }; | |||
| /** | |||
| * class for string option | |||
| */ | |||
| struct String final : public Value { | |||
| String(std::string v) : m_val(v), m_default_val(v) {} | |||
| static std::shared_ptr<String> make(const std::string& v) { | |||
| return std::make_shared<String>(v); | |||
| } | |||
| void set_value(const std::string& v) { m_val = v; } | |||
| std::string& get_value() { return m_val; } | |||
| std::string get_default() { return m_default_val; } | |||
| void reset_value() override { m_val = m_default_val; } | |||
| JsonValueType get_type() const override { return JsonValueType::String; } | |||
| std::string type_string() const override { return "String"; } | |||
| private: | |||
| std::string m_val; | |||
| std::string m_default_val; | |||
| }; | |||
| using OptionValMap = std::unordered_map<std::string, std::shared_ptr<lar::Value>>; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -5,9 +5,7 @@ using namespace mgb; | |||
| template <typename T> | |||
| T* JsonLoader::Value::safe_cast() { | |||
| T* ptr = (T*)(this); | |||
| if (nullptr == ptr) { | |||
| fprintf(stderr, "cast ptr is null\n"); | |||
| } | |||
| mgb_assert(nullptr != ptr, "cast ptr is null\n"); | |||
| return ptr; | |||
| } | |||
| @@ -31,6 +29,12 @@ std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value:: | |||
| return t->m_obj; | |||
| } | |||
| std::vector<std::string>& JsonLoader::Value::keys() { | |||
| mgb_assert(Type::OBJECT == m_type); | |||
| auto t = safe_cast<JsonLoader::ObjectValue>(); | |||
| return t->m_keys; | |||
| } | |||
| size_t JsonLoader::Value::len() { | |||
| if (Type::ARRAY == m_type) { | |||
| auto t = safe_cast<JsonLoader::ArrayValue>(); | |||
| @@ -54,6 +58,12 @@ double JsonLoader::Value::number() { | |||
| return t->value(); | |||
| } | |||
| bool JsonLoader::Value::Bool() { | |||
| mgb_assert(Type::BOOL == m_type); | |||
| auto t = safe_cast<JsonLoader::BoolValue>(); | |||
| return t->value(); | |||
| } | |||
| std::string JsonLoader::Value::str() { | |||
| if (Type::STRING == m_type) { | |||
| auto t = safe_cast<StringValue>(); | |||
| @@ -69,7 +79,7 @@ void JsonLoader::expect(char c) { | |||
| void JsonLoader::skip_whitespace() { | |||
| const char* p = m_buf; | |||
| while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') { | |||
| while (' ' == *p || '\t' == *p || '\n' == *p || '\r' == *p) { | |||
| ++p; | |||
| } | |||
| m_buf = p; | |||
| @@ -80,11 +90,12 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||
| skip_whitespace(); | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| JsonLoader::ObjectValue* pObject = new JsonLoader::ObjectValue(); | |||
| std::unique_ptr<JsonLoader::ObjectValue> pObject = | |||
| std::make_unique<JsonLoader::ObjectValue>(); | |||
| if ('}' == *m_buf) { | |||
| m_buf = m_buf + 1; | |||
| ret.reset((JsonLoader::Value*)(pObject)); | |||
| ret = std::move(pObject); | |||
| return ret; | |||
| } | |||
| @@ -113,6 +124,7 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||
| } | |||
| pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal))); | |||
| pObject->m_keys.push_back(key->str()); | |||
| skip_whitespace(); | |||
| if (',' == (*m_buf)) { | |||
| @@ -126,22 +138,21 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||
| break; | |||
| } | |||
| } | |||
| ret.reset((JsonLoader::Value*)(pObject)); | |||
| ret = std::move(pObject); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | |||
| expect('['); | |||
| skip_whitespace(); | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| JsonLoader::ArrayValue* pArray = new JsonLoader::ArrayValue(); | |||
| std::unique_ptr<JsonLoader::ArrayValue> pArray = | |||
| std::make_unique<JsonLoader::ArrayValue>(); | |||
| if (']' == *m_buf) { | |||
| m_buf = m_buf + 1; | |||
| ret.reset((JsonLoader::Value*)(pArray)); | |||
| ret = std::move(pArray); | |||
| return ret; | |||
| } | |||
| @@ -168,15 +179,14 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | |||
| } | |||
| } | |||
| ret.reset((JsonLoader::Value*)(pArray)); | |||
| ret = std::move(pArray); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | |||
| expect('\"'); | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| JsonLoader::StringValue* pStr = new JsonLoader::StringValue(); | |||
| std::unique_ptr<JsonLoader::StringValue> pStr = | |||
| std::make_unique<JsonLoader::StringValue>(); | |||
| const char* p = m_buf; | |||
| while (true) { | |||
| @@ -189,7 +199,7 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | |||
| } | |||
| } | |||
| m_buf = p; | |||
| ret.reset((JsonLoader::Value*)(pStr)); | |||
| std::unique_ptr<JsonLoader::Value> ret = std::move(pStr); | |||
| return ret; | |||
| } | |||
| @@ -207,31 +217,31 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_number() { | |||
| return; | |||
| }; | |||
| if (*p == '-') | |||
| if ('-' == *p) | |||
| p++; | |||
| if (*p == '0') | |||
| if ('0' == *p) | |||
| p++; | |||
| else { | |||
| loop_digit(std::ref(p)); | |||
| } | |||
| if (*p == '.') { | |||
| if ('.' == *p) { | |||
| p++; | |||
| loop_digit(std::ref(p)); | |||
| } | |||
| if (*p == 'e' || *p == 'E') { | |||
| if ('e' == *p || 'E' == *p) { | |||
| p++; | |||
| if (*p == '+' || *p == '-') | |||
| if ('+' == *p || '-' == *p) | |||
| p++; | |||
| loop_digit(std::ref(p)); | |||
| } | |||
| JsonLoader::NumberValue* pNum = new JsonLoader::NumberValue(); | |||
| std::unique_ptr<JsonLoader::NumberValue> pNum = | |||
| std::make_unique<JsonLoader::NumberValue>(); | |||
| pNum->m_value = strtod(m_buf, nullptr); | |||
| m_buf = p; | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| ret.reset((JsonLoader::Value*)(pNum)); | |||
| std::unique_ptr<JsonLoader::Value> ret = std::move(pNum); | |||
| return ret; | |||
| } | |||
| @@ -243,6 +253,10 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() { | |||
| return parse_object(); | |||
| case '\"': | |||
| return parse_string(); | |||
| case 't': | |||
| return parse_bool(); | |||
| case 'f': | |||
| return parse_bool(); | |||
| case '\0': | |||
| m_state = State::BAD_TYPE; | |||
| break; | |||
| @@ -252,6 +266,37 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() { | |||
| return nullptr; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_bool() { | |||
| const char* p = m_buf; | |||
| std::string value; | |||
| if ('t' == *p) { | |||
| value = ""; | |||
| for (size_t idx = 0; idx < 4; ++idx) { | |||
| value += *p++; | |||
| } | |||
| } else if ('f' == *p) { | |||
| value = ""; | |||
| for (size_t idx = 0; idx < 5; ++idx) { | |||
| value += *p++; | |||
| } | |||
| } | |||
| bool val = false; | |||
| if ("true" == value) { | |||
| val = true; | |||
| } else if ("false" == value) { | |||
| val = false; | |||
| } else { | |||
| mgb_log_error("invalid value: %s for possible bool value", value.c_str()); | |||
| } | |||
| std::unique_ptr<JsonLoader::BoolValue> pBool = | |||
| std::make_unique<JsonLoader::BoolValue>(); | |||
| pBool->m_value = val; | |||
| m_buf = p; | |||
| std::unique_ptr<JsonLoader::Value> ret = std::move(pBool); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::load( | |||
| const char* content, const size_t size) { | |||
| m_buf = content; | |||
| @@ -18,7 +18,7 @@ public: | |||
| // base class for different value format | |||
| class Value { | |||
| protected: | |||
| enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY }; | |||
| enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY, BOOL }; | |||
| Type m_type; | |||
| public: | |||
| @@ -39,12 +39,16 @@ public: | |||
| bool is_str() { return Type::STRING == m_type; } | |||
| bool is_bool() { return Type::BOOL == m_type; } | |||
| std::unique_ptr<Value>& operator[](const std::string& key); | |||
| std::unique_ptr<Value>& operator[](const size_t index); | |||
| std::map<std::string, std::unique_ptr<Value>>& objects(); | |||
| std::vector<std::string>& keys(); | |||
| size_t len(); | |||
| megdnn::SmallVector<std::unique_ptr<Value>>& array(); | |||
| @@ -52,6 +56,8 @@ public: | |||
| double number(); | |||
| std::string str(); | |||
| bool Bool(); | |||
| }; | |||
| void expect(char c); | |||
| @@ -68,6 +74,8 @@ public: | |||
| std::unique_ptr<Value> parse_value(); | |||
| std::unique_ptr<Value> parse_bool(); | |||
| enum struct State : uint8_t { | |||
| OK = 0, | |||
| BAD_TYPE, | |||
| @@ -137,21 +145,26 @@ public: | |||
| class ObjectValue final : public Value { | |||
| std::map<std::string, std::unique_ptr<Value>> m_obj; | |||
| std::vector<std::string> m_keys; | |||
| public: | |||
| ObjectValue() : Value(Type::OBJECT) {} | |||
| ObjectValue(ObjectValue& arr) : Value(arr) { | |||
| m_obj.clear(); | |||
| m_keys.clear(); | |||
| for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | |||
| m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | |||
| m_keys.push_back(itra->first); | |||
| } | |||
| } | |||
| ObjectValue(ObjectValue&& arr) : Value(arr) { | |||
| m_obj.clear(); | |||
| m_keys.clear(); | |||
| for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | |||
| m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | |||
| m_keys.push_back(itra->first); | |||
| } | |||
| } | |||
| @@ -160,9 +173,19 @@ public: | |||
| const std::string&); | |||
| friend std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | |||
| Value::objects(); | |||
| friend std::vector<std::string>& JsonLoader::Value::keys(); | |||
| friend size_t JsonLoader::Value::len(); | |||
| }; | |||
| class BoolValue final : public Value { | |||
| bool m_value; | |||
| public: | |||
| BoolValue() : Value(Type::BOOL) {} | |||
| bool value() { return m_value; } | |||
| friend std::unique_ptr<Value> JsonLoader::parse_bool(); | |||
| }; | |||
| private: | |||
| const char* m_buf; | |||
| State m_state; | |||
| @@ -0,0 +1,362 @@ | |||
| #include "utils.h" | |||
| using namespace lar; | |||
| /////////////////// JsonOptionsCoder /////////////////// | |||
| #if MGB_ENABLE_JSON | |||
| //! encode option | |||
| void encode_single_options( | |||
| std::pair<std::string, std::shared_ptr<lar::Value>> item, | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>& | |||
| list, | |||
| bool encode_all) { | |||
| auto type = item.second->get_type(); | |||
| if (type == JsonValueType::Bool) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||
| return; | |||
| } | |||
| list.push_back( | |||
| {mgb::json::String(item.first), | |||
| mgb::json::Bool::make(val_ptr->get_value())}); | |||
| } else if (type == JsonValueType::NumberInt32) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(item.second); | |||
| if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||
| return; | |||
| } | |||
| list.push_back( | |||
| {mgb::json::String(item.first), | |||
| mgb::json::NumberInt::make( | |||
| static_cast<int64_t>(val_ptr->get_value()))}); | |||
| } else if (type == JsonValueType::NumberUint64) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(item.second); | |||
| list.push_back( | |||
| {mgb::json::String(item.first), | |||
| mgb::json::NumberInt::make( | |||
| static_cast<int64_t>(val_ptr->get_value()))}); | |||
| } else if (type == JsonValueType::Number) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Number>(item.second); | |||
| list.push_back( | |||
| {mgb::json::String(item.first), | |||
| mgb::json::Number::make(val_ptr->get_value())}); | |||
| } else if (type == JsonValueType::String) { | |||
| auto val_ptr = std::static_pointer_cast<lar::String>(item.second); | |||
| if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||
| return; | |||
| } | |||
| list.push_back( | |||
| {mgb::json::String(item.first), | |||
| mgb::json::String::make(val_ptr->get_value())}); | |||
| } else { | |||
| mgb_log_error( | |||
| "unsupport JsonValueType:%s for lar::Value", | |||
| item.second->type_string().c_str()); | |||
| } | |||
| } | |||
| std::string JsonOptionsCoder::encode(OptionValMap& option_val_map, bool encode_all) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_options; | |||
| for (auto& item : option_val_map) { | |||
| encode_single_options(item, json_options, encode_all); | |||
| } | |||
| auto json_obj = mgb::json::Object::make( | |||
| {{"options", mgb::json::Object::make(json_options)}}); | |||
| return json_obj->to_string(1); | |||
| } | |||
| //! encode device | |||
| std::vector<std::shared_ptr<mgb::json::Object>> JsonOptionsCoder::encode( | |||
| OptionValMap& option_val_map) { | |||
| std::vector<std::shared_ptr<mgb::json::Object>> info; | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_device; | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_options; | |||
| for (auto& item : option_val_map) { | |||
| if ((item.first == "cpu" || item.first == "cpu_default" || | |||
| item.first == "multithread" || item.first == "multithread_default")) { | |||
| auto type = item.second->get_type(); | |||
| if (type == JsonValueType::Bool) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| } | |||
| if (type == JsonValueType::NumberInt32) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| } | |||
| json_device.push_back( | |||
| {mgb::json::String("type"), mgb::json::String::make("CPU")}); | |||
| if (item.first == "cpu_default" || item.first == "multithread_default") { | |||
| json_device.push_back( | |||
| {mgb::json::String("enable_inplace_model"), | |||
| mgb::json::Bool::make(true)}); | |||
| } | |||
| if (item.first == "multithread" || item.first == "multithread_default") { | |||
| json_device.push_back( | |||
| {mgb::json::String("number_threads"), | |||
| mgb::json::NumberInt::make( | |||
| std::static_pointer_cast<lar::NumberInt32>(item.second) | |||
| ->get_value())}); | |||
| if (item.first == "multithread") { | |||
| json_device.push_back( | |||
| {mgb::json::String("device_id"), | |||
| mgb::json::NumberInt::make(0)}); | |||
| } | |||
| } | |||
| } else if (item.first == "cuda") { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| json_device.push_back( | |||
| {mgb::json::String("type"), mgb::json::String::make("CUDA")}); | |||
| json_device.push_back( | |||
| {mgb::json::String("device_id"), mgb::json::NumberInt::make(0)}); | |||
| } else if (item.first == "opencl") { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| json_device.push_back( | |||
| {mgb::json::String("type"), mgb::json::String::make("OPENCL")}); | |||
| } else if ( | |||
| item.first == "record_comp_seq" || item.first == "record_comp_seq2") { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| int comp_node_seq_record_level = item.first == "record_comp_seq" ? 1 : 2; | |||
| json_options.push_back( | |||
| {mgb::json::String("comp_node_seq_record_level"), | |||
| mgb::json::NumberInt::make(comp_node_seq_record_level)}); | |||
| } else if (item.first == "fake_first") { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| json_options.push_back( | |||
| {mgb::json::String("fake_next_exec"), | |||
| mgb::json::Bool::make(val_ptr->get_value())}); | |||
| } else if (item.first == "no_sanity_check") { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| json_options.push_back( | |||
| {mgb::json::String("var_sanity_check_first_run"), | |||
| mgb::json::Bool::make(!val_ptr->get_value())}); | |||
| } else if (item.first == "weight_preprocess") { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (val_ptr->get_value() == val_ptr->get_default()) | |||
| continue; | |||
| json_options.push_back( | |||
| {mgb::json::String("weight_preprocess"), | |||
| mgb::json::Bool::make(val_ptr->get_value())}); | |||
| } | |||
| } | |||
| info.push_back(mgb::json::Object::make( | |||
| {{"options", mgb::json::Object::make(json_options)}})); | |||
| info.push_back(mgb::json::Object::make( | |||
| {{"device", mgb::json::Object::make(json_device)}})); | |||
| return info; | |||
| } | |||
| //! decode options note string into option map | |||
| OptionValMap& JsonOptionsCoder::decode( | |||
| const std::string& code, OptionValMap& option_val_map) { | |||
| std::shared_ptr<mgb::JsonLoader::Value> root = | |||
| m_json_loader.load(code.c_str(), code.size()); | |||
| for (auto& item : root->objects()) { | |||
| auto& value = *item.second; | |||
| //! get all keys in json object | |||
| auto keys = value.keys(); | |||
| //! set the json format options into internal options | |||
| for (auto& val : keys) { | |||
| if (value[val]->is_bool()) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(option_val_map[val]); | |||
| val_ptr->set_value(value[val]->Bool()); | |||
| } else if (value[val]->is_number()) { | |||
| auto type = option_val_map[val]->get_type(); | |||
| if (type == JsonValueType::Number) { | |||
| auto val_ptr = | |||
| std::static_pointer_cast<lar::Number>(option_val_map[val]); | |||
| val_ptr->set_value(value[val]->number()); | |||
| } else if (type == JsonValueType::NumberInt32) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberInt32>( | |||
| option_val_map[val]); | |||
| val_ptr->set_value(static_cast<int32_t>(value[val]->number())); | |||
| } else if (type == JsonValueType::NumberUint64) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberUint64>( | |||
| option_val_map[val]); | |||
| val_ptr->set_value(static_cast<uint64_t>(value[val]->number())); | |||
| } else { | |||
| mgb_log_error( | |||
| "invalid number type:%s to set", | |||
| option_val_map[val]->type_string().c_str()); | |||
| } | |||
| } else if (value[val]->is_str()) { | |||
| auto val_ptr = | |||
| std::static_pointer_cast<lar::String>(option_val_map[val]); | |||
| val_ptr->set_value(value[val]->str()); | |||
| } else { | |||
| mgb_log_error("invalid value type for JsonLoader"); | |||
| } | |||
| } | |||
| } | |||
| return option_val_map; | |||
| } | |||
| #endif | |||
| std::string GflagsOptionsCoder::encode(OptionValMap& option_val_map, bool encode_all) { | |||
| std::vector<std::string> gflags_options; | |||
| for (auto& item : option_val_map) { | |||
| auto type = item.second->get_type(); | |||
| std::string val = "--"; | |||
| if (type == JsonValueType::Bool) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||
| if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||
| continue; | |||
| } | |||
| val += item.first; | |||
| val += "="; | |||
| val += val_ptr->get_value() ? "true" : "false"; | |||
| gflags_options.push_back(val); | |||
| } else if (type == JsonValueType::NumberInt32) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(item.second); | |||
| if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||
| continue; | |||
| } | |||
| val += item.first; | |||
| val += "="; | |||
| val += std::to_string(val_ptr->get_value()); | |||
| gflags_options.push_back(val); | |||
| } else if (type == JsonValueType::NumberUint64) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(item.second); | |||
| val += item.first; | |||
| val += "="; | |||
| val += std::to_string(val_ptr->get_value()); | |||
| gflags_options.push_back(val); | |||
| } else if (type == JsonValueType::Number) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Number>(item.second); | |||
| val += item.first; | |||
| val += "="; | |||
| val += std::to_string(val_ptr->get_value()); | |||
| gflags_options.push_back(val); | |||
| } else if (type == JsonValueType::String) { | |||
| auto val_ptr = std::static_pointer_cast<lar::String>(item.second); | |||
| if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||
| continue; | |||
| } | |||
| val += item.first; | |||
| val += "=\""; | |||
| val += val_ptr->get_value(); | |||
| val += "\""; | |||
| gflags_options.push_back(val); | |||
| } else { | |||
| mgb_log_error( | |||
| "unsupport JsonValueType:%s for lar::Value", | |||
| item.second->type_string().c_str()); | |||
| } | |||
| } | |||
| std::string ret; | |||
| for (auto& item : gflags_options) { | |||
| ret += item; | |||
| ret += "\n"; | |||
| } | |||
| return ret; | |||
| } | |||
| //! decode options note string into option map | |||
| OptionValMap& GflagsOptionsCoder::decode( | |||
| const std::string& code, OptionValMap& option_val_map) { | |||
| std::unordered_map<std::string, std::string> gflags_map; | |||
| auto to_raw_string = [](const std::string& str) { | |||
| auto size = str.size(); | |||
| std::string ret; | |||
| if ('\"' == str[0] && '\"' == str[size - 1]) { | |||
| ret = str.substr(1, size - 2); | |||
| } else { | |||
| ret = str; | |||
| } | |||
| return ret; | |||
| }; | |||
| size_t start = 0; | |||
| size_t end = code.find("\n", start); | |||
| while (end != std::string::npos) { | |||
| auto str = code.substr(start, end - start); | |||
| if (str.substr(0, 2) == "--") { | |||
| size_t idx = str.find("=", 0); | |||
| gflags_map.insert( | |||
| {str.substr(2, idx - 2), to_raw_string(str.substr(idx + 1))}); | |||
| } else { | |||
| mgb_log_error("invaid gflags argument %s", str.c_str()); | |||
| } | |||
| start = end + 1; | |||
| end = code.find("\n", start); | |||
| } | |||
| for (auto& item : gflags_map) { | |||
| if (option_val_map.count(item.first) != 0) { | |||
| auto& option_val = option_val_map[item.first]; | |||
| auto type = option_val->get_type(); | |||
| if (type == JsonValueType::Bool) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Bool>(option_val); | |||
| if (item.second == "true" || item.second == "false") { | |||
| auto val = item.second == "true"; | |||
| val_ptr->set_value(val); | |||
| } | |||
| } else if (type == JsonValueType::NumberInt32) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(option_val); | |||
| MGB_TRY { | |||
| int32_t val = std::stoi(item.second); | |||
| val_ptr->set_value(val); | |||
| } | |||
| MGB_CATCH(std::exception & exc, { | |||
| mgb_log_error( | |||
| "invaid value: %s for %s", item.second.c_str(), | |||
| item.first.c_str()); | |||
| }); | |||
| } else if (type == JsonValueType::NumberUint64) { | |||
| auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(option_val); | |||
| MGB_TRY { | |||
| uint64_t val = std::stoull(item.second); | |||
| val_ptr->set_value(val); | |||
| } | |||
| MGB_CATCH(std::exception & exc, { | |||
| mgb_log_error( | |||
| "invaid value: %s for %s", item.second.c_str(), | |||
| item.first.c_str()); | |||
| }); | |||
| } else if (type == JsonValueType::Number) { | |||
| auto val_ptr = std::static_pointer_cast<lar::Number>(option_val); | |||
| MGB_TRY { | |||
| double val = std::stod(item.second); | |||
| val_ptr->set_value(val); | |||
| } | |||
| MGB_CATCH(std::exception & exc, { | |||
| mgb_log_error( | |||
| "invaid value: %s for %s", item.second.c_str(), | |||
| item.first.c_str()); | |||
| }); | |||
| } else if (type == JsonValueType::String) { | |||
| auto val_ptr = std::static_pointer_cast<lar::String>(option_val); | |||
| val_ptr->set_value(item.second); | |||
| } else { | |||
| mgb_log_error( | |||
| "unsupport JsonValueType:%s for lar::Value", | |||
| option_val->type_string().c_str()); | |||
| } | |||
| } else { | |||
| mgb_log_error("invalid gflags when set runtime options in fitting mode"); | |||
| } | |||
| } | |||
| return option_val_map; | |||
| } | |||
| @@ -0,0 +1,69 @@ | |||
| #pragma once | |||
| #include <vector> | |||
| #include "common.h" | |||
| #include "json_loader.h" | |||
| #include "megbrain/utils/json.h" | |||
| namespace lar { | |||
| /** | |||
| * fitting profiler type | |||
| */ | |||
| enum class ProiflerType { | |||
| TIME_PROFILER = 0, | |||
| UNSPEC_PROFILER = 1, | |||
| }; | |||
| /** | |||
| * option coder type | |||
| */ | |||
| enum class CoderType { | |||
| GFLAGS = 0, | |||
| JSON = 1, | |||
| UNSPEC = 2, | |||
| }; | |||
| /** | |||
| * option coder to transform internal option val into differnet form | |||
| */ | |||
| class OptionsCoder { | |||
| public: | |||
| OptionsCoder(){}; | |||
| //! encode options into given format | |||
| virtual std::string encode(OptionValMap&, bool) = 0; | |||
| //! decode options with given format into option map | |||
| virtual OptionValMap& decode(const std::string&, OptionValMap& val_map) = 0; | |||
| //! destructor | |||
| virtual ~OptionsCoder() = default; | |||
| }; | |||
| #if MGB_ENABLE_JSON | |||
| class JsonOptionsCoder final : public OptionsCoder { | |||
| public: | |||
| JsonOptionsCoder(){}; | |||
| //! encode given options into json format | |||
| std::string encode(OptionValMap&, bool encode_all) override; | |||
| std::vector<std::shared_ptr<mgb::json::Object>> encode(OptionValMap&); | |||
| //! decode given json format options into given options map | |||
| OptionValMap& decode(const std::string&, OptionValMap&) override; | |||
| private: | |||
| mgb::JsonLoader m_json_loader; | |||
| }; | |||
| #endif | |||
| class GflagsOptionsCoder final : public OptionsCoder { | |||
| public: | |||
| GflagsOptionsCoder(){}; | |||
| //! encode given options into gflags format | |||
| std::string encode(OptionValMap&, bool encode_all = false) override; | |||
| //! decode given gflags format options into given options maps | |||
| OptionValMap& decode(const std::string&, OptionValMap&) override; | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,20 +1,37 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/main.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| #include "strategys/strategy.h" | |||
| std::string simple_usage = R"( | |||
| load_and_run: load_and_run <model_path> [options Flags...] | |||
| Flags from lite/load_and_run/src/models/model.cpp: | |||
| -lite type: bool default: false use megengine lite interface to run model | |||
| Flags from lite/load_and_run/src/options/strategy_options.cpp: | |||
| -iter type: int32 default: 10 iteration number for run model | |||
| -thread type: int32 default: 1 thread number for run model when <thread> is supported | |||
| -warmup_iter type: int32 default: 1 iteration number for warm up model before run | |||
| Flags from com_github_gflags_gflags/src/gflags.cc: | |||
| -flagfile type: string default: "" load flags from file | |||
| -fromenv type: string default: "" set flags from the environment [use 'export FLAGS_flag1=value'] | |||
| ... | |||
| Flags from com_github_gflags_gflags/src/gflags_reporting.cc: | |||
| -help type: bool default: false show help on all flags | |||
| -helpmatch type: string default: "" show help on modules whose name contains the specified substr | |||
| -version type: bool default: false show version and build info and exit | |||
| ... | |||
| More details using "--help" to get!! | |||
| )"; | |||
| int main(int argc, char** argv) { | |||
| std::string usage = "load_and_run <model_path> [options...]"; | |||
| std::string usage = "load_and_run <model_path> [options Flags...]"; | |||
| if (argc < 2) { | |||
| printf("usage: %s\n", usage.c_str()); | |||
| printf("usage: %s\n", simple_usage.c_str()); | |||
| return -1; | |||
| } | |||
| gflags::SetUsageMessage(usage); | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model.h" | |||
| #include <iostream> | |||
| #include <memory> | |||
| @@ -56,5 +47,5 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| DEFINE_bool(lite, false, "using lite model to run mdl model"); | |||
| DEFINE_bool(lite, false, "use megengine lite interface to run model"); | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -1,17 +1,8 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| #include "helpers/common.h" | |||
| #include "megbrain/utils/json.h" | |||
| DECLARE_bool(lite); | |||
| namespace lar { | |||
| @@ -45,6 +36,12 @@ public: | |||
| virtual ~ModelBase() = default; | |||
| virtual const std::string& get_model_path() const = 0; | |||
| virtual std::vector<uint8_t> get_model_data() = 0; | |||
| #if MGB_ENABLE_JSON | |||
| //! get model io information | |||
| virtual std::shared_ptr<mgb::json::Object> get_io_info() = 0; | |||
| #endif | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,14 +1,7 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_lite.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model_lite.h" | |||
| #include <gflags/gflags.h> | |||
| #include <cstring> | |||
| #include <map> | |||
| #include "misc.h" | |||
| DECLARE_bool(share_param_mem); | |||
| @@ -51,3 +44,75 @@ void ModelLite::run_model() { | |||
| void ModelLite::wait() { | |||
| m_network->wait(); | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| std::shared_ptr<mgb::json::Object> ModelLite::get_io_info() { | |||
| std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make(); | |||
| std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make(); | |||
| auto get_dtype = [&](lite::Layout& layout) { | |||
| std::map<LiteDataType, std::string> type_map = { | |||
| {LiteDataType::LITE_FLOAT, "float32"}, | |||
| {LiteDataType::LITE_HALF, "float16"}, | |||
| {LiteDataType::LITE_INT64, "int64"}, | |||
| {LiteDataType::LITE_INT, "int32"}, | |||
| {LiteDataType::LITE_UINT, "uint32"}, | |||
| {LiteDataType::LITE_INT16, "int16"}, | |||
| {LiteDataType::LITE_UINT16, "uint16"}, | |||
| {LiteDataType::LITE_INT8, "int8"}, | |||
| {LiteDataType::LITE_UINT8, "uint8"}}; | |||
| return type_map[layout.data_type]; | |||
| }; | |||
| auto make_shape = [](lite::Layout& layout) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| shape; | |||
| for (size_t i = 0; i < layout.ndim; ++i) { | |||
| std::string lable = "dim"; | |||
| lable += std::to_string(layout.ndim - i - 1); | |||
| shape.push_back( | |||
| {mgb::json::String(lable), | |||
| mgb::json::NumberInt::make(layout.shapes[layout.ndim - i - 1])}); | |||
| } | |||
| return shape; | |||
| }; | |||
| auto input_name = m_network->get_all_input_name(); | |||
| for (auto& i : input_name) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_inp; | |||
| auto layout = m_network->get_io_tensor(i)->get_layout(); | |||
| json_inp.push_back( | |||
| {mgb::json::String("shape"), | |||
| mgb::json::Object::make(make_shape(layout))}); | |||
| json_inp.push_back( | |||
| {mgb::json::String("dtype"), | |||
| mgb::json::String::make(get_dtype(layout))}); | |||
| json_inp.push_back({mgb::json::String("name"), mgb::json::String::make(i)}); | |||
| inputs->add(mgb::json::Object::make(json_inp)); | |||
| } | |||
| auto output_name = m_network->get_all_output_name(); | |||
| for (auto& i : output_name) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_out; | |||
| auto layout = m_network->get_io_tensor(i)->get_layout(); | |||
| json_out.push_back( | |||
| {mgb::json::String("shape"), | |||
| mgb::json::Object::make(make_shape(layout))}); | |||
| json_out.push_back( | |||
| {mgb::json::String("dtype"), | |||
| mgb::json::String::make(get_dtype(layout))}); | |||
| json_out.push_back({mgb::json::String("name"), mgb::json::String::make(i)}); | |||
| inputs->add(mgb::json::Object::make(json_out)); | |||
| } | |||
| return mgb::json::Object::make( | |||
| {{"IO", | |||
| mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}}); | |||
| } | |||
| #endif | |||
| std::vector<uint8_t> ModelLite::get_model_data() { | |||
| std::vector<uint8_t> out_data; | |||
| LITE_THROW("unsupported interface: ModelLite::get_model_data() \n"); | |||
| return out_data; | |||
| } | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_lite.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| @@ -39,6 +30,10 @@ public: | |||
| //! wait the end of asynchronous function execution | |||
| void wait() override; | |||
| #if MGB_ENABLE_JSON | |||
| std::shared_ptr<mgb::json::Object> get_io_info() override; | |||
| #endif | |||
| //! enable global layout transform | |||
| void set_layout_transform(bool state) { enable_layout_transform = state; } | |||
| @@ -62,6 +57,8 @@ public: | |||
| const std::string& get_model_path() const override { return model_path; } | |||
| std::vector<uint8_t> get_model_data() override; | |||
| private: | |||
| bool share_model_mem; | |||
| bool enable_layout_transform; | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_mdl.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model_mdl.h" | |||
| #include <gflags/gflags.h> | |||
| #include <iostream> | |||
| @@ -109,3 +100,76 @@ void ModelMdl::run_model() { | |||
| void ModelMdl::wait() { | |||
| m_asyc_exec->wait(); | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| std::shared_ptr<mgb::json::Object> ModelMdl::get_io_info() { | |||
| std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make(); | |||
| std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make(); | |||
| auto get_dtype = [&](megdnn::DType data_type) { | |||
| std::map<megdnn::DTypeEnum, std::string> type_map = { | |||
| {mgb::dtype::Float32().enumv(), "float32"}, | |||
| {mgb::dtype::Int32().enumv(), "int32"}, | |||
| {mgb::dtype::Int16().enumv(), "int16"}, | |||
| {mgb::dtype::Uint16().enumv(), "uint16"}, | |||
| {mgb::dtype::Int8().enumv(), "int8"}, | |||
| {mgb::dtype::Uint8().enumv(), "uint8"}}; | |||
| return type_map[data_type.enumv()]; | |||
| }; | |||
| auto make_shape = [](mgb::TensorShape& shape_) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| shape; | |||
| for (size_t i = 0; i < shape_.ndim; ++i) { | |||
| std::string lable = "dim"; | |||
| lable += std::to_string(shape_.ndim - i - 1); | |||
| shape.push_back( | |||
| {mgb::json::String(lable), | |||
| mgb::json::NumberInt::make(shape_[shape_.ndim - i - 1])}); | |||
| } | |||
| return shape; | |||
| }; | |||
| for (auto&& i : m_load_result.tensor_map) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_inp; | |||
| auto shape_ = i.second->shape(); | |||
| json_inp.push_back( | |||
| {mgb::json::String("shape"), | |||
| mgb::json::Object::make(make_shape(shape_))}); | |||
| json_inp.push_back( | |||
| {mgb::json::String("dtype"), | |||
| mgb::json::String::make(get_dtype(i.second->dtype()))}); | |||
| json_inp.push_back( | |||
| {mgb::json::String("name"), mgb::json::String::make(i.first)}); | |||
| inputs->add(mgb::json::Object::make(json_inp)); | |||
| } | |||
| for (auto&& i : m_load_result.output_var_list) { | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| json_out; | |||
| auto shape_ = i.shape(); | |||
| json_out.push_back( | |||
| {mgb::json::String("shape"), | |||
| mgb::json::Object::make(make_shape(shape_))}); | |||
| json_out.push_back( | |||
| {mgb::json::String("dtype"), | |||
| mgb::json::String::make(get_dtype(i.dtype()))}); | |||
| json_out.push_back( | |||
| {mgb::json::String("name"), mgb::json::String::make(i.node()->name())}); | |||
| outputs->add(mgb::json::Object::make(json_out)); | |||
| } | |||
| return mgb::json::Object::make( | |||
| {{"IO", | |||
| mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}}); | |||
| } | |||
| #endif | |||
| std::vector<uint8_t> ModelMdl::get_model_data() { | |||
| std::vector<uint8_t> out_data; | |||
| auto out_file = mgb::serialization::OutputFile::make_vector_proxy(&out_data); | |||
| using DumpConfig = mgb::serialization::GraphDumper::DumpConfig; | |||
| DumpConfig config{1, false, false}; | |||
| auto dumper = | |||
| mgb::serialization::GraphDumper::make(std::move(out_file), m_format.val()); | |||
| dumper->dump(m_load_result.output_var_list, config); | |||
| return out_data; | |||
| } | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_mdl.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| @@ -42,6 +33,10 @@ public: | |||
| void wait() override; | |||
| #if MGB_ENABLE_JSON | |||
| std::shared_ptr<mgb::json::Object> get_io_info() override; | |||
| #endif | |||
| //! get load result for megDL model | |||
| mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() { | |||
| return m_load_result; | |||
| @@ -109,6 +104,8 @@ public: | |||
| const std::string& get_model_path() const override { return model_path; } | |||
| std::vector<uint8_t> get_model_data() override; | |||
| private: | |||
| bool share_model_mem; | |||
| std::string model_path; | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/device_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <iostream> | |||
| #include <sstream> | |||
| #include "lite/global.h" | |||
| @@ -76,7 +67,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
| loc.type = mgb::CompNode::DeviceType::CPU; | |||
| }; | |||
| } | |||
| #if MGB_CUDA | |||
| #if LITE_WITH_CUDA | |||
| if (enable_cuda) { | |||
| mgb_log_warn("using cuda device\n"); | |||
| model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
| @@ -134,7 +125,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
| XPUDeviceOption::XPUDeviceOption() { | |||
| m_option_name = "xpu_device"; | |||
| enable_cpu = FLAGS_cpu; | |||
| #if MGB_CUDA | |||
| #if LITE_WITH_CUDA | |||
| enable_cuda = FLAGS_cuda; | |||
| #endif | |||
| enable_cpu_default = FLAGS_cpu_default; | |||
| @@ -165,18 +156,41 @@ XPUDeviceOption::XPUDeviceOption() { | |||
| "core ids number should be same with thread number set before"); | |||
| enable_set_core_ids = true; | |||
| } | |||
| } | |||
| m_option = { | |||
| {"cpu", lar::Bool::make(false)}, | |||
| #if LITE_WITH_CUDA | |||
| {"cuda", lar::Bool::make(false)}, | |||
| #endif | |||
| {"cpu_default", lar::Bool::make(false)}, | |||
| {"multithread", lar::NumberInt32::make(-1)}, | |||
| {"multithread_default", lar::NumberInt32::make(-1)}, | |||
| {"multi_thread_core_ids", lar::String::make("")}, | |||
| }; | |||
| std::static_pointer_cast<lar::Bool>(m_option["cpu"])->set_value(FLAGS_cpu); | |||
| #if LITE_WITH_CUDA | |||
| std::static_pointer_cast<lar::Bool>(m_option["cuda"])->set_value(FLAGS_cuda); | |||
| #endif | |||
| std::static_pointer_cast<lar::Bool>(m_option["cpu_default"]) | |||
| ->set_value(FLAGS_cpu_default); | |||
| std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"]) | |||
| ->set_value(FLAGS_multithread); | |||
| std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"]) | |||
| ->set_value(FLAGS_multithread_default); | |||
| std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"]) | |||
| ->set_value(FLAGS_multi_thread_core_ids); | |||
| } | |||
| bool XPUDeviceOption::m_valid; | |||
| bool XPUDeviceOption::is_valid() { | |||
| bool ret = FLAGS_cpu || FLAGS_cpu_default; | |||
| #if MGB_CUDA | |||
| #if LITE_WITH_CUDA | |||
| ret = ret || FLAGS_cuda; | |||
| #endif | |||
| ret = ret || FLAGS_multithread >= 0; | |||
| ret = ret || FLAGS_multithread_default >= 0; | |||
| ret = ret || !FLAGS_multi_thread_core_ids.empty(); | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | |||
| @@ -190,11 +204,46 @@ std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | |||
| void XPUDeviceOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| enable_cpu = std::static_pointer_cast<lar::Bool>(m_option["cpu"])->get_value(); | |||
| #if LITE_WITH_CUDA | |||
| enable_cuda = std::static_pointer_cast<lar::Bool>(m_option["cuda"])->get_value(); | |||
| #endif | |||
| enable_cpu_default = | |||
| std::static_pointer_cast<lar::Bool>(m_option["cpu_default"])->get_value(); | |||
| int32_t num_of_thread = | |||
| std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"]) | |||
| ->get_value(); | |||
| enable_multithread = num_of_thread >= 0; | |||
| num_of_thread = | |||
| std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"]) | |||
| ->get_value(); | |||
| enable_multithread_default = num_of_thread >= 0; | |||
| thread_num = num_of_thread >= 0 ? num_of_thread : 0; | |||
| std::string core_id_str = | |||
| std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"]) | |||
| ->get_value(); | |||
| if (!core_id_str.empty()) { | |||
| mgb_assert( | |||
| enable_multithread || enable_multithread_default, | |||
| "core ids should be set after --multithread or --multithread-default"); | |||
| std::stringstream id_stream(core_id_str); | |||
| std::string id; | |||
| size_t thread_cnt = 0; | |||
| while (getline(id_stream, id, ',')) { | |||
| thread_cnt++; | |||
| core_ids.push_back(atoi(id.c_str())); | |||
| } | |||
| mgb_assert( | |||
| thread_cnt == thread_num, | |||
| "core ids number should be same with thread number set before"); | |||
| enable_set_core_ids = true; | |||
| } | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// xpu gflags //////////////////////////// | |||
| DEFINE_bool(cpu, false, "set CPU device as running device"); | |||
| #if MGB_CUDA || LITE_WITH_CUDA | |||
| #if LITE_WITH_CUDA | |||
| DEFINE_bool(cuda, false, "set CUDA device as running device "); | |||
| #endif | |||
| DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode"); | |||
| @@ -204,3 +253,4 @@ DEFINE_int32( | |||
| "set multithread device as running device with inplace mode"); | |||
| DEFINE_string(multi_thread_core_ids, "", "set multithread core id"); | |||
| REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option); | |||
| REGIST_OPTION_VALIDATER(xpu_device, lar::XPUDeviceOption::set_valid); | |||
| @@ -1,18 +1,10 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/device_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_bool(cpu); | |||
| #if MGB_CUDA || LITE_WITH_CUDA | |||
| #if LITE_WITH_CUDA | |||
| DECLARE_bool(cuda); | |||
| #endif | |||
| DECLARE_bool(cpu_default); | |||
| @@ -29,12 +21,16 @@ public: | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| XPUDeviceOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| bool enable_cpu; | |||
| #if MGB_CUDA || LITE_WITH_CUDA | |||
| #if LITE_WITH_CUDA | |||
| bool enable_cuda; | |||
| #endif | |||
| bool enable_cpu_default; | |||
| @@ -44,5 +40,8 @@ private: | |||
| size_t thread_num; | |||
| std::vector<int> core_ids; | |||
| std::string m_option_name; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/extern_c_opr_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "extern_c_opr_options.h" | |||
| #include "megbrain/utils/debug.h" | |||
| #include "misc.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/extern_c_opr_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "megbrain/graph/extern_copr_api.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/fastrun_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #if defined(_WIN32) | |||
| @@ -153,7 +144,7 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||
| } // namespace lar | |||
| using namespace lar; | |||
| bool FastRunOption::m_valid; | |||
| FastRunOption::FastRunOption() { | |||
| m_option_name = "fastrun"; | |||
| #if MGB_ENABLE_FASTRUN | |||
| @@ -164,6 +155,25 @@ FastRunOption::FastRunOption() { | |||
| enable_reproducible = FLAGS_reproducible; | |||
| m_fast_run_cache = FLAGS_fast_run_algo_policy; | |||
| share_batch_size = FLAGS_fast_run_shared_batch_size; | |||
| m_option = { | |||
| #if MGB_ENABLE_FASTRUN | |||
| {"fast_run", lar::Bool::make(false)}, | |||
| {"full_run", lar::Bool::make(false)}, | |||
| #endif | |||
| {"binary_equal_between_batch", lar::Bool::make(false)}, | |||
| {"reproducible", lar::Bool::make(false)} | |||
| }; | |||
| #if MGB_ENABLE_FASTRUN | |||
| std::static_pointer_cast<lar::Bool>(m_option["fast_run"]) | |||
| ->set_value(FLAGS_fast_run); | |||
| std::static_pointer_cast<lar::Bool>(m_option["full_run"]) | |||
| ->set_value(FLAGS_full_run); | |||
| #endif | |||
| std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"]) | |||
| ->set_value(FLAGS_binary_equal_between_batch); | |||
| std::static_pointer_cast<lar::Bool>(m_option["reproducible"]) | |||
| ->set_value(FLAGS_reproducible); | |||
| #if MGB_ENABLE_FASTRUN | |||
| //! while fastrun cache file path is not empty and can't be accessed | |||
| if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) { | |||
| @@ -191,7 +201,7 @@ bool FastRunOption::is_valid() { | |||
| ret = ret || FLAGS_reproducible; | |||
| ret = ret || FLAGS_fast_run_algo_policy.size() > 0; | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> FastRunOption::create_option() { | |||
| @@ -205,6 +215,21 @@ std::shared_ptr<OptionBase> FastRunOption::create_option() { | |||
| void FastRunOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| enable_fast_run = | |||
| std::static_pointer_cast<lar::Bool>(m_option["fast_run"])->get_value(); | |||
| enable_full_run = | |||
| std::static_pointer_cast<lar::Bool>(m_option["full_run"])->get_value(); | |||
| mgb_throw_if( | |||
| enable_fast_run && enable_full_run, mgb::AssertionError, | |||
| "invalid options of both fast-run and full-run"); | |||
| #endif | |||
| batch_binary_equal = | |||
| std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"]) | |||
| ->get_value(); | |||
| enable_reproducible = | |||
| std::static_pointer_cast<lar::Bool>(m_option["reproducible"])->get_value(); | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| @@ -228,4 +253,5 @@ DEFINE_bool( | |||
| DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||
| DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | |||
| REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); | |||
| REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); | |||
| REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid); | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/fastrun_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| @@ -38,6 +29,10 @@ public: | |||
| //! get options name for quickly search | |||
| std::string option_name() const override { return m_option_name; } | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| FastRunOption(); | |||
| //! config template for different model | |||
| @@ -53,5 +48,8 @@ private: | |||
| size_t share_batch_size; //! fast run strategy share batch size setting | |||
| std::string m_fast_run_cache; //! fast run cache file path | |||
| std::string m_option_name; //! option name | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/io_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <map> | |||
| #include "helpers/data_parser.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/io_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "helpers/outdumper.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/layout_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #include "misc.h" | |||
| @@ -24,7 +15,7 @@ void LayoutOption::config_model_internel<ModelLite>( | |||
| model->get_config().options.enable_##layout = true; \ | |||
| break; | |||
| switch (option_flag) { | |||
| switch (m_option_flag) { | |||
| case OptLayoutType::NCHW4: | |||
| ENABLE_LAYOUT(nchw4) | |||
| @@ -59,13 +50,12 @@ template <> | |||
| void lar::LayoutOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| mgb_log_debug("mdl layout config start"); | |||
| #define ENABLE_LAYOUT(layout) \ | |||
| mgb_log_warn("enable " #layout " optimization"); \ | |||
| model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | |||
| break; | |||
| switch (option_flag) { | |||
| switch (m_option_flag) { | |||
| case OptLayoutType::NCHW4: | |||
| ENABLE_LAYOUT(nchw4) | |||
| @@ -93,7 +83,6 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||
| default: | |||
| break; | |||
| } | |||
| mgb_log_debug("mdl layout config end"); | |||
| #undef ENABLE_LAYOUT | |||
| } | |||
| @@ -101,48 +90,68 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||
| } // namespace lar | |||
| using namespace lar; | |||
| OptLayoutType LayoutOption::option_flag; | |||
| bool LayoutOption::m_valid; | |||
| LayoutOption::LayoutOption() { | |||
| m_option_name = "layout"; | |||
| m_option_flag = static_cast<OptLayoutType>(0); | |||
| m_option = { | |||
| {"enable_nchw4", lar::Bool::make(false)}, | |||
| {"enable_chwn4", lar::Bool::make(false)}, | |||
| {"enable_nchw44", lar::Bool::make(false)}, | |||
| {"enable_nchw88", lar::Bool::make(false)}, | |||
| {"enable_nchw32", lar::Bool::make(false)}, | |||
| {"enable_nchw64", lar::Bool::make(false)}, | |||
| {"enable_nhwcd4", lar::Bool::make(false)}, | |||
| {"enable_nchw44_dot", lar::Bool::make(false)}, | |||
| }; | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"]) | |||
| ->set_value(FLAGS_enable_nchw4); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"]) | |||
| ->set_value(FLAGS_enable_chwn4); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"]) | |||
| ->set_value(FLAGS_enable_nchw44); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"]) | |||
| ->set_value(FLAGS_enable_nchw88); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"]) | |||
| ->set_value(FLAGS_enable_nchw32); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"]) | |||
| ->set_value(FLAGS_enable_nchw64); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"]) | |||
| ->set_value(FLAGS_enable_nhwcd4); | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"]) | |||
| ->set_value(FLAGS_enable_nchw44_dot); | |||
| } | |||
| bool LayoutOption::is_valid() { | |||
| size_t valid_flag = 0; | |||
| if (FLAGS_enable_nchw4) { | |||
| valid_flag = valid_flag | (1 << 0); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | |||
| } | |||
| if (FLAGS_enable_chwn4) { | |||
| valid_flag = valid_flag | (1 << 1); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | |||
| } | |||
| if (FLAGS_enable_nchw44) { | |||
| valid_flag = valid_flag | (1 << 2); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | |||
| } | |||
| if (FLAGS_enable_nchw88) { | |||
| valid_flag = valid_flag | (1 << 3); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | |||
| } | |||
| if (FLAGS_enable_nchw32) { | |||
| valid_flag = valid_flag | (1 << 4); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32); | |||
| } | |||
| if (FLAGS_enable_nchw64) { | |||
| valid_flag = valid_flag | (1 << 5); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64); | |||
| } | |||
| if (FLAGS_enable_nhwcd4) { | |||
| valid_flag = valid_flag | (1 << 6); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4); | |||
| } | |||
| if (FLAGS_enable_nchw44_dot) { | |||
| valid_flag = valid_flag | (1 << 7); | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT); | |||
| } | |||
| //! only one flag is valid | |||
| bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); | |||
| if (ret) { | |||
| option_flag = static_cast<OptLayoutType>(valid_flag); | |||
| } else { | |||
| option_flag = static_cast<OptLayoutType>(0); | |||
| } | |||
| return ret; | |||
| return ret | m_valid; | |||
| }; | |||
| std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||
| @@ -156,6 +165,37 @@ std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||
| void LayoutOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| size_t valid_flag = 0; | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"])->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4); | |||
| } | |||
| if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"]) | |||
| ->get_value()) { | |||
| valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT); | |||
| } | |||
| mgb_throw_if( | |||
| valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError, | |||
| "invalid options of layout transform 0x%lx", valid_flag); | |||
| m_option_flag = static_cast<OptLayoutType>(valid_flag); | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| @@ -168,4 +208,5 @@ DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!"); | |||
| DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!"); | |||
| DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!"); | |||
| REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); | |||
| REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); | |||
| REGIST_OPTION_VALIDATER(layout, lar::LayoutOption::set_valid); | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/layout_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| @@ -42,6 +33,10 @@ public: | |||
| //! get option name | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| //! Constructor | |||
| LayoutOption(); | |||
| @@ -50,7 +45,9 @@ private: | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| static OptLayoutType option_flag; | |||
| OptLayoutType m_option_flag; | |||
| std::string m_option_name; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,11 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/layout_trans_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "layout_trans_options.h" | |||
| #include <gflags/gflags.h> | |||
| #include "megbrain/serialization/serializer.h" | |||
| @@ -19,6 +11,7 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (m_layout_transform) { | |||
| LITE_WARN("using global layout transform optimization\n"); | |||
| if (m_layout_transform_target == | |||
| mgb::gopt::GraphTuningOptions::Target::CPU) { | |||
| model->get_config().device_type = LiteDeviceType::LITE_CPU; | |||
| @@ -48,7 +41,47 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||
| if (m_layout_transform) { | |||
| mgb_log_warn("using global layout transform optimization\n"); | |||
| auto&& load_result = model->get_mdl_load_result(); | |||
| for (auto&& item : load_result.output_var_list) { | |||
| if (item.shape()[0] > 1) { | |||
| mgb_log_warn( | |||
| " model may be dumped with multi batch and will cost lots " | |||
| "of time to profile during global layout transform!!!\n"); | |||
| } | |||
| } | |||
| //! update output varlist when input shape maybe change(some pass excution | |||
| //! time depends on the shape of init input) | |||
| mgb::thin_hash_table::ThinHashMap<mgb::cg::SymbolVar, mgb::cg::SymbolVar> | |||
| varmap; | |||
| mgb::cg::DepOprIter dep([&](mgb::cg::OperatorNodeBase* opr) { | |||
| if (auto h2d = opr->try_cast_final<mgb::opr::Host2DeviceCopy>()) { | |||
| auto param = h2d->param(); | |||
| mgb::TensorShape new_shape = h2d->host_data()->shape(); | |||
| std::shared_ptr<mgb::HostTensorND> new_tensor = | |||
| std::make_shared<mgb::HostTensorND>( | |||
| h2d->host_data()->comp_node(), new_shape, | |||
| h2d->host_data()->dtype()); | |||
| new_tensor->only_reset_raw_storage(h2d->host_data()->storage()); | |||
| auto h2d_opr = mgb::opr::Host2DeviceCopy::make( | |||
| *h2d->owner_graph(), new_tensor, param, h2d->config()); | |||
| varmap[h2d->output(0)] = h2d_opr; | |||
| } | |||
| }); | |||
| for (auto&& i : load_result.output_var_list) | |||
| dep.add(i); | |||
| if (!varmap.empty()) { | |||
| auto output_vars = | |||
| mgb::cg::replace_vars(load_result.output_var_list, varmap); | |||
| for (size_t i = 0; i < load_result.output_var_list.size(); ++i) { | |||
| output_vars[i].rename( | |||
| load_result.output_var_list[i].node()->name()); | |||
| } | |||
| load_result.output_var_list = output_vars; | |||
| } | |||
| load_result.output_var_list = mgb::gopt::layout_transform( | |||
| load_result.output_var_list, m_layout_transform_target); | |||
| @@ -98,7 +131,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
| } // namespace lar | |||
| using namespace lar; | |||
| bool GoptLayoutOption::m_valid; | |||
| GoptLayoutOption::GoptLayoutOption() { | |||
| m_option_name = "gopt_layout"; | |||
| if (FLAGS_layout_transform != "cpu" | |||
| @@ -122,6 +155,12 @@ GoptLayoutOption::GoptLayoutOption() { | |||
| #endif | |||
| } | |||
| m_layout_transform_dump_file = FLAGS_layout_transform_dump; | |||
| m_option = { | |||
| {"layout_transform", lar::String::make("")}, | |||
| }; | |||
| std::static_pointer_cast<lar::String>(m_option["layout_transform"]) | |||
| ->set_value(FLAGS_layout_transform); | |||
| } | |||
| bool GoptLayoutOption::is_valid() { | |||
| @@ -143,7 +182,7 @@ bool GoptLayoutOption::is_valid() { | |||
| } | |||
| } | |||
| ret = ret || !FLAGS_layout_transform_dump.empty(); | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | |||
| @@ -157,6 +196,26 @@ std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | |||
| void GoptLayoutOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| auto value = std::static_pointer_cast<lar::String>(m_option["layout_transform"]) | |||
| ->get_value(); | |||
| if (value.empty()) { | |||
| return; | |||
| } | |||
| if (value == "cpu") { | |||
| m_layout_transform = true; | |||
| m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU; | |||
| } | |||
| #if LITE_WITH_CUDA | |||
| else if (value == "cuda") { | |||
| m_layout_transform = true; | |||
| m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA; | |||
| } | |||
| #endif | |||
| else { | |||
| mgb_throw( | |||
| mgb::AssertionError, "invalid options of global layout transform %s", | |||
| value.c_str()); | |||
| } | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| @@ -175,3 +234,4 @@ DEFINE_string( | |||
| "file path."); | |||
| REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | |||
| REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/layout_trans_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| @@ -32,6 +23,10 @@ public: | |||
| //! get options name for quickly search | |||
| std::string option_name() const override { return m_option_name; } | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| GoptLayoutOption(); | |||
| //! config template for different model | |||
| @@ -41,5 +36,7 @@ private: | |||
| std::string m_option_name; | |||
| std::string m_layout_transform_dump_file; | |||
| mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/model_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model_options.h" | |||
| #include "device_options.h" | |||
| #include "lite/pack_model.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/model_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "megbrain/graph/operator_node.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/optimize_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "megbrain/gopt/inference.h" | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| #include "megbrain/tensorrt/tensorrt_engine_cache.h" | |||
| @@ -43,15 +34,18 @@ void FusePreprocessOption::config_model_internel<ModelMdl>( | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| bool FusePreprocessOption::m_valid; | |||
| FusePreprocessOption::FusePreprocessOption() { | |||
| m_option_name = "fuse_preprocess"; | |||
| enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | |||
| m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}}; | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_preprocess"]) | |||
| ->set_value(FLAGS_enable_fuse_preprocess); | |||
| } | |||
| bool FusePreprocessOption::is_valid() { | |||
| bool ret = FLAGS_enable_fuse_preprocess; | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | |||
| @@ -65,10 +59,14 @@ std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | |||
| void FusePreprocessOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| enable_fuse_preprocess = | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_preprocess"]) | |||
| ->get_value(); | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// weight preprocess optimize options /////////////// | |||
| bool WeightPreprocessOption::m_valid; | |||
| namespace lar { | |||
| template <> | |||
| void WeightPreprocessOption::config_model_internel<ModelLite>( | |||
| @@ -97,11 +95,14 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||
| WeightPreprocessOption::WeightPreprocessOption() { | |||
| m_option_name = "weight_preprocess"; | |||
| weight_preprocess = FLAGS_weight_preprocess; | |||
| m_option = {{"weight_preprocess", lar::Bool::make(false)}}; | |||
| std::static_pointer_cast<lar::Bool>(m_option["weight_preprocess"]) | |||
| ->set_value(FLAGS_weight_preprocess); | |||
| } | |||
| bool WeightPreprocessOption::is_valid() { | |||
| bool ret = FLAGS_weight_preprocess; | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | |||
| @@ -115,10 +116,14 @@ std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | |||
| void WeightPreprocessOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| weight_preprocess = | |||
| std::static_pointer_cast<lar::Bool>(m_option["weight_preprocess"]) | |||
| ->get_value(); | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///// fuse conv bias and nonlinear activation opr optimize options //////// | |||
| bool FuseConvBiasNonlinearOption::m_valid; | |||
| namespace lar { | |||
| template <> | |||
| void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>( | |||
| @@ -145,13 +150,16 @@ void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>( | |||
| } // namespace lar | |||
| FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | |||
| m_option_name = "fuse_conv_bias_nonlinear"; | |||
| m_option_name = "fuse_conv_bias_nonlinearity"; | |||
| enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||
| m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}}; | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_conv_bias_nonlinearity"]) | |||
| ->set_value(FLAGS_enable_fuse_conv_bias_nonlinearity); | |||
| } | |||
| bool FuseConvBiasNonlinearOption::is_valid() { | |||
| bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||
| @@ -166,10 +174,15 @@ std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||
| void FuseConvBiasNonlinearOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| enable_fuse_conv_bias_nonlinearity = | |||
| std::static_pointer_cast<lar::Bool>( | |||
| m_option["enable_fuse_conv_bias_nonlinearity"]) | |||
| ->get_value(); | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// fuse and preprocess optimize options /////////////// | |||
| bool FuseConvBiasElemwiseAddOption::m_valid; | |||
| namespace lar { | |||
| template <> | |||
| void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>( | |||
| @@ -198,13 +211,16 @@ void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>( | |||
| } // namespace lar | |||
| FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | |||
| m_option_name = "fuse_conv_bias_z"; | |||
| m_option_name = "fuse_conv_bias_with_z"; | |||
| enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | |||
| m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}}; | |||
| std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_conv_bias_with_z"]) | |||
| ->set_value(FLAGS_enable_fuse_conv_bias_with_z); | |||
| } | |||
| bool FuseConvBiasElemwiseAddOption::is_valid() { | |||
| bool ret = FLAGS_enable_fuse_conv_bias_with_z; | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||
| @@ -219,10 +235,14 @@ std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||
| void FuseConvBiasElemwiseAddOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| enable_fuse_conv_bias_with_z = std::static_pointer_cast<lar::Bool>( | |||
| m_option["enable_fuse_conv_bias_with_z"]) | |||
| ->get_value(); | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// graph retrict options ///////////////////////// | |||
| bool GraphRecordOption::m_valid; | |||
| namespace lar { | |||
| template <> | |||
| void GraphRecordOption::config_model_internel<ModelLite>( | |||
| @@ -299,6 +319,23 @@ GraphRecordOption::GraphRecordOption() { | |||
| if (FLAGS_record_comp_seq2) { | |||
| m_record_comp_seq = 2; | |||
| } | |||
| m_option = { | |||
| {"record_comp_seq", lar::Bool::make(false)}, | |||
| {"record_comp_seq2", lar::Bool::make(false)}, | |||
| {"const_shape", lar::Bool::make(false)}, | |||
| {"fake_first", lar::Bool::make(false)}, | |||
| {"no_sanity_check", lar::Bool::make(false)}}; | |||
| std::static_pointer_cast<lar::Bool>(m_option["const_shape"]) | |||
| ->set_value(FLAGS_const_shape); | |||
| std::static_pointer_cast<lar::Bool>(m_option["fake_first"]) | |||
| ->set_value(FLAGS_fake_first); | |||
| std::static_pointer_cast<lar::Bool>(m_option["no_sanity_check"]) | |||
| ->set_value(FLAGS_no_sanity_check); | |||
| std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq"]) | |||
| ->set_value(FLAGS_record_comp_seq); | |||
| std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq2"]) | |||
| ->set_value(FLAGS_record_comp_seq2); | |||
| } | |||
| bool GraphRecordOption::is_valid() { | |||
| @@ -307,7 +344,7 @@ bool GraphRecordOption::is_valid() { | |||
| ret = ret || FLAGS_no_sanity_check; | |||
| ret = ret || FLAGS_record_comp_seq; | |||
| ret = ret || FLAGS_record_comp_seq2; | |||
| return ret; | |||
| return ret || m_valid; | |||
| } | |||
| std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | |||
| @@ -321,6 +358,22 @@ std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | |||
| void GraphRecordOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| const_shape = | |||
| std::static_pointer_cast<lar::Bool>(m_option["const_shape"])->get_value(); | |||
| fake_first = | |||
| std::static_pointer_cast<lar::Bool>(m_option["fake_first"])->get_value(); | |||
| no_sanity_check = std::static_pointer_cast<lar::Bool>(m_option["no_sanity_check"]) | |||
| ->get_value(); | |||
| m_record_comp_seq = std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq"]) | |||
| ->get_value() | |||
| ? 1 | |||
| : 0; | |||
| m_record_comp_seq = | |||
| std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq2"]) | |||
| ->get_value() | |||
| ? 2 | |||
| : 0; | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// graph retrict options ///////////////////////// | |||
| @@ -569,13 +622,26 @@ DEFINE_string( | |||
| "Set the TensorRT engine cache path for serialized prebuilt " | |||
| "ICudaEngine"); | |||
| #endif | |||
| REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option); | |||
| REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid); | |||
| REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); | |||
| REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid); | |||
| REGIST_OPTION_CREATOR( | |||
| fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option); | |||
| fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option); | |||
| REGIST_OPTION_VALIDATER( | |||
| fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid); | |||
| REGIST_OPTION_CREATOR( | |||
| fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||
| fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||
| REGIST_OPTION_VALIDATER( | |||
| fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::set_valid); | |||
| REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option); | |||
| REGIST_OPTION_VALIDATER(graph_record, lar::GraphRecordOption::set_valid); | |||
| REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option); | |||
| REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option); | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/optimize_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "helpers/common.h" | |||
| @@ -44,6 +35,10 @@ public: | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| FusePreprocessOption(); | |||
| template <typename ModelImpl> | |||
| @@ -51,6 +46,8 @@ private: | |||
| std::string m_option_name; | |||
| bool enable_fuse_preprocess; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| ///////////////////////// weight preprocess optimize options ////////////// | |||
| @@ -64,6 +61,9 @@ public: | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; }; | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| WeightPreprocessOption(); | |||
| @@ -72,6 +72,8 @@ private: | |||
| std::string m_option_name; | |||
| bool weight_preprocess; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| /////////////// fuse_conv_bias_nonlinearity optimize options /////////////// | |||
| @@ -85,6 +87,9 @@ public: | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| FuseConvBiasNonlinearOption(); | |||
| @@ -93,6 +98,8 @@ private: | |||
| std::string m_option_name; | |||
| bool enable_fuse_conv_bias_nonlinearity; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| ///////////////////////// fuse_conv_bias_with_z optimize options ////////////// | |||
| @@ -106,6 +113,9 @@ public: | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| FuseConvBiasElemwiseAddOption(); | |||
| @@ -113,6 +123,8 @@ private: | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool enable_fuse_conv_bias_with_z; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| ///////////////////////// graph record options /////////////////////////// | |||
| @@ -127,6 +139,10 @@ public: | |||
| std::string option_name() const override { return m_option_name; }; | |||
| static void set_valid(bool val) { m_valid = val; } | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| GraphRecordOption(); | |||
| template <typename ModelImpl> | |||
| @@ -137,6 +153,8 @@ private: | |||
| bool const_shape; | |||
| bool fake_first; | |||
| bool no_sanity_check; | |||
| static bool m_valid; | |||
| OptionValMap m_option; | |||
| }; | |||
| ///////////////////////// memory optimize options ///////////////////////// | |||
| @@ -1,22 +1,13 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/option_base.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "megbrain/common.h" | |||
| #include "helpers/common.h" | |||
| #include "helpers/utils.h" | |||
| #include "models/model.h" | |||
| namespace lar { | |||
| @@ -34,6 +25,9 @@ public: | |||
| //! get option name | |||
| virtual std::string option_name() const = 0; | |||
| //! get option map | |||
| virtual OptionValMap* get_option() { return nullptr; } | |||
| virtual ~OptionBase() = default; | |||
| }; | |||
| @@ -43,7 +37,10 @@ public: | |||
| class OptionFactory { | |||
| public: | |||
| using OptionCreator = std::function<std::shared_ptr<OptionBase>()>; | |||
| using OptionMap = std::unordered_map<std::string, OptionCreator>; | |||
| using OptionValidater = std::function<void(bool)>; | |||
| using OptionCreatorMap = std::unordered_map<std::string, OptionCreator>; | |||
| using OptionValidaterMap = std::unordered_map<std::string, OptionValidater>; | |||
| //! get Singleton option factory | |||
| static OptionFactory& get_Instance() { | |||
| @@ -52,29 +49,49 @@ public: | |||
| } | |||
| //! registe option creator into option map | |||
| void registe_options(std::string name, OptionCreator creator) { | |||
| if (option_creator_map.count(name) == 0) { | |||
| option_creator_map[name] = creator; | |||
| void registe_options_creator(std::string name, OptionCreator creator) { | |||
| if (m_option_creator_map.count(name) == 0) { | |||
| m_option_creator_map[name] = creator; | |||
| } | |||
| } | |||
| //! registe option validater into option map | |||
| void registe_options_validater(std::string name, OptionValidater validater) { | |||
| if (m_option_validater_map.count(name) == 0) { | |||
| m_option_validater_map[name] = validater; | |||
| } | |||
| } | |||
| //! get creator map | |||
| OptionMap* get_option_creator_map() { return &option_creator_map; } | |||
| OptionCreatorMap* get_option_creator_map() { return &m_option_creator_map; } | |||
| //! get validater map | |||
| OptionValidaterMap* get_option_validater_map() { return &m_option_validater_map; } | |||
| private: | |||
| OptionFactory(){}; | |||
| OptionMap option_creator_map; | |||
| OptionCreatorMap m_option_creator_map; | |||
| OptionValidaterMap m_option_validater_map; | |||
| }; | |||
| } // namespace lar | |||
| #define REGIST_OPTION_CREATOR(name_, creator_) \ | |||
| struct OptionRegister_##name_ { \ | |||
| OptionRegister_##name_() { \ | |||
| lar::OptionFactory::get_Instance().registe_options(#name_, creator_); \ | |||
| } \ | |||
| }; \ | |||
| OptionRegister_##name_ name_; | |||
| #define REGIST_OPTION_CREATOR(_name, _creator) \ | |||
| struct CreatorRegister_##_name { \ | |||
| CreatorRegister_##_name() { \ | |||
| lar::OptionFactory::get_Instance().registe_options_creator( \ | |||
| #_name, _creator); \ | |||
| } \ | |||
| }; \ | |||
| CreatorRegister_##_name creator_##_name; | |||
| #define REGIST_OPTION_VALIDATER(_name, _validater) \ | |||
| struct ValitaterRegister_##_name { \ | |||
| ValitaterRegister_##_name() { \ | |||
| lar::OptionFactory::get_Instance().registe_options_validater( \ | |||
| #_name, _validater); \ | |||
| } \ | |||
| }; \ | |||
| ValitaterRegister_##_name validater_##_name; | |||
| #define CONFIG_MODEL_FUN \ | |||
| if (model->type() == ModelType::LITE_MODEL) { \ | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/plugin_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "plugin_options.h" | |||
| #include <map> | |||
| #include "misc.h" | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/plugin_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #if __linux__ || __unix__ | |||
| @@ -1,24 +1,21 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/strategy_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy_options.h" | |||
| #include "models/model_mdl.h" | |||
| using namespace lar; | |||
| DECLARE_bool(c_opr_lib_with_param); | |||
| DECLARE_bool(fitting); | |||
| StrategyOption::StrategyOption() { | |||
| m_option_name = "run_strategy"; | |||
| warmup_iter = FLAGS_warmup_iter; | |||
| run_iter = FLAGS_iter; | |||
| threads = FLAGS_thread; | |||
| warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter; | |||
| run_iter = FLAGS_fitting ? 10 : FLAGS_iter; | |||
| threads = FLAGS_fitting ? 1 : FLAGS_thread; | |||
| m_option = { | |||
| {"iter", lar::NumberInt32::make(run_iter)}, | |||
| {"warmup_iter", lar::NumberInt32::make(warmup_iter)}, | |||
| {"thread", lar::NumberInt32::make(threads)}, | |||
| }; | |||
| } | |||
| std::shared_ptr<OptionBase> StrategyOption::create_option() { | |||
| @@ -60,8 +57,7 @@ void TestcaseOption::config_model( | |||
| if (model->type() == ModelType::MEGDL_MODEL) { | |||
| auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | |||
| if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) { | |||
| if (runtime_param.stage == RunStage::MODEL_RUNNING) { | |||
| auto load_result = model_ptr->get_mdl_load_result(); | |||
| if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| auto input_tensor = model_ptr->get_test_input(); | |||
| auto loader = model_ptr->reset_loader(); | |||
| auto testcase = loader->load(model_ptr->get_mdl_config(), false); | |||
| @@ -1,12 +1,3 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/strategy_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| @@ -32,6 +23,8 @@ public: | |||
| //! get option name | |||
| std::string option_name() const override { return m_option_name; }; | |||
| OptionValMap* get_option() override { return &m_option; } | |||
| private: | |||
| //! Constructor | |||
| StrategyOption(); | |||
| @@ -43,6 +36,7 @@ private: | |||
| size_t run_iter; //! iteration number for running model | |||
| size_t threads; //! thread number for running model (NOTE:it's different | |||
| //! from multithread device ) | |||
| OptionValMap m_option; | |||
| }; | |||
| class TestcaseOption final : public OptionBase { | |||
| @@ -1,18 +1,10 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy.h" | |||
| #include <iostream> | |||
| #include "strategy_fitting.h" | |||
| #include "strategy_normal.h" | |||
| using namespace lar; | |||
| DECLARE_bool(fitting); | |||
| std::shared_ptr<StrategyBase> StrategyBase::create_strategy(std::string model_path) { | |||
| if (FLAGS_fitting) { | |||
| return std::make_shared<FittingStrategy>(model_path); | |||
| @@ -1,23 +1,11 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "helpers/common.h" | |||
| #include "models/model.h" | |||
| #include "options/option_base.h" | |||
| DECLARE_bool(fitting); | |||
| namespace lar { | |||
| using OptionMap = std::unordered_map<std::string, std::shared_ptr<OptionBase>>; | |||
| /*! | |||
| * \brief: load and run strategy base class | |||
| */ | |||
| @@ -30,34 +18,10 @@ public: | |||
| virtual ~StrategyBase() = default; | |||
| RuntimeParam m_runtime_param; | |||
| std::unordered_map<std::string, std::shared_ptr<OptionBase>> m_options; | |||
| }; | |||
| /*! | |||
| * \brief: normal strategy for running | |||
| */ | |||
| class NormalStrategy : public StrategyBase { | |||
| public: | |||
| NormalStrategy(std::string model_path); | |||
| //! run model with runtime parameter | |||
| void run() override; | |||
| private: | |||
| //! run model subline for multiple thread | |||
| void run_subline(); | |||
| std::string m_model_path; | |||
| std::shared_ptr<OptionMap> m_options; | |||
| }; | |||
| /*! | |||
| * \brief: Fitting strategy for running | |||
| */ | |||
| class FittingStrategy : public StrategyBase { | |||
| public: | |||
| FittingStrategy(std::string model_path); | |||
| void run() override; | |||
| }; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -1,24 +1,590 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy_fitting.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy.h" | |||
| #include "strategy_fitting.h" | |||
| #if defined(_WIN32) | |||
| #include <io.h> | |||
| #define F_OK 0 | |||
| #define access(a, b) _access(a, b) | |||
| #elif __linux__ || __unix__ || __APPLE__ | |||
| #include <unistd.h> | |||
| #endif | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <list> | |||
| #include <regex> | |||
| #include <thread> | |||
| #include "lite/pack_model.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/exception.h" | |||
| #include "megbrain/utils/timer.h" | |||
| #include "megbrain/version.h" | |||
| #include "megdnn/version.h" | |||
| #include "misc.h" | |||
| DECLARE_bool(cpu); | |||
| using namespace lar; | |||
| FittingStrategy::FittingStrategy(std::string) { | |||
| mgb_assert("this version don't support Fitting Strategy"); | |||
| // /////////////////// OptionsFastManager /////////////////// | |||
| void OptionsFastManager::init(std::shared_ptr<OptionMap>& options) { | |||
| m_option_group_cnt = 0; | |||
| m_fixed_option_cnt = 0; | |||
| m_internal_options_name = { | |||
| {"enable_fuse_conv_bias_with_z"}, | |||
| {"enable_fuse_preprocess"}, | |||
| {"record_comp_seq"}}; | |||
| //! record the independent option value | |||
| for (auto& option : *options) { | |||
| auto option_vals = option.second->get_option(); | |||
| if (option_vals) { | |||
| for (auto& item : *option_vals) { | |||
| m_valid_option_vals.insert(item); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| std::string OptionsFastManager::set_next_fixed_options() { | |||
| reset_option(); | |||
| auto& fixed_options_name = m_fixed_options_name[m_fixed_option_cnt]; | |||
| for (auto& item : fixed_options_name) { | |||
| if (m_valid_option_vals.find(item) != m_valid_option_vals.end()) { | |||
| auto& option_val = m_valid_option_vals[item]; | |||
| auto type = option_val->get_type(); | |||
| if (type == JsonValueType::Bool) { | |||
| auto option_val_ptr = std::static_pointer_cast<lar::Bool>(option_val); | |||
| option_val_ptr->set_value(true); | |||
| } else if (type == JsonValueType::String && item == "layout_transform") { | |||
| auto option_val_ptr = std::static_pointer_cast<lar::String>(option_val); | |||
| //! device type | |||
| option_val_ptr->set_value(fixed_options_name[0]); | |||
| } else { | |||
| mgb_log_error( | |||
| "invalid JsonValueType:%s to set next value for fitting mode", | |||
| option_val->type_string().c_str()); | |||
| } | |||
| } | |||
| } | |||
| ++m_fixed_option_cnt; | |||
| std::string code = m_gflags_coder.encode(m_valid_option_vals); | |||
| return code; | |||
| } | |||
| std::string OptionsFastManager::set_next_options() { | |||
| reset_option(); | |||
| auto& constraint = m_internal_options_name[m_option_group_cnt]; | |||
| for (auto& item : constraint) { | |||
| if (m_valid_option_vals.find(item) != m_valid_option_vals.end()) { | |||
| auto& option_val = m_valid_option_vals[item]; | |||
| auto type = option_val->get_type(); | |||
| if (type == JsonValueType::Bool) { | |||
| auto option_val_ptr = std::static_pointer_cast<lar::Bool>(option_val); | |||
| option_val_ptr->set_value(true); | |||
| } else { | |||
| mgb_log_error( | |||
| "invalid JsonValueType: %s to set next value for fitting mode", | |||
| option_val->type_string().c_str()); | |||
| } | |||
| } | |||
| } | |||
| ++m_option_group_cnt; | |||
| std::string code = m_gflags_coder.encode(m_valid_option_vals); | |||
| return code; | |||
| } | |||
| bool OptionsFastManager::is_end_options() { | |||
| return m_option_group_cnt == m_internal_options_name.size(); | |||
| } | |||
| bool OptionsFastManager::is_fixed_end() { | |||
| return m_fixed_option_cnt == m_fixed_options_name.size(); | |||
| } | |||
| void OptionsFastManager::set_options(const std::string& code) { | |||
| reset_option(); | |||
| #if MGB_ENABLE_JSON | |||
| const std::regex json_regex(".\\{"); | |||
| #endif | |||
| const std::regex gflags_regex("--.*=.*"); | |||
| if (std::regex_search(code, gflags_regex)) { | |||
| m_gflags_coder.decode(code, m_valid_option_vals); | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| else if (std::regex_search(code, json_regex)) { | |||
| m_json_coder.decode(code, m_valid_option_vals); | |||
| } | |||
| #endif | |||
| else { | |||
| mgb_log_error("invalid options code format \"%s\" to decode", code.c_str()); | |||
| } | |||
| } | |||
| void OptionsFastManager::registe_fixed_options( | |||
| const std::vector<std::string>& option_name) { | |||
| m_fixed_options_name.push_back(option_name); | |||
| } | |||
| std::string OptionsFastManager::get_curr_options_code(CoderType type, bool encode_all) { | |||
| if (type == CoderType::GFLAGS) { | |||
| return m_gflags_coder.encode(m_valid_option_vals, encode_all); | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| else if (type == CoderType::JSON) { | |||
| return m_json_coder.encode(m_valid_option_vals, encode_all); | |||
| } | |||
| #endif | |||
| else { | |||
| mgb_log_error("coder should be implemented in furture"); | |||
| return ""; | |||
| } | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| std::vector<std::shared_ptr<mgb::json::Object>> OptionsFastManager::get_json() { | |||
| std::vector<std::shared_ptr<mgb::json::Object>> ret = | |||
| m_json_coder.encode(m_valid_option_vals); | |||
| return ret; | |||
| } | |||
| #endif | |||
| void OptionsFastManager::reset_option() { | |||
| for (auto& option : m_valid_option_vals) { | |||
| option.second->reset_value(); | |||
| } | |||
| } | |||
| ////////////////// OptionsTimeProfiler ////////////////// | |||
| void OptionsTimeProfiler::profile_with_given_options( | |||
| const std::string& model_path, std::shared_ptr<OptionMap>& given_options, | |||
| const std::string& option_code) { | |||
| RuntimeParam runtime_param; | |||
| auto model = ModelBase::create_model(model_path); | |||
| mgb::RealTimer timer; | |||
| auto stage_config_model = [&]() { | |||
| for (auto& option : *given_options) { | |||
| option.second->config_model(runtime_param, model); | |||
| } | |||
| }; | |||
| auto warm_up = [&]() { | |||
| for (size_t i = 0; i < runtime_param.warmup_iter; i++) { | |||
| auto start = timer.get_msecs(); | |||
| model->run_model(); | |||
| model->wait(); | |||
| mgb_log_warn("warm up %ld time %f ms", i, timer.get_msecs() - start); | |||
| } | |||
| }; | |||
| double inference_time = 0.0; | |||
| auto run_iter = [&]() { | |||
| for (size_t i = 0; i < runtime_param.run_iter; i++) { | |||
| auto start = timer.get_msecs(); | |||
| model->run_model(); | |||
| model->wait(); | |||
| auto end = timer.get_msecs(); | |||
| mgb_log_warn("run iter %ld time %f ms", i, end - start); | |||
| inference_time += end - start; | |||
| mgb_throw_if( | |||
| inference_time > TIME_OUT, mgb::TimeoutError, | |||
| "time out while using fitting"); | |||
| } | |||
| }; | |||
| //! model with testcase | |||
| size_t case_num = runtime_param.testcase_num; | |||
| bool exception_state = false; | |||
| MGB_TRY { | |||
| timer.reset(); | |||
| runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||
| stage_config_model(); | |||
| model->load_model(); | |||
| //! after load configure | |||
| auto config_model_before_runing = [&]() { | |||
| for (auto stage : | |||
| {RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||
| RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET, | |||
| RunStage::MODEL_RUNNING}) { | |||
| runtime_param.stage = stage; | |||
| stage_config_model(); | |||
| } | |||
| }; | |||
| timer.reset(); | |||
| for (size_t idx = 0; idx < case_num; idx++) { | |||
| auto start = timer.get_msecs(); | |||
| config_model_before_runing(); | |||
| auto end = timer.get_msecs(); | |||
| mgb_log_warn("config model time %f ms", end - start); | |||
| warm_up(); | |||
| run_iter(); | |||
| } | |||
| runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | |||
| stage_config_model(); | |||
| } | |||
| MGB_CATCH(std::exception & exc, { | |||
| mgb_log_error("catch exception: %s", exc.what()); | |||
| exception_state = true; | |||
| }); | |||
| auto average = inference_time / runtime_param.run_iter; | |||
| if (exception_state) { | |||
| average = TIME_OUT; | |||
| } | |||
| //! record profile result | |||
| printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), average); | |||
| m_options_profile_result.insert({option_code, average}); | |||
| //! record the best result | |||
| if (average < m_best_setting.second) { | |||
| m_best_setting.first = option_code; | |||
| m_best_setting.second = average; | |||
| } | |||
| } | |||
| /////////////////////////// UserInfoParser ///////////////////////////// | |||
| void UserInfoParser::get_user_info() { | |||
| //! register user information tips | |||
| std::vector<std::pair<std::string, std::string>> info_tips; | |||
| m_user_info["fitting_preference"] = "Inferspeed"; | |||
| info_tips.push_back( | |||
| {"use_const_shape", "whether the input shape is constant?(yes/no)?"}); | |||
| for (auto& tip : info_tips) { | |||
| std::cout << tip.second; | |||
| std::string answer = ""; | |||
| std::cin >> answer; | |||
| m_user_info[tip.first] = answer; | |||
| } | |||
| } | |||
| void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) { | |||
| std::vector<std::string> fixed_options; | |||
| if (m_user_info["use_const_shape"] == "yes") { | |||
| fixed_options.push_back("const_shape"); | |||
| } else if (m_user_info["use_const_shape"] != "no") { | |||
| mgb_log_error("invalid user information for \"use_const_shape\""); | |||
| } | |||
| fixed_options.push_back("enable_fuse_conv_bias_nonlinearity"); | |||
| std::vector<std::string> tmp_options; | |||
| auto insert_common_cpu_options = [&]() { | |||
| tmp_options = {"cpu"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "weight_preprocess", "fast_run"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "layout_transform"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "layout_transform", "weight_preprocess"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| }; | |||
| #if (MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||
| //! arm cpu device | |||
| insert_common_cpu_options(); | |||
| tmp_options = {"cpu", "enable_nchw44"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "enable_nchw44", "weight_preprocess", "fast_run"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "enable_nchw44_dot"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "enable_nchw44_dot", "weight_preprocess", "fast_run"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| #else | |||
| #if LITE_WITH_CUDA | |||
| //! build with cuda and not force to use cpu device | |||
| if (!FLAGS_cpu) { | |||
| tmp_options = {"cuda"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cuda", "enable_nchw4"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cuda", "enable_chwn4"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cuda", "enable_nchw64"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cuda", "enable_nchw32"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cuda", "layout_transform"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cuda", "layout_transform", "weight_preprocess"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| } | |||
| #endif | |||
| #if LITE_WITH_CUDA | |||
| //! build with cuda force to use cpu | |||
| if (FLAGS_cpu) { | |||
| #endif | |||
| //!x86 cpu options | |||
| insert_common_cpu_options(); | |||
| tmp_options = {"cpu", "enable_nchw88"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| tmp_options = {"cpu", "enable_nchw88", "weight_preprocess", "fast_run"}; | |||
| tmp_options.insert( | |||
| tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||
| manager->registe_fixed_options(tmp_options); | |||
| #if LITE_WITH_CUDA | |||
| } | |||
| #endif | |||
| #endif | |||
| m_proifler_type = ProiflerType::TIME_PROFILER; | |||
| } | |||
| // /////////////////// FittingStrategy ////////////////////////////////// | |||
| FittingStrategy::FittingStrategy(std::string model_path) { | |||
| m_manager = std::make_shared<OptionsFastManager>(); | |||
| m_dumped_model = FLAGS_dump_fitting_model; | |||
| mgb::set_log_level(mgb::LogLevel::WARN); | |||
| m_options = std::make_shared<OptionMap>(); | |||
| m_model_path = model_path; | |||
| auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||
| auto option_validater_map = | |||
| OptionFactory::get_Instance().get_option_validater_map(); | |||
| //! validate option used in fitting | |||
| auto validate_option = [&](std::string name) -> void { | |||
| if (option_validater_map->find(name) != option_validater_map->end()) { | |||
| auto& validater = (*option_validater_map).at(name); | |||
| if (validater) { | |||
| validater(true); | |||
| } | |||
| } | |||
| }; | |||
| //! construct option which is valid | |||
| auto construct_option = [&](std::string name) -> void { | |||
| auto& creator = (*option_creator_map)[name]; | |||
| auto option = creator(); | |||
| if (option) { | |||
| m_options->insert({name, option}); | |||
| } | |||
| }; | |||
| //! get all options which is valid | |||
| for (auto& creator : *option_creator_map) { | |||
| auto name = creator.first; | |||
| if (m_options->count(name) == 0) { | |||
| validate_option(name); | |||
| construct_option(name); | |||
| } | |||
| } | |||
| m_manager->init(m_options); | |||
| } | |||
| void FittingStrategy::dump_best_options_with_model() { | |||
| std::vector<uint8_t> info_algo_policy_data; | |||
| std::vector<uint8_t> info_binary_cache_data; | |||
| auto model = ModelBase::create_model(m_model_path); | |||
| RuntimeParam runtime_param; | |||
| auto stage_config_model = [&]() { | |||
| for (auto& option : *m_options) { | |||
| option.second->config_model(runtime_param, model); | |||
| } | |||
| }; | |||
| runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||
| stage_config_model(); | |||
| model->load_model(); | |||
| //! get json info vector | |||
| std::string json_info_str; | |||
| #if MGB_ENABLE_JSON | |||
| std::shared_ptr<mgb::json::Object> code_json = model->get_io_info(); | |||
| m_packed_info.push_back({mgb::json::String("IO"), (*code_json)["IO"]}); | |||
| auto info_json = m_manager->get_json(); | |||
| m_packed_info.push_back({mgb::json::String("options"), (*info_json[0])["options"]}); | |||
| m_packed_info.push_back({mgb::json::String("device"), (*info_json[1])["device"]}); | |||
| m_packed_info.push_back( | |||
| {mgb::json::String("backend"), mgb::json::String::make("MGE")}); | |||
| int lite_major, lite_minor, lite_patch; | |||
| lite::get_version(lite_major, lite_minor, lite_patch); | |||
| std::string version = std::to_string(lite_major); | |||
| version += "."; | |||
| version += std::to_string(lite_minor) + "."; | |||
| version += std::to_string(lite_patch); | |||
| m_packed_info.push_back( | |||
| {mgb::json::String("version"), mgb::json::String::make(version)}); | |||
| m_packed_info.push_back({mgb::json::String("valid"), mgb::json::Bool::make(true)}); | |||
| m_packed_info.push_back( | |||
| {mgb::json::String("name"), mgb::json::String::make("packed_model")}); | |||
| auto obj = mgb::json::Object::make(m_packed_info); | |||
| json_info_str = obj->to_string(); | |||
| #endif | |||
| std::vector<uint8_t> json_info(json_info_str.begin(), json_info_str.end()); | |||
| //! get model binary data after optimized | |||
| for (auto stage : | |||
| {RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||
| RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET, | |||
| RunStage::MODEL_RUNNING}) { | |||
| runtime_param.stage = stage; | |||
| stage_config_model(); | |||
| } | |||
| model->run_model(); | |||
| model->wait(); | |||
| std::vector<uint8_t> model_data = model->get_model_data(); | |||
| mgb_log_warn("model_data size=%zu", model_data.size()); | |||
| mgb_log_warn("json_info size=%zu", json_info.size()); | |||
| mgb_log_warn("info_algo_policy_data size=%zu", info_algo_policy_data.size()); | |||
| mgb_log_warn("info_binary_cache_data size=%zu", info_binary_cache_data.size()); | |||
| lite::ModelPacker packer( | |||
| model_data, m_dumped_model, json_info, info_algo_policy_data, | |||
| info_binary_cache_data); | |||
| packer.set_header(); | |||
| packer.pack_model(); | |||
| } | |||
| ///////////////////////// AutoCleanFile/////////////////////////// | |||
| FittingStrategy::AutoCleanFile::AutoCleanFile( | |||
| const std::string& model_path, std::shared_ptr<OptionMap>& options) | |||
| : m_model_path(model_path), m_options(options) { | |||
| m_filename = "fitting_tmp_model"; | |||
| if (!access(m_filename.c_str(), F_OK)) { | |||
| remove(m_filename.c_str()); | |||
| } | |||
| } | |||
| FittingStrategy::AutoCleanFile::~AutoCleanFile() { | |||
| if (!access(m_filename.c_str(), F_OK)) { | |||
| remove(m_filename.c_str()); | |||
| } | |||
| } | |||
| void FittingStrategy::AutoCleanFile::dump_model() { | |||
| auto model = ModelBase::create_model(m_model_path); | |||
| RuntimeParam runtime_param; | |||
| auto stage_config_model = [&]() { | |||
| for (auto& option : *m_options) { | |||
| option.second->config_model(runtime_param, model); | |||
| } | |||
| }; | |||
| runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||
| stage_config_model(); | |||
| model->load_model(); | |||
| //! get model binary data after optimized | |||
| for (auto stage : | |||
| {RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||
| RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET, | |||
| RunStage::MODEL_RUNNING}) { | |||
| runtime_param.stage = stage; | |||
| stage_config_model(); | |||
| } | |||
| model->run_model(); | |||
| model->wait(); | |||
| std::vector<uint8_t> model_data = model->get_model_data(); | |||
| mgb_log_warn("dumped model_data size=%zu\n", model_data.size()); | |||
| auto fp = fopen(m_filename.c_str(), "wb"); | |||
| fwrite(model_data.data(), 1, model_data.size(), fp); | |||
| fclose(fp); | |||
| } | |||
| void FittingStrategy::run() { | |||
| mgb_assert("this version don't support Fitting Strategy"); | |||
| }; | |||
| auto mgb_version = mgb::get_version(); | |||
| auto dnn_version = megdnn::get_version(); | |||
| printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||
| "%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||
| mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev, | |||
| dnn_version.major, dnn_version.minor, dnn_version.patch); | |||
| // ! create profiler with given user info | |||
| m_info_parser.get_user_info(); | |||
| m_info_parser.parse_info(m_manager); | |||
| auto profiler = m_info_parser.create_profiler(); | |||
| mgb_throw_if( | |||
| profiler == nullptr, mgb::AssertionError, | |||
| "get empty profiler for fittting\n"); | |||
| //! profile model with fixed options | |||
| while (!m_manager->is_fixed_end()) { | |||
| std::string option_str = m_manager->set_next_fixed_options(); | |||
| profiler->profile_with_given_options(m_model_path, m_options, option_str); | |||
| #if (MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||
| //! sleep to keep machine with stable cpu frequence | |||
| usleep(500000); | |||
| #endif | |||
| } | |||
| std::string m_tmp_model = m_model_path; | |||
| const std::regex layout_regex("layout_transform"); | |||
| auto best_fixed_options = profiler->get_best_setting(); | |||
| m_manager->set_options(best_fixed_options); | |||
| //! dump model for global layout transform | |||
| auto m_tmp_file = AutoCleanFile(m_model_path, m_options); | |||
| if (std::regex_search(best_fixed_options, layout_regex)) { | |||
| m_tmp_file.dump_model(); | |||
| m_model_path = m_tmp_file.filename(); | |||
| } | |||
| //! profile model with given profiler | |||
| while (!m_manager->is_end_options()) { | |||
| std::string curr_option_str = m_manager->set_next_options(); | |||
| //! set option with current option and fixed options | |||
| if (m_model_path == m_tmp_model) { | |||
| auto total_option_str = curr_option_str + best_fixed_options; | |||
| m_manager->set_options(total_option_str); | |||
| } | |||
| curr_option_str += best_fixed_options; | |||
| profiler->profile_with_given_options(m_model_path, m_options, curr_option_str); | |||
| #if (MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||
| usleep(500000); | |||
| #endif | |||
| } | |||
| //! set with best options and inference | |||
| m_model_path = m_tmp_model; | |||
| auto best_options = profiler->get_best_setting(); | |||
| m_manager->set_options(best_options); | |||
| profiler->profile_with_given_options(m_model_path, m_options, best_options); | |||
| //! save best options into given dir | |||
| std::cout << "the best options:\n" << best_options << std::endl; | |||
| if (!m_dumped_model.empty()) { | |||
| dump_best_options_with_model(); | |||
| } | |||
| } | |||
| DEFINE_bool( | |||
| fitting, false, | |||
| "whether to use the fitting model, which will auto profile and get " | |||
| "the best option set!"); | |||
| fitting, false, "use the fitting mode profile and get the best option set."); | |||
| DEFINE_string(dump_fitting_model, "", "dump the best option and algo cache into model"); | |||
| @@ -0,0 +1,152 @@ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "helpers/utils.h" | |||
| #include "strategy.h" | |||
| DECLARE_bool(fitting); | |||
| DECLARE_string(dump_fitting_model); | |||
| #define TIME_OUT 10000 | |||
| namespace lar { | |||
| class OptionsFastManager { | |||
| public: | |||
| using ConstraintMap = std::unordered_map<std::string, bool>; | |||
| OptionsFastManager(){}; | |||
| //! init the options value map with given options | |||
| void init(std::shared_ptr<OptionMap>&); | |||
| //! set next options group cyclely | |||
| std::string set_next_options(); | |||
| std::string set_next_fixed_options(); | |||
| //! check the end of options group | |||
| bool is_end_options(); | |||
| bool is_fixed_end(); | |||
| std::string get_curr_options_code(CoderType, bool encode_all = false); | |||
| //! set current options with given options | |||
| void set_options(const std::string&); | |||
| void registe_fixed_options(const std::vector<std::string>&); | |||
| #if MGB_ENABLE_JSON | |||
| std::vector<std::shared_ptr<mgb::json::Object>> get_json(); | |||
| #endif | |||
| private: | |||
| void reset_option(); | |||
| size_t m_option_group_cnt; | |||
| size_t m_fixed_option_cnt; | |||
| OptionValMap m_valid_option_vals; | |||
| std::vector<std::vector<std::string>> m_internal_options_name; | |||
| std::vector<std::vector<std::string>> m_fixed_options_name; | |||
| #if MGB_ENABLE_JSON | |||
| JsonOptionsCoder m_json_coder; | |||
| #endif | |||
| GflagsOptionsCoder m_gflags_coder; | |||
| }; | |||
| //! Options proifler to get the best settings with different evaluate standard | |||
| class OptionsProfiler { | |||
| public: | |||
| OptionsProfiler(){}; | |||
| //! run with m_options | |||
| virtual void profile_with_given_options( | |||
| const std::string&, std::shared_ptr<OptionMap>&, const std::string&) = 0; | |||
| //! get the best setting and inference time | |||
| virtual std::string get_best_setting() { return ""; } | |||
| virtual ~OptionsProfiler() = default; | |||
| }; | |||
| /** | |||
| * profiler to get the fast setting | |||
| */ | |||
| class OptionsTimeProfiler final : public OptionsProfiler { | |||
| public: | |||
| OptionsTimeProfiler(){}; | |||
| void profile_with_given_options( | |||
| const std::string&, std::shared_ptr<OptionMap>&, | |||
| const std::string&) override; | |||
| std::string get_best_setting() override { return m_best_setting.first; } | |||
| private: | |||
| std::unordered_map<std::string, double> m_options_profile_result; | |||
| std::pair<std::string, double> m_best_setting = {"", TIME_OUT}; | |||
| }; | |||
| /** | |||
| * parse information from user given | |||
| */ | |||
| class UserInfoParser { | |||
| public: | |||
| UserInfoParser(){}; | |||
| void get_user_info(); | |||
| void parse_info(std::shared_ptr<OptionsFastManager>&); | |||
| std::shared_ptr<OptionsProfiler> create_profiler() { | |||
| switch (m_proifler_type) { | |||
| case ProiflerType::TIME_PROFILER: | |||
| return std::make_shared<OptionsTimeProfiler>(); | |||
| case ProiflerType::UNSPEC_PROFILER: | |||
| return nullptr; | |||
| default: | |||
| return nullptr; | |||
| } | |||
| } | |||
| private: | |||
| ProiflerType m_proifler_type; | |||
| std::unordered_map<std::string, std::string> m_user_info; | |||
| }; | |||
| /*! | |||
| * \brief: Fitting strategy for running | |||
| */ | |||
| class FittingStrategy : public StrategyBase { | |||
| public: | |||
| class AutoCleanFile { | |||
| public: | |||
| AutoCleanFile( | |||
| const std::string& model_path, std::shared_ptr<OptionMap>& options); | |||
| void dump_model(); | |||
| std::string filename() { return m_filename; } | |||
| ~AutoCleanFile(); | |||
| private: | |||
| std::string m_model_path; | |||
| std::shared_ptr<OptionMap> m_options; | |||
| std::string m_filename; | |||
| }; | |||
| FittingStrategy(std::string model_path); | |||
| void run() override; | |||
| void dump_best_options_with_model(); | |||
| void dump_model(); | |||
| private: | |||
| std::string m_model_path; | |||
| std::string m_dumped_model; | |||
| std::shared_ptr<OptionsFastManager> m_manager; | |||
| UserInfoParser m_info_parser; | |||
| #if MGB_ENABLE_JSON | |||
| std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||
| m_packed_info; | |||
| #endif | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,11 +1,4 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy_normal.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy_normal.h" | |||
| #include <iostream> | |||
| #include <thread> | |||
| #include "megbrain/common.h" | |||
| @@ -13,13 +6,13 @@ | |||
| #include "megbrain/version.h" | |||
| #include "megdnn/version.h" | |||
| #include "misc.h" | |||
| #include "strategy.h" | |||
| using namespace lar; | |||
| NormalStrategy::NormalStrategy(std::string model_path) { | |||
| mgb::set_log_level(mgb::LogLevel::WARN); | |||
| lite::set_log_level(LiteLogLevel::WARN); | |||
| m_options = std::make_shared<OptionMap>(); | |||
| m_model_path = model_path; | |||
| auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||
| mgb_log_debug("option map size: %lu", option_creator_map->size()); | |||
| @@ -27,13 +20,13 @@ NormalStrategy::NormalStrategy(std::string model_path) { | |||
| auto& creator = (*option_creator_map)[name]; | |||
| auto option = creator(); | |||
| if (option) { | |||
| m_options.insert({name, option}); | |||
| m_options->insert({name, option}); | |||
| } | |||
| }; | |||
| for (auto& creator : *option_creator_map) { | |||
| auto name = creator.first; | |||
| if (m_options.count(name) == 0) { | |||
| if (m_options->count(name) == 0) { | |||
| construct_option(name); | |||
| } | |||
| } | |||
| @@ -44,7 +37,7 @@ void NormalStrategy::run_subline() { | |||
| mgb_assert(model != nullptr, "create model failed!!"); | |||
| auto stage_config_model = [&]() { | |||
| for (auto& option : m_options) { | |||
| for (auto& option : *m_options) { | |||
| option.second->config_model(m_runtime_param, model); | |||
| } | |||
| }; | |||
| @@ -57,18 +50,14 @@ void NormalStrategy::run_subline() { | |||
| printf("load model: %.3fms\n", timer.get_msecs_reset()); | |||
| //! after load configure | |||
| m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD; | |||
| stage_config_model(); | |||
| m_runtime_param.stage = RunStage::GLOBAL_OPTIMIZATION; | |||
| stage_config_model(); | |||
| m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET; | |||
| stage_config_model(); | |||
| // for get static memmory information options | |||
| m_runtime_param.stage = RunStage::AFTER_OUTSPEC_SET; | |||
| stage_config_model(); | |||
| auto config_after_load = [&]() { | |||
| for (auto stage : | |||
| {RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||
| RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET}) { | |||
| m_runtime_param.stage = stage; | |||
| stage_config_model(); | |||
| } | |||
| }; | |||
| auto warm_up = [&]() { | |||
| auto warmup_num = m_runtime_param.warmup_iter; | |||
| @@ -117,6 +106,8 @@ void NormalStrategy::run_subline() { | |||
| double tot_time = 0; | |||
| for (size_t idx = 0; idx < iter_num; idx++) { | |||
| //! config model | |||
| config_after_load(); | |||
| //! config when running model | |||
| mgb_log_warn("run testcase: %zu ", idx); | |||
| m_runtime_param.stage = RunStage::MODEL_RUNNING; | |||
| @@ -0,0 +1,22 @@ | |||
| #pragma once | |||
| #include "strategy.h" | |||
| namespace lar { | |||
| /*! | |||
| * \brief: normal strategy for running | |||
| */ | |||
| class NormalStrategy : public StrategyBase { | |||
| public: | |||
| NormalStrategy(std::string model_path); | |||
| //! run model with runtime parameter | |||
| void run() override; | |||
| private: | |||
| //! run model subline for multiple thread | |||
| void run_subline(); | |||
| std::string m_model_path; | |||
| }; | |||
| } // namespace lar | |||
| @@ -1,14 +1,3 @@ | |||
| /** | |||
| * \file src/pack_model/pack_model.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "lite/pack_model.h" | |||
| #include "../misc.h" | |||
| #if LITE_BUILD_WITH_MGE | |||
| @@ -192,7 +181,7 @@ ModelPacker::ModelPacker( | |||
| std::string info_data_path, std::string info_algo_policy_path, | |||
| std::string info_binary_cache_path) | |||
| : m_packed_model_path(packed_model_path) { | |||
| m_fbs_helper = new FbsHelper(this, model_path); | |||
| m_fbs_helper = std::make_shared<FbsHelper>(this, model_path); | |||
| std::vector<uint8_t> empty_vec; | |||
| m_info_data = info_data_path.empty() ? empty_vec : read_file(info_data_path); | |||
| m_algo_policy_data = info_algo_policy_path.empty() | |||
| @@ -207,7 +196,7 @@ ModelPacker::ModelPacker( | |||
| std::vector<uint8_t> model_data, std::string packed_model_path, | |||
| std::vector<uint8_t> info_data, std::vector<uint8_t> info_algo_policy_data, | |||
| std::vector<uint8_t> info_binary_cache_data) { | |||
| m_fbs_helper = new FbsHelper(this, model_data); | |||
| m_fbs_helper = std::make_shared<FbsHelper>(this, model_data); | |||
| m_packed_model_path = packed_model_path; | |||
| m_info_data = info_data; | |||
| m_algo_policy_data = info_algo_policy_data; | |||
| @@ -1,14 +1,3 @@ | |||
| /** | |||
| * \file src/parse_info/cache_parse.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "lite/global.h" | |||
| #if LITE_BUILD_WITH_MGE | |||