Browse Source

board tensor for pynative infer

tags/v0.7.0-beta
Wei Luning 5 years ago
parent
commit
1d6c76f350
3 changed files with 77 additions and 48 deletions
  1. +12
    -5
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +32
    -24
      mindspore/core/abstract/abstract_value.cc
  3. +33
    -19
      mindspore/core/abstract/abstract_value.h

+ 12
- 5
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -285,12 +285,12 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe


void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info, void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
const abstract::AbstractBasePtrList &args_spec_list) { const abstract::AbstractBasePtrList &args_spec_list) {
MS_LOG(DEBUG) << "prim " << prim->name() << "input infer" << mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
prim->BeginRecordAddAttr(); prim->BeginRecordAddAttr();
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
prim->EndRecordAddAttr(); prim->EndRecordAddAttr();
op_exec_info->abstract = infer_res; op_exec_info->abstract = infer_res;
MS_LOG(DEBUG) << "prim " << prim->name() << "infer result " << op_exec_info->abstract->ToString();
MS_LOG(DEBUG) << "prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
} }


OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
@@ -632,7 +632,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
auto obj = op_exec_info->op_inputs[i]; auto obj = op_exec_info->op_inputs[i];
bool op_mask = py::hasattr(obj, "__parameter__"); bool op_mask = py::hasattr(obj, "__parameter__");
(*op_masks).push_back(op_mask); (*op_masks).push_back(op_mask);
MS_LOG(DEBUG) << "gen args i " << i << op_exec_info->op_name << " op mask" << op_mask << "grad_flag_" << grad_flag_;
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
<< grad_flag_;


AnfNodePtr node = nullptr; AnfNodePtr node = nullptr;
abstract::AbstractBasePtr abs = nullptr; abstract::AbstractBasePtr abs = nullptr;
@@ -646,11 +647,17 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if (node != nullptr && node->abstract() != nullptr) { if (node != nullptr && node->abstract() != nullptr) {
abs = node->abstract(); abs = node->abstract();
} }
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_value();
if (abs == nullptr || prim->is_const_value()) { if (abs == nullptr || prim->is_const_value()) {
MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
ValuePtr input_value = PyAttrValue(obj); ValuePtr input_value = PyAttrValue(obj);
bool broaden = !prim->is_const_value() && input_value->isa<tensor::Tensor>();
abs = abstract::FromValueInside(input_value, broaden);
abs = input_value->ToAbstract();
if (!prim->is_const_value()) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config);
MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config;
}
node_abs_map_[id] = abs; node_abs_map_[id] = abs;
} }
(*args_spec_list).push_back(abs); (*args_spec_list).push_back(abs);


+ 32
- 24
mindspore/core/abstract/abstract_value.cc View File

@@ -66,9 +66,12 @@ ValuePtr AbstractBase::BuildValue() const {
return value_; return value_;
} }


AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr AbstractBase::Broaden(uint8_t config) const {
AbstractBasePtr clone = Clone(); AbstractBasePtr clone = Clone();
clone->set_value(kAnyValue);
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
if (not_broaden == 0) {
clone->set_value(kAnyValue);
}
return clone; return clone;
} }


@@ -85,7 +88,7 @@ std::string AbstractBase::ToString() const {
return buffer.str(); return buffer.str();
} }


AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden(); }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }


AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other); MS_EXCEPTION_IF_NULL(other);
@@ -224,11 +227,11 @@ AbstractBasePtrList AbstractSequeue::ElementsClone() const {
return ele_list; return ele_list;
} }


AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
AbstractBasePtrList AbstractSequeue::ElementsBroaden(uint8_t config) const {
AbstractBasePtrList ele_list; AbstractBasePtrList ele_list;
for (const auto &ele : elements_) { for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele); MS_EXCEPTION_IF_NULL(ele);
AbstractBasePtr broadend = ele->Broaden();
AbstractBasePtr broadend = ele->Broaden(config);
ele_list.push_back(broadend); ele_list.push_back(broadend);
} }
return ele_list; return ele_list;
@@ -376,13 +379,13 @@ AbstractBasePtr AbstractSlice::Clone() const {
return std::make_shared<AbstractSlice>(start, stop, step); return std::make_shared<AbstractSlice>(start, stop, step);
} }


