Also, add a CONSTANT value inference tag to outputs of
MultipleDeviceTensorHolder.
GitOrigin-RevId: 82a805ed5f
tags/v1.0.0-rc1
| @@ -570,10 +570,10 @@ void ParamFusePass::apply(OptState &state) const { | |||
| *var->owner_graph(), hv, var_namer.name(var)); | |||
| } else { | |||
| if (is_default_format) { | |||
| new_var = opr::SharedDeviceTensor::make( | |||
| new_var = opr::SharedDeviceTensor::make_const( | |||
| *var->owner_graph(), inferred_val, var_namer.name(var)); | |||
| } else { | |||
| new_var = opr::SharedDeviceTensorWithFormat::make( | |||
| new_var = opr::SharedDeviceTensorWithFormat::make_const( | |||
| *var->owner_graph(), inferred_val, var_namer.name(var)); | |||
| } | |||
| } | |||
| @@ -281,11 +281,11 @@ void Host2DeviceCopy::record_execute_deps(ExecDependencyArray& deps) { | |||
| /* ===================== SharedDeviceTensor related ===================== */ | |||
| intl::SharedDeviceTensorBase::SharedDeviceTensorBase( | |||
| ComputingGraph &graph, const std::shared_ptr<DeviceTensorND> &dev_data, | |||
| const OperatorNodeConfig &config): | |||
| Super{&graph, config, "shared", {}}, | |||
| m_dev_data{dev_data} | |||
| { | |||
| ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| bool const_value, const OperatorNodeConfig& config) | |||
| : Super{&graph, config, "shared", {}}, | |||
| m_dev_data{dev_data}, | |||
| m_const_value(const_value) { | |||
| if (config.has_comp_node_set()) { | |||
| mgb_assert(config.get_single_comp_node() == dev_data->comp_node()); | |||
| } | |||
| @@ -307,26 +307,42 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { | |||
| comp_node(m_dev_data->comp_node()); | |||
| } | |||
| bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) { | |||
| if (m_const_value) { | |||
| if (dest) { | |||
| if (m_static_infer.empty()) { | |||
| m_static_infer.comp_node(CompNode::default_cpu()) | |||
| .copy_from(*m_dev_data); | |||
| } | |||
| *dest = m_static_infer; | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { | |||
| return cg::static_infer::SourceType::CONSTANT; | |||
| } | |||
| SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | |||
| const std::shared_ptr<DeviceTensorND> &dev_data, | |||
| bool const_value, | |||
| const OperatorNodeConfig &config) { | |||
| return graph.insert_opr(std::make_unique<SharedDeviceTensor>( | |||
| graph, dev_data, config))->output(0); | |||
| graph, dev_data, const_value, config))->output(0); | |||
| } | |||
| SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | |||
| const HostTensorND &value, | |||
| bool const_value, | |||
| const OperatorNodeConfig &config) { | |||
| auto cn = value.comp_node(); | |||
| if (config.has_comp_node_set()) | |||
| cn = config.get_single_comp_node(); | |||
| auto dev_v = std::make_shared<DeviceTensorND>(); | |||
| dev_v->comp_node(cn).copy_from(value).sync(); | |||
| return make(graph, dev_v, config); | |||
| return make(graph, dev_v, const_value, config); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensor); | |||
| @@ -342,7 +358,7 @@ SymbolVar VolatileSharedDeviceTensor::make(ComputingGraph &graph, | |||
| const std::shared_ptr<DeviceTensorND> &dev_data, | |||
| const OperatorNodeConfig &config) { | |||
| return graph.insert_opr(std::make_unique<VolatileSharedDeviceTensor>( | |||
| graph, dev_data, config))->output(0); | |||
| graph, dev_data, false, config))->output(0); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(VolatileSharedDeviceTensor); | |||
| @@ -354,10 +370,10 @@ void SharedDeviceTensorWithFormat::init_output_format() { | |||
| SymbolVar SharedDeviceTensorWithFormat::make( | |||
| ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config) { | |||
| bool const_value, const OperatorNodeConfig& config) { | |||
| auto&& opr = | |||
| graph.insert_opr(std::make_unique<SharedDeviceTensorWithFormat>( | |||
| graph, dev_data, config)) | |||
| graph, dev_data, const_value, config)) | |||
| ->cast_final_safe<SharedDeviceTensorWithFormat>(); | |||
| return opr.output(0); | |||
| } | |||
| @@ -870,6 +886,24 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { | |||
| }; | |||
| mgr.register_shape_infer(output(i), | |||
| {SourceType::CONSTANT, {}, infer_shp}); | |||
| auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) { | |||
| if (m_host_values.empty()) { | |||
| m_host_values.resize(m_values.size()); | |||
| } | |||
| if (m_host_values[i].empty()) { | |||
| m_host_values[i] | |||
| .comp_node(CompNode::default_cpu()) | |||
| .copy_from(*m_values[i]); | |||
| } | |||
| if (!m_host_values[i].empty()) { | |||
| dest = m_host_values[i]; | |||
| return true; | |||
| } | |||
| return false; | |||
| }; | |||
| mgr.register_value_infer(output(i), | |||
| {SourceType::CONSTANT, {}, infer_val}); | |||
| } | |||
| } | |||
| @@ -79,6 +79,10 @@ namespace serialization { | |||
| HostTensorND val; | |||
| val.copy_from(opr.get_dev_tensor()).sync(); | |||
| ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED); | |||
| // Note that we don't persist opr.m_const_value, because it does not | |||
| // affect correctness, and SharedDeviceTensor will be bundled | |||
| // together as MultipleDeviceTensorHolder in optimize_for_inference | |||
| // before being dumped. | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| @@ -280,9 +284,10 @@ namespace opr { | |||
| const OperatorNodeConfig &config) { | |||
| mgb_assert(inputs.empty()); | |||
| auto &&opr = opr_.cast_final_safe<Opr>(); | |||
| return Opr::make( | |||
| *ctx.owner_graph(opr, inputs), opr.dev_data(), config). | |||
| node()->owner_opr(); | |||
| return Opr::make(*ctx.owner_graph(opr, inputs), opr.dev_data(), | |||
| opr.const_value(), config) | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor( | |||
| @@ -75,19 +75,22 @@ class DeviceTensorHolder: public HostIONodeBase { | |||
| */ | |||
| MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
| std::shared_ptr<DeviceTensorND> m_dev_data; | |||
| DeviceTensorND m_static_infer; | |||
| bool m_const_value; | |||
| const TensorShape& get_output_shape() override; | |||
| bool fill_in_static_infer(DeviceTensorND* dest) override { | |||
| MGB_MARK_USED_VAR(dest); | |||
| return false; | |||
| } | |||
| bool fill_in_static_infer(DeviceTensorND* dest) override; | |||
| void init_output_comp_node() override; | |||
| public: | |||
| //! const_value marks whether the device value of this operator should | |||
| //! be treated as constant during graph execution. Should be false in | |||
| //! most cases. | |||
| SharedDeviceTensorBase(ComputingGraph &graph, | |||
| const std::shared_ptr<DeviceTensorND> &dev_data, | |||
| bool const_value, | |||
| const OperatorNodeConfig &config); | |||
| const DeviceTensorND& get_dev_tensor() const override { | |||
| @@ -97,6 +100,8 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
| const std::shared_ptr<DeviceTensorND>& dev_data() const { | |||
| return m_dev_data; | |||
| } | |||
| bool const_value() const { return m_const_value; } | |||
| }; | |||
| /*! | |||
| @@ -104,6 +109,7 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
| * device tensors | |||
| * | |||
| * This opr is used to speed up inference by packing params together. | |||
| * This operator assumes the device tensors are constant. | |||
| */ | |||
| MGB_DEFINE_CLS_WITH_SUPER(MultipleDeviceTensorHolderBase, | |||
| cg::OperatorNodeBase) // { | |||
| @@ -125,6 +131,8 @@ private: | |||
| void init_output_comp_node() override; | |||
| void init_output_static_infer_desc() override; | |||
| NodeProp* do_make_node_prop() const override; | |||
| SmallVector<DeviceTensorND> m_host_values; | |||
| }; | |||
| } // namespace intl | |||
| @@ -249,16 +257,43 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { | |||
| public: | |||
| using Super::Super; | |||
| static SymbolVar make(ComputingGraph &graph, | |||
| const std::shared_ptr<DeviceTensorND> &dev_data, | |||
| const OperatorNodeConfig &config = {}); | |||
| static SymbolVar make(ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| bool const_value, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVar make(ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config = {}) { | |||
| return make(graph, dev_data, false, config); | |||
| } | |||
| static SymbolVar make_const( | |||
| ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config = {}) { | |||
| return make(graph, dev_data, true, config); | |||
| } | |||
| /*! | |||
| * \brief make a SharedDeviceTensor by first coping from host to device | |||
| * | |||
| * See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value. | |||
| */ | |||
| static SymbolVar make(ComputingGraph &graph, | |||
| const HostTensorND &value, | |||
| const OperatorNodeConfig &config = {}); | |||
| static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, | |||
| bool const_value, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, | |||
| const OperatorNodeConfig& config = {}) { | |||
| return make(graph, value, false, config); | |||
| } | |||
| static SymbolVar make_const(ComputingGraph& graph, | |||
| const HostTensorND& value, | |||
| const OperatorNodeConfig& config = {}) { | |||
| return make(graph, value, false, config); | |||
| } | |||
| }; | |||
| /*! | |||
| @@ -276,7 +311,19 @@ public: | |||
| static SymbolVar make(ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config = {}); | |||
| bool const_value, const OperatorNodeConfig& config); | |||
| static SymbolVar make(ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config = {}) { | |||
| return make(graph, dev_data, false, config); | |||
| } | |||
| static SymbolVar make_const(ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config = {}) { | |||
| return make(graph, dev_data, true, config); | |||
| } | |||
| }; | |||
| /*! | |||
| @@ -297,6 +344,15 @@ MGB_DEFINE_OPR_CLASS( | |||
| static SymbolVar make(ComputingGraph &graph, | |||
| const std::shared_ptr<DeviceTensorND> &dev_data, | |||
| const OperatorNodeConfig &config = {}); | |||
| //! adapter for io.sereg.h: opr_shallow_copy_shared_device_tensor | |||
| static SymbolVar make(ComputingGraph& graph, | |||
| const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| bool const_value, | |||
| const OperatorNodeConfig& config) { | |||
| mgb_assert(!const_value); | |||
| return make(graph, dev_data, false, config); | |||
| } | |||
| }; | |||
| /*! | |||