From: @zhupuxu Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -100,13 +100,13 @@ void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode, | |||||
| void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map, | void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map, | ||||
| NotNull<std::map<std::string, optiling::TeConstTensorData> *> const_inputs) { | NotNull<std::map<std::string, optiling::TeConstTensorData> *> const_inputs) { | ||||
| MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope(); | MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope(); | ||||
| if (!AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode.get())) { | |||||
| auto depends_list_me = abstract::GetDependsFormMap(cnode); | |||||
| if (depends_list_me.empty()) { | |||||
| MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope(); | MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope(); | ||||
| return; | return; | ||||
| } | } | ||||
| std::vector<int> depends_list; | std::vector<int> depends_list; | ||||
| std::vector<int64_t> depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode.get(), kDynamicShapeDepends); | |||||
| (void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depends_list), | (void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depends_list), | ||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| for (auto index : depends_list) { | for (auto index : depends_list) { | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "register/op_tiling.h" | #include "register/op_tiling.h" | ||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -39,16 +39,14 @@ void DynamicKernel::Initialize() { | |||||
| is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape); | is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape); | ||||
| is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape); | is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape); | ||||
| auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_); | |||||
| if (!have_depends) { | |||||
| auto ret = abstract::GetDependsFormMap(cnode_ptr_); | |||||
| if (ret.empty()) { | |||||
| MS_LOG(DEBUG) << "No dynamic_shape_depends found"; | MS_LOG(DEBUG) << "No dynamic_shape_depends found"; | ||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Have depends"; | MS_LOG(INFO) << "Have depends"; | ||||
| std::vector<int64_t> depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode_ptr_, kDynamicShapeDepends); | |||||
| (void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depend_list_), | |||||
| (void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| MS_LOG(INFO) << "Init End"; | MS_LOG(INFO) << "Init End"; | ||||
| } | } | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -22,6 +22,25 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||||
| constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum"; | |||||
| constexpr auto kUnsortedSegmentMin = "UnsortedSegmentMin"; | |||||
| constexpr auto kUnsortedSegmentMax = "UnsortedSegmentMax"; | |||||
| static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = { | |||||
| {kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}}; | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->inputs().empty()) { | |||||
| MS_LOG(EXCEPTION) << "Invalid inputs"; | |||||
| } | |||||
| auto primitive = GetValueNode<PrimitivePtr>(cnode->inputs()[0]); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto iter = dynamic_shape_depends.find(primitive->ToString()); | |||||
| if (iter != dynamic_shape_depends.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | ||||
| static PrimitiveEvalImplMap prim_eval_implement_map = { | static PrimitiveEvalImplMap prim_eval_implement_map = { | ||||
| // Statements | // Statements | ||||
| @@ -21,6 +21,8 @@ | |||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "ir/anf.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| @@ -35,6 +37,8 @@ using PrimitiveEvalImplMap = | |||||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | ||||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); | |||||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | ||||
| class RegisterStandardPrimitiveEvalHelper { | class RegisterStandardPrimitiveEvalHelper { | ||||
| @@ -1892,7 +1892,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize UnsortedSegmentSum""" | """Initialize UnsortedSegmentSum""" | ||||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | ||||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||||
| def __infer__(self, x, segment_ids, num_segments): | def __infer__(self, x, segment_ids, num_segments): | ||||
| x_type = x['dtype'] | x_type = x['dtype'] | ||||