GitOrigin-RevId: 2232195c50
tags/v0.5.0
| @@ -485,14 +485,14 @@ OperatorNodeConfig& OperatorNodeConfig::comp_node_arr( | |||
| size_t OperatorNodeConfig::hash() const { | |||
| return hash_pair_combine( | |||
| hash_pair_combine(mgb::hash(m_instance_id), mgb::hash(m_comp_node)), | |||
| hash_pair_combine(m_instance_id_hashed, mgb::hash(m_comp_node)), | |||
| mgb::hash(m_output_dtype.handle())); | |||
| } | |||
| bool OperatorNodeConfig::is_same_st(const Hashable &rhs_) const { | |||
| auto &&rhs = static_cast<const OperatorNodeConfig&>(rhs_); | |||
| return m_comp_node == rhs.m_comp_node && | |||
| m_instance_id == rhs.m_instance_id && | |||
| m_instance_id_hashed == rhs.m_instance_id_hashed && | |||
| m_output_dtype == rhs.m_output_dtype; | |||
| } | |||
| @@ -1225,14 +1225,17 @@ bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { | |||
| OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( | |||
| OperatorNodeBase* opr, bool recomp) { | |||
| auto config = opr->config(); | |||
| // set operator instance id to bybass the shallow copy's cache if | |||
| // update operator instance id to bybass the shallow copy's cache if | |||
| // it's a dup-opr-copying due to discarding. | |||
| // Don't set instance id(nullptr) if it's a recomp-opr-copying, because: | |||
| // Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||
| // because: | |||
| // 0) recomp-opr would be copied iff its input vars is changed | |||
| // 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||
| // and config, we use instance id to differentiate them. | |||
| config.name(opr->name() + (recomp ? ":recomp" : ":dup")) | |||
| .instance_id(recomp ? nullptr : this); | |||
| config.name(opr->name() + (recomp ? ":recomp" : ":dup")); | |||
| if (!recomp) { | |||
| config.update_instance_id(this); | |||
| } | |||
| // Note: if all outputs of op were placed on the same comp_node, since its | |||
| // stream maybe changed during seq_comp_node_opt, output's comp_node has | |||
| @@ -70,24 +70,36 @@ class OperatorNodeConfig final: public Hashable { | |||
| } | |||
| /*! | |||
| * \brief set instance id | |||
| * \brief update instance ID | |||
| * | |||
| * Instance id is used to differentiate multiple instances of the same | |||
| * operator (with same inputs, params and config), so the deduplication | |||
| * system can be bypassed. | |||
| * Instance ID is a hashed value used to differentiate multiple | |||
| * instances of the same operator (with same inputs, params and | |||
| * config), so the deduplication system can be bypassed. | |||
| * | |||
| * Currently only used for sublinear memory optimization. | |||
| * This method always updates underlying instance_id. | |||
| */ | |||
| OperatorNodeConfig& instance_id(const void *id) { | |||
| m_instance_id = id; | |||
| template<typename T> | |||
| OperatorNodeConfig& update_instance_id(const T& p) { | |||
| static_assert(std::is_pointer<T>::value, | |||
| "update_instance_id can only accept a pointer"); | |||
| m_instance_id_hashed = hash_pair_combine( | |||
| m_instance_id_hashed, mgb::hash(p)); | |||
| return *this; | |||
| } | |||
| /*! | |||
| * \brief get current instance ID | |||
| * \brief reset instance ID to the initial value | |||
| */ | |||
| const void* instance_id() const { | |||
| return m_instance_id; | |||
| OperatorNodeConfig& reset_instance_id() { | |||
| m_instance_id_hashed = sm_initial_instance_id; | |||
| return *this; | |||
| } | |||
| /*! | |||
| * \brief get current hashed instance ID | |||
| */ | |||
| size_t instance_id() const { | |||
| return m_instance_id_hashed; | |||
| } | |||
| /*! | |||
| @@ -133,9 +145,10 @@ class OperatorNodeConfig final: public Hashable { | |||
| bool is_same_st(const Hashable &rhs) const override; | |||
| private: | |||
| static constexpr size_t sm_initial_instance_id = 1333331; | |||
| Maybe<std::string> m_name; | |||
| CompNodeArray m_comp_node; | |||
| const void *m_instance_id = nullptr; | |||
| size_t m_instance_id_hashed = sm_initial_instance_id; | |||
| DType m_output_dtype; | |||
| }; | |||
| @@ -1777,4 +1777,41 @@ TEST(TestGraph, In2OutOpStreamPropagate) { | |||
| } | |||
| } | |||
| TEST(TestGraph, OperatorNodeConfigInstanceID) { | |||
| OperatorNodeConfig config0, config1; | |||
| void *p0 = &config0, *p1 = &config1; | |||
| { // set and reset | |||
| ASSERT_EQ(config0.instance_id(), config1.instance_id()); | |||
| config0.update_instance_id(p0); | |||
| ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||
| config0.reset_instance_id(); | |||
| ASSERT_EQ(config0.instance_id(), config1.instance_id()); | |||
| } | |||
| { // set to the same pointer | |||
| config0.reset_instance_id(); | |||
| config0.update_instance_id(p1); | |||
| config1.reset_instance_id(); | |||
| config1.update_instance_id(p1); | |||
| ASSERT_EQ(config0.instance_id(), config1.instance_id()); | |||
| } | |||
| { // check update semantics | |||
| config0.reset_instance_id(); | |||
| config0.update_instance_id(p0); | |||
| config1.reset_instance_id(); | |||
| config1.update_instance_id(p1); | |||
| ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||
| config0.update_instance_id(p1); | |||
| ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||
| } | |||
| { // set in different order | |||
| config0.reset_instance_id(); | |||
| config0.update_instance_id(p1); | |||
| config0.update_instance_id(p0); | |||
| config1.reset_instance_id(); | |||
| config1.update_instance_id(p0); | |||
| config1.update_instance_id(p1); | |||
| ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -361,7 +361,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const { | |||
| size_t prev_step = iter->second; | |||
| if (step - prev_step > m_threshold) { | |||
| OperatorNodeConfig config = opr->config(); | |||
| config.instance_id(opr); | |||
| config.update_instance_id(opr); | |||
| opt.call_with_opr(typecvt, [&]{ | |||
| auto new_typecvt = | |||
| opr::TypeCvt::make( | |||
| @@ -261,7 +261,7 @@ TEST_PASS(RecompTypeCvtPass, Basic) { | |||
| } | |||
| auto for_pass = f + x_fp32; | |||
| OperatorNodeConfig config = x_fp32.node()->owner_opr()->config(); | |||
| config.instance_id(for_pass.node()->owner_opr()); | |||
| config.update_instance_id(for_pass.node()->owner_opr()); | |||
| auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(), | |||
| config); | |||
| @@ -92,8 +92,8 @@ VarNode* InternalGraphGenerator::replace_graph_by_placeholder() { | |||
| auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, | |||
| const VarNodeArray& inputs) { | |||
| OperatorNodeConfig config = opr->config(); | |||
| // remove instance_id. | |||
| config.instance_id(nullptr); | |||
| // reset instance_id. | |||
| config.reset_instance_id(); | |||
| if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) { | |||
| HostTensorND hval{cpu_default}; | |||
| hval.copy_from(imm->value()).sync(); | |||