| @@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key( | |||||
| * other config not inclue in config and networkIO, ParseInfoFunc can fill it | * other config not inclue in config and networkIO, ParseInfoFunc can fill it | ||||
| * with the information in json, now support: | * with the information in json, now support: | ||||
| * "device_id" : int, default 0 | * "device_id" : int, default 0 | ||||
| * "number_threads" : size_t, default 1 | |||||
| * "number_threads" : uint32_t, default 1 | |||||
| * "is_inplace_model" : bool, default false | * "is_inplace_model" : bool, default false | ||||
| * "use_tensorrt" : bool, default false | * "use_tensorrt" : bool, default false | ||||
| */ | */ | ||||
| @@ -149,28 +149,42 @@ private: | |||||
| */ | */ | ||||
| class LITE_API LiteAny { | class LITE_API LiteAny { | ||||
| public: | public: | ||||
| enum Type { | |||||
| STRING = 0, | |||||
| INT32 = 1, | |||||
| UINT32 = 2, | |||||
| UINT8 = 3, | |||||
| INT8 = 4, | |||||
| INT64 = 5, | |||||
| UINT64 = 6, | |||||
| BOOL = 7, | |||||
| VOID_PTR = 8, | |||||
| FLOAT = 9, | |||||
| NONE_SUPPORT = 10, | |||||
| }; | |||||
| LiteAny() = default; | LiteAny() = default; | ||||
| template <class T> | template <class T> | ||||
| LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { | LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { | ||||
| m_is_string = std::is_same<std::string, T>(); | |||||
| m_type = get_type<T>(); | |||||
| } | } | ||||
| LiteAny(const LiteAny& any) { | LiteAny(const LiteAny& any) { | ||||
| m_holder = any.m_holder->clone(); | m_holder = any.m_holder->clone(); | ||||
| m_is_string = any.is_string(); | |||||
| m_type = any.m_type; | |||||
| } | } | ||||
| LiteAny& operator=(const LiteAny& any) { | LiteAny& operator=(const LiteAny& any) { | ||||
| m_holder = any.m_holder->clone(); | m_holder = any.m_holder->clone(); | ||||
| m_is_string = any.is_string(); | |||||
| m_type = any.m_type; | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| bool is_string() const { return m_is_string; } | |||||
| template <class T> | |||||
| Type get_type() const; | |||||
| class HolderBase { | class HolderBase { | ||||
| public: | public: | ||||
| virtual ~HolderBase() = default; | virtual ~HolderBase() = default; | ||||
| virtual std::shared_ptr<HolderBase> clone() = 0; | virtual std::shared_ptr<HolderBase> clone() = 0; | ||||
| virtual size_t type_length() const = 0; | |||||
| }; | }; | ||||
| template <class T> | template <class T> | ||||
| @@ -180,7 +194,6 @@ public: | |||||
| virtual std::shared_ptr<HolderBase> clone() override { | virtual std::shared_ptr<HolderBase> clone() override { | ||||
| return std::make_shared<AnyHolder>(m_value); | return std::make_shared<AnyHolder>(m_value); | ||||
| } | } | ||||
| virtual size_t type_length() const override { return sizeof(T); } | |||||
| public: | public: | ||||
| T m_value; | T m_value; | ||||
| @@ -188,14 +201,21 @@ public: | |||||
| //! if type is miss matching, it will throw | //! if type is miss matching, it will throw | ||||
| void type_missmatch(size_t expect, size_t get) const; | void type_missmatch(size_t expect, size_t get) const; | ||||
| //! only check the storage type and the visit type length, so it's not safe | |||||
| template <class T> | template <class T> | ||||
| T unsafe_cast() const { | |||||
| if (sizeof(T) != m_holder->type_length()) { | |||||
| type_missmatch(m_holder->type_length(), sizeof(T)); | |||||
| T safe_cast() const { | |||||
| if (get_type<T>() != m_type) { | |||||
| type_missmatch(m_type, get_type<T>()); | |||||
| } | } | ||||
| return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; | return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; | ||||
| } | } | ||||
| template <class T> | |||||
| bool try_cast() const { | |||||
| if (get_type<T>() == m_type) { | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| //! only check the storage type and the visit type length, so it's not safe | //! only check the storage type and the visit type length, so it's not safe | ||||
| void* cast_void_ptr() const { | void* cast_void_ptr() const { | ||||
| return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; | return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; | ||||
| @@ -203,7 +223,7 @@ public: | |||||
| private: | private: | ||||
| std::shared_ptr<HolderBase> m_holder; | std::shared_ptr<HolderBase> m_holder; | ||||
| bool m_is_string = false; | |||||
| Type m_type = NONE_SUPPORT; | |||||
| }; | }; | ||||
| /*********************** special tensor function ***************/ | /*********************** special tensor function ***************/ | ||||
| @@ -127,7 +127,8 @@ int LITE_register_parse_info_func( | |||||
| separate_config_map["device_id"] = device_id; | separate_config_map["device_id"] = device_id; | ||||
| } | } | ||||
| if (nr_threads != 1) { | if (nr_threads != 1) { | ||||
| separate_config_map["nr_threads"] = nr_threads; | |||||
| separate_config_map["nr_threads"] = | |||||
| static_cast<uint32_t>(nr_threads); | |||||
| } | } | ||||
| if (is_cpu_inplace_mode != false) { | if (is_cpu_inplace_mode != false) { | ||||
| separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; | separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; | ||||
| @@ -352,19 +352,19 @@ void NetworkImplDft::load_model( | |||||
| //! config some flag get from json config file | //! config some flag get from json config file | ||||
| if (separate_config_map.find("device_id") != separate_config_map.end()) { | if (separate_config_map.find("device_id") != separate_config_map.end()) { | ||||
| set_device_id(separate_config_map["device_id"].unsafe_cast<int>()); | |||||
| set_device_id(separate_config_map["device_id"].safe_cast<int>()); | |||||
| } | } | ||||
| if (separate_config_map.find("number_threads") != separate_config_map.end() && | if (separate_config_map.find("number_threads") != separate_config_map.end() && | ||||
| separate_config_map["number_threads"].unsafe_cast<size_t>() > 1) { | |||||
| separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) { | |||||
| set_cpu_threads_number( | set_cpu_threads_number( | ||||
| separate_config_map["number_threads"].unsafe_cast<size_t>()); | |||||
| separate_config_map["number_threads"].safe_cast<uint32_t>()); | |||||
| } | } | ||||
| if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && | if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && | ||||
| separate_config_map["enable_inplace_model"].unsafe_cast<bool>()) { | |||||
| separate_config_map["enable_inplace_model"].safe_cast<bool>()) { | |||||
| set_cpu_inplace_mode(); | set_cpu_inplace_mode(); | ||||
| } | } | ||||
| if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && | if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && | ||||
| separate_config_map["use_tensorrt"].unsafe_cast<bool>()) { | |||||
| separate_config_map["use_tensorrt"].safe_cast<bool>()) { | |||||
| use_tensorrt(); | use_tensorrt(); | ||||
| } | } | ||||
| @@ -84,7 +84,7 @@ bool default_parse_info( | |||||
| } | } | ||||
| if (device_json.contains("number_threads")) { | if (device_json.contains("number_threads")) { | ||||
| separate_config_map["number_threads"] = | separate_config_map["number_threads"] = | ||||
| static_cast<size_t>(device_json["number_threads"]); | |||||
| static_cast<uint32_t>(device_json["number_threads"]); | |||||
| } | } | ||||
| if (device_json.contains("enable_inplace_model")) { | if (device_json.contains("enable_inplace_model")) { | ||||
| separate_config_map["enable_inplace_model"] = | separate_config_map["enable_inplace_model"] = | ||||
| @@ -277,10 +277,28 @@ void Tensor::update_from_implement() { | |||||
| void LiteAny::type_missmatch(size_t expect, size_t get) const { | void LiteAny::type_missmatch(size_t expect, size_t get) const { | ||||
| LITE_THROW(ssprintf( | LITE_THROW(ssprintf( | ||||
| "The type store in LiteAny is not match the visit type, type of " | "The type store in LiteAny is not match the visit type, type of " | ||||
| "storage length is %zu, type of visit length is %zu.", | |||||
| "storage enum is %zu, type of visit enum is %zu.", | |||||
| expect, get)); | expect, get)); | ||||
| } | } | ||||
| namespace lite { | |||||
| #define GET_TYPE(ctype, ENUM) \ | |||||
| template <> \ | |||||
| LiteAny::Type LiteAny::get_type<ctype>() const { \ | |||||
| return ENUM; \ | |||||
| } | |||||
| GET_TYPE(std::string, STRING) | |||||
| GET_TYPE(int32_t, INT32) | |||||
| GET_TYPE(uint32_t, UINT32) | |||||
| GET_TYPE(int8_t, INT8) | |||||
| GET_TYPE(uint8_t, UINT8) | |||||
| GET_TYPE(int64_t, INT64) | |||||
| GET_TYPE(uint64_t, UINT64) | |||||
| GET_TYPE(float, FLOAT) | |||||
| GET_TYPE(bool, BOOL) | |||||
| GET_TYPE(void*, VOID_PTR) | |||||
| } // namespace lite | |||||
| std::shared_ptr<Tensor> TensorUtils::concat( | std::shared_ptr<Tensor> TensorUtils::concat( | ||||
| const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device, | const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device, | ||||
| int dst_device_id) { | int dst_device_id) { | ||||