| @@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key( | |||
| * other config not inclue in config and networkIO, ParseInfoFunc can fill it | |||
| * with the information in json, now support: | |||
| * "device_id" : int, default 0 | |||
| * "number_threads" : size_t, default 1 | |||
| * "number_threads" : uint32_t, default 1 | |||
| * "is_inplace_model" : bool, default false | |||
| * "use_tensorrt" : bool, default false | |||
| */ | |||
| @@ -149,28 +149,42 @@ private: | |||
| */ | |||
| class LITE_API LiteAny { | |||
| public: | |||
| enum Type { | |||
| STRING = 0, | |||
| INT32 = 1, | |||
| UINT32 = 2, | |||
| UINT8 = 3, | |||
| INT8 = 4, | |||
| INT64 = 5, | |||
| UINT64 = 6, | |||
| BOOL = 7, | |||
| VOID_PTR = 8, | |||
| FLOAT = 9, | |||
| NONE_SUPPORT = 10, | |||
| }; | |||
| LiteAny() = default; | |||
| template <class T> | |||
| LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { | |||
| m_is_string = std::is_same<std::string, T>(); | |||
| m_type = get_type<T>(); | |||
| } | |||
| LiteAny(const LiteAny& any) { | |||
| m_holder = any.m_holder->clone(); | |||
| m_is_string = any.is_string(); | |||
| m_type = any.m_type; | |||
| } | |||
| LiteAny& operator=(const LiteAny& any) { | |||
| m_holder = any.m_holder->clone(); | |||
| m_is_string = any.is_string(); | |||
| m_type = any.m_type; | |||
| return *this; | |||
| } | |||
| bool is_string() const { return m_is_string; } | |||
| template <class T> | |||
| Type get_type() const; | |||
| class HolderBase { | |||
| public: | |||
| virtual ~HolderBase() = default; | |||
| virtual std::shared_ptr<HolderBase> clone() = 0; | |||
| virtual size_t type_length() const = 0; | |||
| }; | |||
| template <class T> | |||
| @@ -180,7 +194,6 @@ public: | |||
| virtual std::shared_ptr<HolderBase> clone() override { | |||
| return std::make_shared<AnyHolder>(m_value); | |||
| } | |||
| virtual size_t type_length() const override { return sizeof(T); } | |||
| public: | |||
| T m_value; | |||
| @@ -188,14 +201,21 @@ public: | |||
| //! if type is miss matching, it will throw | |||
| void type_missmatch(size_t expect, size_t get) const; | |||
| //! only check the storage type and the visit type length, so it's not safe | |||
| template <class T> | |||
| T unsafe_cast() const { | |||
| if (sizeof(T) != m_holder->type_length()) { | |||
| type_missmatch(m_holder->type_length(), sizeof(T)); | |||
| T safe_cast() const { | |||
| if (get_type<T>() != m_type) { | |||
| type_missmatch(m_type, get_type<T>()); | |||
| } | |||
| return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; | |||
| } | |||
| template <class T> | |||
| bool try_cast() const { | |||
| if (get_type<T>() == m_type) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| //! only check the storage type and the visit type length, so it's not safe | |||
| void* cast_void_ptr() const { | |||
| return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; | |||
| @@ -203,7 +223,7 @@ public: | |||
| private: | |||
| std::shared_ptr<HolderBase> m_holder; | |||
| bool m_is_string = false; | |||
| Type m_type = NONE_SUPPORT; | |||
| }; | |||
| /*********************** special tensor function ***************/ | |||
| @@ -127,7 +127,8 @@ int LITE_register_parse_info_func( | |||
| separate_config_map["device_id"] = device_id; | |||
| } | |||
| if (nr_threads != 1) { | |||
| separate_config_map["nr_threads"] = nr_threads; | |||
| separate_config_map["nr_threads"] = | |||
| static_cast<uint32_t>(nr_threads); | |||
| } | |||
| if (is_cpu_inplace_mode != false) { | |||
| separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; | |||
| @@ -352,19 +352,19 @@ void NetworkImplDft::load_model( | |||
| //! config some flag get from json config file | |||
| if (separate_config_map.find("device_id") != separate_config_map.end()) { | |||
| set_device_id(separate_config_map["device_id"].unsafe_cast<int>()); | |||
| set_device_id(separate_config_map["device_id"].safe_cast<int>()); | |||
| } | |||
| if (separate_config_map.find("number_threads") != separate_config_map.end() && | |||
| separate_config_map["number_threads"].unsafe_cast<size_t>() > 1) { | |||
| separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) { | |||
| set_cpu_threads_number( | |||
| separate_config_map["number_threads"].unsafe_cast<size_t>()); | |||
| separate_config_map["number_threads"].safe_cast<uint32_t>()); | |||
| } | |||
| if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && | |||
| separate_config_map["enable_inplace_model"].unsafe_cast<bool>()) { | |||
| separate_config_map["enable_inplace_model"].safe_cast<bool>()) { | |||
| set_cpu_inplace_mode(); | |||
| } | |||
| if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && | |||
| separate_config_map["use_tensorrt"].unsafe_cast<bool>()) { | |||
| separate_config_map["use_tensorrt"].safe_cast<bool>()) { | |||
| use_tensorrt(); | |||
| } | |||
| @@ -84,7 +84,7 @@ bool default_parse_info( | |||
| } | |||
| if (device_json.contains("number_threads")) { | |||
| separate_config_map["number_threads"] = | |||
| static_cast<size_t>(device_json["number_threads"]); | |||
| static_cast<uint32_t>(device_json["number_threads"]); | |||
| } | |||
| if (device_json.contains("enable_inplace_model")) { | |||
| separate_config_map["enable_inplace_model"] = | |||
| @@ -277,10 +277,28 @@ void Tensor::update_from_implement() { | |||
| void LiteAny::type_missmatch(size_t expect, size_t get) const { | |||
| LITE_THROW(ssprintf( | |||
| "The type store in LiteAny is not match the visit type, type of " | |||
| "storage length is %zu, type of visit length is %zu.", | |||
| "storage enum is %zu, type of visit enum is %zu.", | |||
| expect, get)); | |||
| } | |||
| namespace lite { | |||
| #define GET_TYPE(ctype, ENUM) \ | |||
| template <> \ | |||
| LiteAny::Type LiteAny::get_type<ctype>() const { \ | |||
| return ENUM; \ | |||
| } | |||
| GET_TYPE(std::string, STRING) | |||
| GET_TYPE(int32_t, INT32) | |||
| GET_TYPE(uint32_t, UINT32) | |||
| GET_TYPE(int8_t, INT8) | |||
| GET_TYPE(uint8_t, UINT8) | |||
| GET_TYPE(int64_t, INT64) | |||
| GET_TYPE(uint64_t, UINT64) | |||
| GET_TYPE(float, FLOAT) | |||
| GET_TYPE(bool, BOOL) | |||
| GET_TYPE(void*, VOID_PTR) | |||
| } // namespace lite | |||
| std::shared_ptr<Tensor> TensorUtils::concat( | |||
| const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device, | |||
| int dst_device_id) { | |||