AbstractBasePtr AbstractSlice::Broaden() const {
AbstractBasePtr AbstractSlice::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_); MS_EXCEPTION_IF_NULL(step_);
AbstractBasePtr start = start_->Broaden();
AbstractBasePtr stop = stop_->Broaden();
AbstractBasePtr step = step_->Broaden();
AbstractBasePtr start = start_->Broaden(config);
AbstractBasePtr stop = stop_->Broaden(config);
AbstractBasePtr step = step_->Broaden(config);
return std::make_shared<AbstractSlice>(start, stop, step); return std::make_shared<AbstractSlice>(start, stop, step);
} }


@@ -506,12 +509,15 @@ AbstractBasePtr AbstractTensor::Clone() const {
return clone; return clone;
} }


AbstractBasePtr AbstractTensor::Broaden() const {
AbstractBasePtr AbstractTensor::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(element_); MS_EXCEPTION_IF_NULL(element_);
auto broaden = std::make_shared<AbstractTensor>(element_->Broaden()); auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
auto shp = shape(); auto shp = shape();
broaden->set_shape(shp->Clone()); broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
auto not_broaden = config & kBroadenParameterOnly;
if (not_broaden == 0) {
broaden->set_value(kAnyValue);
}
return broaden; return broaden;
} }


@@ -585,12 +591,12 @@ AbstractBasePtr AbstractDictionary::Clone() const {
return std::make_shared<AbstractDictionary>(kv); return std::make_shared<AbstractDictionary>(kv);
} }


AbstractBasePtr AbstractDictionary::Broaden() const {
AbstractBasePtr AbstractDictionary::Broaden(uint8_t config) const {
std::vector<AbstractAttribute> kv; std::vector<AbstractAttribute> kv;
(void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[config](const AbstractAttribute &item) {
MS_EXCEPTION_IF_NULL(item.second); MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first, item.second->Broaden());
return std::make_pair(item.first, item.second->Broaden(config));
}); });
return std::make_shared<AbstractDictionary>(kv); return std::make_shared<AbstractDictionary>(kv);
} }
@@ -711,11 +717,11 @@ AbstractBasePtr AbstractClass::Clone() const {
return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_); return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
} }


AbstractBasePtr AbstractClass::Broaden() const {
AbstractBasePtr AbstractClass::Broaden(uint8_t config) const {
std::vector<AbstractAttribute> attributes_clone; std::vector<AbstractAttribute> attributes_clone;
for (auto attr : attributes_) { for (auto attr : attributes_) {
MS_EXCEPTION_IF_NULL(attr.second); MS_EXCEPTION_IF_NULL(attr.second);
AbstractBasePtr clone = attr.second->Broaden();
AbstractBasePtr clone = attr.second->Broaden(config);
AbstractAttribute elem(attr.first, clone); AbstractAttribute elem(attr.first, clone);
attributes_clone.push_back(elem); attributes_clone.push_back(elem);
} }
@@ -843,9 +849,8 @@ TypePtr AbstractRef::BuildType() const {
} }


bool AbstractRef::operator==(const AbstractRef &other) const { bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) &&
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) &&
(!need_cast_ || (*target_type_ == *other.target_type_)); (!need_cast_ || (*target_type_ == *other.target_type_));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
} }


bool AbstractRef::operator==(const AbstractBase &other) const { bool AbstractRef::operator==(const AbstractBase &other) const {
@@ -921,9 +926,12 @@ std::string AbstractNone::ToString() const {


ValuePtr AbstractNone::RealBuildValue() const { return kNone; } ValuePtr AbstractNone::RealBuildValue() const { return kNone; }


AbstractBasePtr AbstractRefKey::Broaden() const {
AbstractBasePtr AbstractRefKey::Broaden(uint8_t config) const {
auto refkey = std::make_shared<AbstractRefKey>(); auto refkey = std::make_shared<AbstractRefKey>();
refkey->set_value(kAnyValue);
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
if (not_broaden == 0) {
refkey->set_value(kAnyValue);
}
return refkey; return refkey;
} }


@@ -1016,9 +1024,9 @@ AbstractBasePtr AbstractKeywordArg::Clone() const {
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone()); return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
} }


AbstractBasePtr AbstractKeywordArg::Broaden() const {
AbstractBasePtr AbstractKeywordArg::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(arg_value_); MS_EXCEPTION_IF_NULL(arg_value_);
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden(config));
} }


