GitOrigin-RevId: bb9150eb83
tags/v1.0.0-rc1
| @@ -537,14 +537,28 @@ std::shared_ptr<json::Value> ComputingGraphImpl::ComputingSequence::to_json() | |||||
| comp_seq->add(json::String::make(i->id_str())); | comp_seq->add(json::String::make(i->id_str())); | ||||
| } | } | ||||
| // expand opr and var nodes that do not appear in comp seq | |||||
| // expand opr and var nodes that do not appear in comp seq, | |||||
| // also expand var nodes which are only used in static infer | |||||
| { | { | ||||
| VarNodeArray new_var_node; | VarNodeArray new_var_node; | ||||
| auto&& mgr = m_owner_graph->static_infer_manager_impl(); | |||||
| auto check_opr_input = [&](OperatorNodeBase* opr) { | auto check_opr_input = [&](OperatorNodeBase* opr) { | ||||
| auto update = [&](VarNode* var) { | |||||
| if (!(all_var_node.count(var))) { | |||||
| all_var_node.insert(var); | |||||
| new_var_node.push_back(var); | |||||
| } | |||||
| }; | |||||
| for (auto i : opr->input()) { | for (auto i : opr->input()) { | ||||
| if (!(all_var_node.count(i))) { | |||||
| all_var_node.insert(i); | |||||
| new_var_node.push_back(i); | |||||
| update(i); | |||||
| } | |||||
| for (auto &&out : opr->output()) { | |||||
| using DepType = static_infer::DepType; | |||||
| for (auto&& i : mgr.get_deps({out, DepType::SHAPE})) { | |||||
| update(i.dest); | |||||
| } | |||||
| for (auto&& i : mgr.get_deps({out, DepType::VALUE})) { | |||||
| update(i.dest); | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -245,6 +245,9 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, | |||||
| return m_infer_withoutexc_ret; | return m_infer_withoutexc_ret; | ||||
| } | } | ||||
| //! original deps given in the InferDesc by the caller | |||||
| virtual const DepVal& raw_deps() = 0; | |||||
| protected: | protected: | ||||
| //! current infer result, to be used by dependents | //! current infer result, to be used by dependents | ||||
| InpElement m_inp_element; | InpElement m_inp_element; | ||||
| @@ -300,9 +303,6 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, | |||||
| //! all missing inputs | //! all missing inputs | ||||
| SharedSet<TagHandler*, TagHandlerSet> m_missing_input; | SharedSet<TagHandler*, TagHandlerSet> m_missing_input; | ||||
| //! original deps given in the InferDesc by the caller | |||||
| virtual const DepVal& raw_deps() = 0; | |||||
| //! recursively set m_inp_element_synced of this and all receivers to | //! recursively set m_inp_element_synced of this and all receivers to | ||||
| //! false | //! false | ||||
| void reset_inp_element_synced(); | void reset_inp_element_synced(); | ||||
| @@ -1027,6 +1027,14 @@ void StaticInferManagerImpl::update_mutable_src_shape(Tag dest) { | |||||
| MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); }) | MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); }) | ||||
| } | } | ||||
| DepVal StaticInferManagerImpl::get_deps(const DepElement &elem) { | |||||
| auto trait_base = get_tag_trait_container(elem.dest).select(elem.type); | |||||
| if (!trait_base || trait_base->is_const()) | |||||
| return {}; | |||||
| return trait_base->as_mutable_safe()->raw_deps(); | |||||
| } | |||||
| /* ===================== CompSeqManager ===================== */ | /* ===================== CompSeqManager ===================== */ | ||||
| class CompSeqManager::VersionedTagTrait { | class CompSeqManager::VersionedTagTrait { | ||||
| @@ -99,6 +99,17 @@ class StaticInferManagerImpl final: public StaticInferManager { | |||||
| */ | */ | ||||
| void update_mutable_src_shape(Tag tag); | void update_mutable_src_shape(Tag tag); | ||||
| /*! | |||||
| * \brief get original deps given in the InferDesc which is registered | |||||
| * by register_shape_infer or register_value_infer | |||||
| * | |||||
| * Note: the \p elem with DepType::SHAPE and InferType::CONST shows no | |||||
| * deps since the StaticInferManagerImpl folds the infererence chain of | |||||
| * the const var shape | |||||
| */ | |||||
| DepVal get_deps(const DepElement &elem); | |||||
| private: | private: | ||||
| friend class CompSeqManager; | friend class CompSeqManager; | ||||
| @@ -396,6 +396,108 @@ VarNode& VarNode::comp_node(const CompNode &cn) { | |||||
| } | } | ||||
| #if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
| std::shared_ptr<json::Value> | |||||
| VarNode::dump_static_infer_info_to_json() const { | |||||
| using namespace cg::static_infer; | |||||
| auto&& mgr = static_cast<cg::ComputingGraphImpl*>( | |||||
| owner_graph())->static_infer_manager_impl(); | |||||
| auto get_dep_type = [](const DepType& type) -> std::string { | |||||
| switch (type) { | |||||
| #define cb(name) \ | |||||
| case DepType::name: \ | |||||
| return #name; | |||||
| cb(SHAPE) | |||||
| cb(VALUE) | |||||
| #undef cb | |||||
| default: | |||||
| mgb_throw(MegBrainError, "unknown dep type"); | |||||
| } | |||||
| }; | |||||
| auto get_infer_type = [](const InferType::Flag& type) { | |||||
| switch (type) { | |||||
| #define cb(name) \ | |||||
| case InferType::Flag::name: \ | |||||
| return json::String::make(#name); | |||||
| cb(NO_DESC) | |||||
| cb(CONST) | |||||
| cb(RT_STATIC) | |||||
| cb(MISSING_INP) | |||||
| #undef cb | |||||
| default: | |||||
| mgb_throw(MegBrainError, "unknown infer type"); | |||||
| } | |||||
| }; | |||||
| auto make_tag = [&](const DepType& type) { | |||||
| VarNode* self = const_cast<VarNode*>(this); | |||||
| auto c_deps = mgr.get_deps({self, type}); | |||||
| auto deps = json::Array::make(); | |||||
| for (auto&& i : c_deps) { | |||||
| mgb_assert(i.dest); | |||||
| deps->add(json::Object::make({ | |||||
| {"var", json::String::make(i.dest->id_str())}, | |||||
| {"dep_type", json::String::make(get_dep_type(i.type))} | |||||
| })); | |||||
| } | |||||
| auto infer_type_handle = mgr.get_infer_type(self); | |||||
| auto inferred_result = json::Null::make(); | |||||
| auto infer_type = type == DepType::SHAPE ? infer_type_handle.shape | |||||
| : infer_type_handle.value; | |||||
| if (infer_type != InferType::Flag::NO_DESC) { | |||||
| if (type == DepType::SHAPE) { | |||||
| if (auto shape = mgr.infer_shape_fallible(self)) { | |||||
| auto inferred_shape = json::Array::make(); | |||||
| for (size_t i = 0; i < shape->ndim; ++ i) { | |||||
| inferred_shape->add(json::Number::make((*shape)[i])); | |||||
| } | |||||
| inferred_result = inferred_shape; | |||||
| } | |||||
| } else { | |||||
| if (auto p = mgr.infer_value_fallible(self)) { | |||||
| auto&& dev = *p; | |||||
| if (dev.shape().ndim == 1 && | |||||
| dev.shape(0) < TensorShape::MAX_NDIM && | |||||
| mgb_likely(dev.comp_node() == CompNode::default_cpu())) { | |||||
| MGB_TRY { | |||||
| size_t nr_elems = dev.shape(0); | |||||
| auto&& dtype = dev.dtype(); | |||||
| void* vptr = dev.raw_ptr(); | |||||
| double data[nr_elems]; | |||||
| HostTensorND contig; | |||||
| if (!dev.layout().is_contiguous()) { | |||||
| // both src and dst are placed on default cpu, | |||||
| // no need for sync | |||||
| contig.copy_from(dev); | |||||
| mgb_assert(contig.layout().is_contiguous()); | |||||
| vptr = contig.raw_ptr(); | |||||
| } | |||||
| static_cast_dtype(data, dtype, vptr, nr_elems); | |||||
| auto inferred_value = json::Array::make(); | |||||
| for (size_t i = 0; i < nr_elems; ++ i) { | |||||
| inferred_value->add(json::Number::make(data[i])); | |||||
| } | |||||
| inferred_result = inferred_value; | |||||
| } | |||||
| MGB_CATCH(ConversionError&, {}); | |||||
| } else { | |||||
| inferred_result = json::String::make("Large Array"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return json::Object::make({ | |||||
| {"node_type", json::String::make("static_infer_tag")}, | |||||
| {"infer_type", get_infer_type(infer_type)}, | |||||
| {"inferred_result", inferred_result}, | |||||
| {"deps", deps} | |||||
| }); | |||||
| }; | |||||
| return json::Object::make({ | |||||
| #define TAG(type) {get_dep_type(type), make_tag(type)} | |||||
| TAG(DepType::SHAPE), TAG(DepType::VALUE) | |||||
| #undef TAG | |||||
| }); | |||||
| } | |||||
| std::shared_ptr<json::Value> VarNode::to_json() const { | std::shared_ptr<json::Value> VarNode::to_json() const { | ||||
| auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> { | auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> { | ||||
| if(p) | if(p) | ||||
| @@ -443,8 +545,10 @@ std::shared_ptr<json::Value> VarNode::to_json() const { | |||||
| {"dev_ptr", json::Null::make()}, | {"dev_ptr", json::Null::make()}, | ||||
| {"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>( | {"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>( | ||||
| m_prev_dev_ptr))}, | m_prev_dev_ptr))}, | ||||
| {"flag", flag} | |||||
| {"flag", flag}, | |||||
| {"static_infer_tags", dump_static_infer_info_to_json()} | |||||
| }); | }); | ||||
| if (m_prev_dev_ptr) { | if (m_prev_dev_ptr) { | ||||
| (*rst)["prev_dev_ptr_end"] = json::NumberInt::make( | (*rst)["prev_dev_ptr_end"] = json::NumberInt::make( | ||||
| reinterpret_cast<size_t>(m_prev_dev_ptr) + | reinterpret_cast<size_t>(m_prev_dev_ptr) + | ||||
| @@ -575,6 +575,10 @@ class VarNode final: public GraphNodeBase { | |||||
| void assign_dev_tensor_from_tensor(const DeviceTensorND &value); | void assign_dev_tensor_from_tensor(const DeviceTensorND &value); | ||||
| #if MGB_ENABLE_JSON | |||||
| std::shared_ptr<json::Value> dump_static_infer_info_to_json() const; | |||||
| #endif | |||||
| friend class static_infer::StaticInferManagerImpl; | friend class static_infer::StaticInferManagerImpl; | ||||
| friend class VarNodeMemManager; | friend class VarNodeMemManager; | ||||
| friend class VarDevMemDefragmenter; | friend class VarDevMemDefragmenter; | ||||