From 1d6c76f35061dbba317810c2f1a589b0e951b43e Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Tue, 18 Aug 2020 20:26:54 +0800 Subject: [PATCH] board tensor for pynative infer --- .../pipeline/pynative/pynative_execute.cc | 17 ++++-- mindspore/core/abstract/abstract_value.cc | 56 +++++++++++-------- mindspore/core/abstract/abstract_value.h | 52 ++++++++++------- 3 files changed, 77 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index e982e22a8d..c797594d99 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -285,12 +285,12 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info, const abstract::AbstractBasePtrList &args_spec_list) { - MS_LOG(DEBUG) << "prim " << prim->name() << "input infer" << mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list); prim->BeginRecordAddAttr(); AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); prim->EndRecordAddAttr(); op_exec_info->abstract = infer_res; - MS_LOG(DEBUG) << "prim " << prim->name() << "infer result " << op_exec_info->abstract->ToString(); + MS_LOG(DEBUG) << "prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); } OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { @@ -632,7 +632,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v auto obj = op_exec_info->op_inputs[i]; bool op_mask = py::hasattr(obj, "__parameter__"); (*op_masks).push_back(op_mask); - MS_LOG(DEBUG) << "gen args i " << i << op_exec_info->op_name << " op mask" << op_mask << "grad_flag_" << grad_flag_; + MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " + << grad_flag_; AnfNodePtr node = nullptr; abstract::AbstractBasePtr abs = nullptr; @@ -646,11 +647,17 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v if (node != nullptr && node->abstract() != nullptr) { abs = node->abstract(); } + MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " + << prim->is_const_value(); if (abs == nullptr || prim->is_const_value()) { MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; ValuePtr input_value = PyAttrValue(obj); - bool broaden = !prim->is_const_value() && input_value->isa(); - abs = abstract::FromValueInside(input_value, broaden); + abs = input_value->ToAbstract(); + if (!prim->is_const_value()) { + auto config = abstract::AbstractBase::kBroadenTensorOnly; + abs = abs->Broaden(config); + MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config; + } node_abs_map_[id] = abs; } (*args_spec_list).push_back(abs); diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 0fb6759d95..154122c5aa 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -66,9 +66,12 @@ ValuePtr AbstractBase::BuildValue() const { return value_; } -AbstractBasePtr AbstractBase::Broaden() const { +AbstractBasePtr AbstractBase::Broaden(uint8_t config) const { AbstractBasePtr clone = Clone(); - clone->set_value(kAnyValue); + auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly); + if (not_broaden == 0) { + clone->set_value(kAnyValue); + } return clone; } @@ -85,7 +88,7 @@ std::string AbstractBase::ToString() const { return buffer.str(); } -AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden(); } +AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); @@ -224,11 +227,11 @@ AbstractBasePtrList AbstractSequeue::ElementsClone() const { return ele_list; } -AbstractBasePtrList AbstractSequeue::ElementsBroaden() const { +AbstractBasePtrList AbstractSequeue::ElementsBroaden(uint8_t config) const { AbstractBasePtrList ele_list; for (const auto &ele : elements_) { MS_EXCEPTION_IF_NULL(ele); - AbstractBasePtr broadend = ele->Broaden(); + AbstractBasePtr broadend = ele->Broaden(config); ele_list.push_back(broadend); } return ele_list; @@ -376,13 +379,13 @@ AbstractBasePtr AbstractSlice::Clone() const { return std::make_shared(start, stop, step); } -AbstractBasePtr AbstractSlice::Broaden() const { +AbstractBasePtr AbstractSlice::Broaden(uint8_t config) const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); - AbstractBasePtr start = start_->Broaden(); - AbstractBasePtr stop = stop_->Broaden(); - AbstractBasePtr step = step_->Broaden(); + AbstractBasePtr start = start_->Broaden(config); + AbstractBasePtr stop = stop_->Broaden(config); + AbstractBasePtr step = step_->Broaden(config); return std::make_shared(start, stop, step); } @@ -506,12 +509,15 @@ AbstractBasePtr AbstractTensor::Clone() const { return clone; } -AbstractBasePtr AbstractTensor::Broaden() const { +AbstractBasePtr AbstractTensor::Broaden(uint8_t config) const { MS_EXCEPTION_IF_NULL(element_); auto broaden = std::make_shared(element_->Broaden()); auto shp = shape(); broaden->set_shape(shp->Clone()); - broaden->set_value(kAnyValue); + auto not_broaden = config & kBroadenParameterOnly; + if (not_broaden == 0) { + broaden->set_value(kAnyValue); + } return broaden; } @@ -585,12 +591,12 @@ AbstractBasePtr AbstractDictionary::Clone() const { return std::make_shared(kv); } -AbstractBasePtr AbstractDictionary::Broaden() const { +AbstractBasePtr AbstractDictionary::Broaden(uint8_t config) const { std::vector kv; (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const AbstractAttribute &item) { + [config](const AbstractAttribute &item) { MS_EXCEPTION_IF_NULL(item.second); - return std::make_pair(item.first, item.second->Broaden()); + return std::make_pair(item.first, item.second->Broaden(config)); }); return std::make_shared(kv); } @@ -711,11 +717,11 @@ AbstractBasePtr AbstractClass::Clone() const { return std::make_shared(tag_, attributes_clone, methods_); } -AbstractBasePtr AbstractClass::Broaden() const { +AbstractBasePtr AbstractClass::Broaden(uint8_t config) const { std::vector attributes_clone; for (auto attr : attributes_) { MS_EXCEPTION_IF_NULL(attr.second); - AbstractBasePtr clone = attr.second->Broaden(); + AbstractBasePtr clone = attr.second->Broaden(config); AbstractAttribute elem(attr.first, clone); attributes_clone.push_back(elem); } @@ -843,9 +849,8 @@ TypePtr AbstractRef::BuildType() const { } bool AbstractRef::operator==(const AbstractRef &other) const { - return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && + return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) && (!need_cast_ || (*target_type_ == *other.target_type_)); - // not compare the key for reuse the graph (*ref_key_ == *other.ref_key_); } bool AbstractRef::operator==(const AbstractBase &other) const { @@ -921,9 +926,12 @@ std::string AbstractNone::ToString() const { ValuePtr AbstractNone::RealBuildValue() const { return kNone; } -AbstractBasePtr AbstractRefKey::Broaden() const { +AbstractBasePtr AbstractRefKey::Broaden(uint8_t config) const { auto refkey = std::make_shared(); - refkey->set_value(kAnyValue); + auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly); + if (not_broaden == 0) { + refkey->set_value(kAnyValue); + } return refkey; } @@ -1016,9 +1024,9 @@ AbstractBasePtr AbstractKeywordArg::Clone() const { return std::make_shared(arg_name_, arg_value_->Clone()); } -AbstractBasePtr AbstractKeywordArg::Broaden() const { +AbstractBasePtr AbstractKeywordArg::Broaden(uint8_t config) const { MS_EXCEPTION_IF_NULL(arg_value_); - return std::make_shared(arg_name_, arg_value_->Broaden()); + return std::make_shared(arg_name_, arg_value_->Broaden(config)); } std::size_t AbstractKeywordArg::hash() const { @@ -1123,7 +1131,7 @@ AbstractBasePtr AbstractRowTensor::Clone() const { return clone; } -AbstractBasePtr AbstractRowTensor::Broaden() const { +AbstractBasePtr AbstractRowTensor::Broaden(uint8_t config) const { MS_EXCEPTION_IF_NULL(element()); auto broaden = std::make_shared(element()->Broaden()); auto shp = shape(); @@ -1182,7 +1190,7 @@ AbstractBasePtr AbstractSparseTensor::Clone() const { return clone; } -AbstractBasePtr AbstractSparseTensor::Broaden() const { +AbstractBasePtr AbstractSparseTensor::Broaden(uint8_t config) const { MS_EXCEPTION_IF_NULL(element()); auto broaden = std::make_shared(element()->Broaden()); auto shp = shape(); diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 2c6cf10d6c..3bdbd4b5ee 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -69,7 +69,14 @@ class AbstractBase : public Base { virtual TypePtr BuildType() const = 0; virtual BaseShapePtr BuildShape() const { return kNoShape; } virtual AbstractBasePtr Clone() const = 0; - virtual AbstractBasePtr Broaden() const; + + // mask for Broaden config + inline static const uint8_t kBroadenTensorOnly = 1; + inline static const uint8_t kBroadenParameterOnly = 2; + // Each bit for on config. + // 00000001 -> 1: only boarden tensor + // 00000010 -> 2: only boarden parameter + virtual AbstractBasePtr Broaden(uint8_t config = 0) const; virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base(); } friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &a) { @@ -108,7 +115,7 @@ class AbstractScalar : public AbstractBase { AbstractBasePtr Clone() const override { return std::make_shared(GetValueTrack(), GetTypeTrack()->Clone()); } - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr Join(const AbstractBasePtr &other) override; }; using AbstractScalarPtr = std::shared_ptr; @@ -128,7 +135,7 @@ class AbstractType : public AbstractBase { TypePtr BuildType() const override { return std::make_shared(); } AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override { return Clone(); } + AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); } }; using AbstractTypePtr = std::shared_ptr; @@ -143,7 +150,7 @@ class AbstractError : public AbstractBase { MS_DECLARE_PARENT(AbstractError, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } - AbstractBasePtr Broaden() const override { return Clone(); } + AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); } AbstractBasePtr Clone() const override { return std::make_shared(GetValueTrack()->cast(), node_); @@ -180,7 +187,7 @@ class AbstractFunction : public AbstractBase { TypePtr BuildType() const override { return std::make_shared(); } AbstractBasePtr Clone() const override { return Copy(); } // For Function, no need to broaden. - AbstractBasePtr Broaden() const override { + AbstractBasePtr Broaden(uint8_t config = 0) const override { return const_cast(this)->shared_from_base(); } virtual AbstractFunctionPtr Copy() const = 0; @@ -209,7 +216,7 @@ class AbstractKeywordArg : public AbstractBase { TypePtr BuildType() const override; AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; std::size_t hash() const override; bool operator==(const AbstractKeywordArg &other) const; @@ -275,7 +282,7 @@ class AbstractTensor : public AbstractUndetermined { TypePtr BuildType() const override; BaseShapePtr BuildShape() const override; AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr BroadenWithShape() const; AbstractBasePtr Join(const AbstractBasePtr &other) final; int format() const { return this->format_; } @@ -312,7 +319,7 @@ class AbstractSequeue : public AbstractBase { TypePtrList ElementsType() const; BaseShapePtrList ElementsShape() const; AbstractBasePtrList ElementsClone() const; - AbstractBasePtrList ElementsBroaden() const; + AbstractBasePtrList ElementsBroaden(uint8_t config = 0) const; template ValuePtr ElementsBuildValue() const; @@ -345,7 +352,9 @@ class AbstractTuple : public AbstractSequeue { AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } - AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } + AbstractBasePtr Broaden(uint8_t config = 0) const override { + return std::make_shared(ElementsBroaden(config)); + } AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } @@ -372,7 +381,9 @@ class AbstractList : public AbstractSequeue { AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } - AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } + AbstractBasePtr Broaden(uint8_t config = 0) const override { + return std::make_shared(ElementsBroaden(config)); + } AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } @@ -403,7 +414,7 @@ class AbstractClass : public AbstractBase { AbstractBasePtr GetAttribute(const std::string &name); ValuePtr GetMethod(const std::string &name); AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; std::string ToString() const override; Named tag() const { return tag_; } std::size_t hash() const override; @@ -428,7 +439,7 @@ class AbstractDictionary : public AbstractBase { bool operator==(const AbstractDictionary &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; std::string ToString() const override; std::size_t hash() const override; std::size_t size() const { return key_values_.size(); } @@ -452,7 +463,7 @@ class AbstractSlice : public AbstractBase { bool operator==(const AbstractSlice &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; std::string ToString() const override; std::size_t hash() const override; AbstractBasePtr start() const { return start_; } @@ -478,7 +489,9 @@ class AbstractJTagged : public AbstractBase { TypePtr BuildType() const override; AbstractBasePtr Clone() const override { return std::make_shared(element_->Clone()); } - AbstractBasePtr Broaden() const override { return std::make_shared(element_->Broaden()); } + AbstractBasePtr Broaden(uint8_t config = 0) const override { + return std::make_shared(element_->Broaden(config)); + } AbstractBasePtr Join(const AbstractBasePtr &other) override; bool operator==(const AbstractJTagged &other) const; @@ -558,7 +571,7 @@ class AbstractRefKey : public AbstractBase { } RefKeyPtr ref_key_value() const { return ref_key_value_; } AbstractBasePtr Join(const AbstractBasePtr &other) override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; std::string ToString() const override; private: @@ -588,8 +601,9 @@ class AbstractRef : public AbstractBase { inline RefKeyPtr ref_key_value() const { return ref_key_value_; } inline TypePtr target_type() const { return target_type_; } inline bool need_cast() const { return need_cast_; } - AbstractBasePtr Broaden() const override { - return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_); + AbstractBasePtr Broaden(uint8_t config = 0) const override { + // always broaden for ref + return std::make_shared(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_); } AbstractBasePtr Join(const AbstractBasePtr &other) override; std::size_t hash() const override { @@ -636,7 +650,7 @@ class AbstractRowTensor : public AbstractUndetermined { void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } TypePtr BuildType() const override; AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr BroadenWithShape() const; std::string ToString() const override; @@ -665,7 +679,7 @@ class AbstractSparseTensor : public AbstractUndetermined { void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } TypePtr BuildType() const override; AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; + AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr BroadenWithShape() const; std::string ToString() const override;