std::size_t AbstractKeywordArg::hash() const { std::size_t AbstractKeywordArg::hash() const {
@@ -1123,7 +1131,7 @@ AbstractBasePtr AbstractRowTensor::Clone() const {
return clone; return clone;
} }


AbstractBasePtr AbstractRowTensor::Broaden() const {
AbstractBasePtr AbstractRowTensor::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(element()); MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden()); auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
auto shp = shape(); auto shp = shape();
@@ -1182,7 +1190,7 @@ AbstractBasePtr AbstractSparseTensor::Clone() const {
return clone; return clone;
} }


AbstractBasePtr AbstractSparseTensor::Broaden() const {
AbstractBasePtr AbstractSparseTensor::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(element()); MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden()); auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape(); auto shp = shape();


+ 33
- 19
mindspore/core/abstract/abstract_value.h View File

@@ -69,7 +69,14 @@ class AbstractBase : public Base {
virtual TypePtr BuildType() const = 0; virtual TypePtr BuildType() const = 0;
virtual BaseShapePtr BuildShape() const { return kNoShape; } virtual BaseShapePtr BuildShape() const { return kNoShape; }
virtual AbstractBasePtr Clone() const = 0; virtual AbstractBasePtr Clone() const = 0;
virtual AbstractBasePtr Broaden() const;

// mask for Broaden config
inline static const uint8_t kBroadenTensorOnly = 1;
inline static const uint8_t kBroadenParameterOnly = 2;
// Each bit for on config.
// 00000001 -> 1: only boarden tensor
// 00000010 -> 2: only boarden parameter
virtual AbstractBasePtr Broaden(uint8_t config = 0) const;
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); } virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); }


friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) { friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) {
@@ -108,7 +115,7 @@ class AbstractScalar : public AbstractBase {
AbstractBasePtr Clone() const override { AbstractBasePtr Clone() const override {
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone()); return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
} }
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;
}; };
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>; using AbstractScalarPtr = std::shared_ptr<AbstractScalar>;
@@ -128,7 +135,7 @@ class AbstractType : public AbstractBase {


TypePtr BuildType() const override { return std::make_shared<TypeType>(); } TypePtr BuildType() const override { return std::make_shared<TypeType>(); }
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override { return Clone(); }
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
}; };
using AbstractTypePtr = std::shared_ptr<AbstractType>; using AbstractTypePtr = std::shared_ptr<AbstractType>;


@@ -143,7 +150,7 @@ class AbstractError : public AbstractBase {
MS_DECLARE_PARENT(AbstractError, AbstractBase) MS_DECLARE_PARENT(AbstractError, AbstractBase)


TypePtr BuildType() const override { return std::make_shared<Problem>(); } TypePtr BuildType() const override { return std::make_shared<Problem>(); }
AbstractBasePtr Broaden() const override { return Clone(); }
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }


AbstractBasePtr Clone() const override { AbstractBasePtr Clone() const override {
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_); return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_);
@@ -180,7 +187,7 @@ class AbstractFunction : public AbstractBase {
TypePtr BuildType() const override { return std::make_shared<Function>(); } TypePtr BuildType() const override { return std::make_shared<Function>(); }
AbstractBasePtr Clone() const override { return Copy(); } AbstractBasePtr Clone() const override { return Copy(); }
// For Function, no need to broaden. // For Function, no need to broaden.
AbstractBasePtr Broaden() const override {
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>(); return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
} }
virtual AbstractFunctionPtr Copy() const = 0; virtual AbstractFunctionPtr Copy() const = 0;
@@ -209,7 +216,7 @@ class AbstractKeywordArg : public AbstractBase {


TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::size_t hash() const override; std::size_t hash() const override;


bool operator==(const AbstractKeywordArg &other) const; bool operator==(const AbstractKeywordArg &other) const;
@@ -275,7 +282,7 @@ class AbstractTensor : public AbstractUndetermined {
TypePtr BuildType() const override; TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override; BaseShapePtr BuildShape() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const; AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other) final; AbstractBasePtr Join(const AbstractBasePtr &other) final;
int format() const { return this->format_; } int format() const { return this->format_; }
@@ -312,7 +319,7 @@ class AbstractSequeue : public AbstractBase {
TypePtrList ElementsType() const; TypePtrList ElementsType() const;
BaseShapePtrList ElementsShape() const; BaseShapePtrList ElementsShape() const;
AbstractBasePtrList ElementsClone() const; AbstractBasePtrList ElementsClone() const;
AbstractBasePtrList ElementsBroaden() const;
AbstractBasePtrList ElementsBroaden(uint8_t config = 0) const;


template <typename T> template <typename T>
ValuePtr ElementsBuildValue() const; ValuePtr ElementsBuildValue() const;
@@ -345,7 +352,9 @@ class AbstractTuple : public AbstractSequeue {


AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); } AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); }


AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractTuple>(ElementsBroaden(config));
}


AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); } AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); }


@@ -372,7 +381,9 @@ class AbstractList : public AbstractSequeue {


AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); } AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); }


AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractList>(ElementsBroaden(config));
}


AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); } AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); }


@@ -403,7 +414,7 @@ class AbstractClass : public AbstractBase {
AbstractBasePtr GetAttribute(const std::string &name); AbstractBasePtr GetAttribute(const std::string &name);
ValuePtr GetMethod(const std::string &name); ValuePtr GetMethod(const std::string &name);
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override; std::string ToString() const override;
Named tag() const { return tag_; } Named tag() const { return tag_; }
std::size_t hash() const override; std::size_t hash() const override;
@@ -428,7 +439,7 @@ class AbstractDictionary : public AbstractBase {
bool operator==(const AbstractDictionary &other) const; bool operator==(const AbstractDictionary &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override; std::string ToString() const override;
std::size_t hash() const override; std::size_t hash() const override;
std::size_t size() const { return key_values_.size(); } std::size_t size() const { return key_values_.size(); }
@@ -452,7 +463,7 @@ class AbstractSlice : public AbstractBase {
bool operator==(const AbstractSlice &other) const; bool operator==(const AbstractSlice &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override; std::string ToString() const override;
std::size_t hash() const override; std::size_t hash() const override;
AbstractBasePtr start() const { return start_; } AbstractBasePtr start() const { return start_; }
@@ -478,7 +489,9 @@ class AbstractJTagged : public AbstractBase {


TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); } AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); }
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractJTagged>(element_->Broaden(config));
}
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;


bool operator==(const AbstractJTagged &other) const; bool operator==(const AbstractJTagged &other) const;
@@ -558,7 +571,7 @@ class AbstractRefKey : public AbstractBase {
} }
RefKeyPtr ref_key_value() const { return ref_key_value_; } RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override; std::string ToString() const override;


private: private:
@@ -588,8 +601,9 @@ class AbstractRef : public AbstractBase {
inline RefKeyPtr ref_key_value() const { return ref_key_value_; } inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; } inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; } inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_);
AbstractBasePtr Broaden(uint8_t config = 0) const override {
// always broaden for ref
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_);
} }
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override { std::size_t hash() const override {
@@ -636,7 +650,7 @@ class AbstractRowTensor : public AbstractUndetermined {
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const; AbstractBasePtr BroadenWithShape() const;


std::string ToString() const override; std::string ToString() const override;
@@ -665,7 +679,7 @@ class AbstractSparseTensor : public AbstractUndetermined {
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const; AbstractBasePtr BroadenWithShape() const;


std::string ToString() const override; std::string ToString() const override;


Loading…
Cancel
Save