| @@ -31,6 +31,7 @@ | |||||
| #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" | #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" | #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.h" | |||||
| #include "backend/optimizer/pass/communication_op_fusion.h" | #include "backend/optimizer/pass/communication_op_fusion.h" | ||||
| #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" | ||||
| #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" | ||||
| @@ -179,6 +180,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConcatFission>()); | ir_fusion_pm->AddPass(std::make_shared<ConcatFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>()); | ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>()); | ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GatherV2DsFission>()); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -0,0 +1,177 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| // only pad operator can run in dynamic shape. | |||||
| CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(origin_node); | |||||
| std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), origin_node->input(1)}; | |||||
| auto pad = graph->NewCNode(pad_inputs); | |||||
| MS_EXCEPTION_IF_NULL(pad); | |||||
| pad->set_scope(origin_node->scope()); | |||||
| auto param_abstract_shape = origin_node->input(1)->Shape(); | |||||
| MS_EXCEPTION_IF_NULL(param_abstract_shape); | |||||
| if (!param_abstract_shape->isa<abstract::Shape>()) { | |||||
| MS_LOG(EXCEPTION) << "Gatherv2 's first input has wrong shape type"; | |||||
| } | |||||
| auto param_dyn_shape = param_abstract_shape->cast<abstract::ShapePtr>(); | |||||
| ShapeVector shape(param_dyn_shape->shape()); | |||||
| if (shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Gatherv2 's shape is empty"; | |||||
| } | |||||
| if (shape[shape.size() - 1] == -1) { | |||||
| MS_LOG(EXCEPTION) << "Dim needs pad should not be dynamic"; | |||||
| } | |||||
| shape[shape.size() - 1] = pad_dim_size; | |||||
| auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0); | |||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), shape); | |||||
| if (param_dyn_shape->max_shape().size() == param_dyn_shape->shape().size() && | |||||
| param_dyn_shape->min_shape().size() == param_dyn_shape->shape().size()) { | |||||
| ShapeVector max_shape(param_dyn_shape->max_shape()); | |||||
| ShapeVector min_shape(param_dyn_shape->min_shape()); | |||||
| ShapeVector new_shape(shape); | |||||
| max_shape[max_shape.size() - 1] = pad_dim_size; | |||||
| min_shape[min_shape.size() - 1] = pad_dim_size; | |||||
| abstract->set_shape(std::make_shared<abstract::Shape>(new_shape, min_shape, max_shape)); | |||||
| } | |||||
| pad->set_abstract(abstract); | |||||
| std::vector<ValuePtr> elements; | |||||
| for (size_t i = 0; i < shape.size() - 1; ++i) { | |||||
| ShapeVector padding_vector(2); | |||||
| auto padding_value = MakeValue(padding_vector); | |||||
| elements.push_back(padding_value); | |||||
| } | |||||
| ShapeVector last_padding_vector = {0, SizeToLong(pad_dim_size - 1)}; | |||||
| auto last_padding_value = MakeValue(last_padding_vector); | |||||
| elements.push_back(last_padding_value); | |||||
| ValueTuplePtr paddings = std::make_shared<ValueTuple>(elements); | |||||
| AnfAlgo::SetNodeAttr(kAttrPaddings, paddings, pad); | |||||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), pad); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), pad); | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), pad); | |||||
| return pad; | |||||
| } | |||||
| CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &pad, | |||||
| const size_t &pad_dim_size) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(origin_node); | |||||
| MS_EXCEPTION_IF_NULL(pad); | |||||
| if (origin_node->size() != 4) { | |||||
| MS_LOG(EXCEPTION) << "In dynamic shape scene, gatherv2 should have 3 inputs"; | |||||
| } | |||||
| std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGatherV2->name())), | |||||
| pad, origin_node->input(2), origin_node->input(3)}; | |||||
| auto gather_v2 = graph->NewCNode(gatherv2_inputs); | |||||
| MS_EXCEPTION_IF_NULL(gather_v2); | |||||
| gather_v2->set_scope(origin_node->scope()); | |||||
| auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0); | |||||
| shape[shape.size() - 1] = pad_dim_size; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape}, gather_v2.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), gather_v2); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), gather_v2); | |||||
| auto depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(origin_node, kAttrDynamicShapeDepends); | |||||
| AnfAlgo::SetNodeAttr(kAttrDynamicShapeDepends, MakeValue(depends_list_me), gather_v2); | |||||
| auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(origin_node, kAttrInputNames); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), gather_v2); | |||||
| auto output_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(origin_node, kAttrOutputNames); | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), gather_v2); | |||||
| return gather_v2; | |||||
| } | |||||
| CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const CNodePtr &gather_v2_padding_8) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(gather_v2); | |||||
| MS_EXCEPTION_IF_NULL(gather_v2_padding_8); | |||||
| std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), gather_v2_padding_8}; | |||||
| auto slice = graph->NewCNode(slice_inputs); | |||||
| MS_EXCEPTION_IF_NULL(slice); | |||||
| slice->set_scope(gather_v2->scope()); | |||||
| slice->set_abstract(gather_v2->abstract()); | |||||
| auto gather_v2_shape = AnfAlgo::GetOutputInferShape(gather_v2, 0); | |||||
| std::vector<size_t> offsets(gather_v2_shape.size(), 0); | |||||
| AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice); | |||||
| AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(gather_v2_shape)), slice); | |||||
| return slice; | |||||
| } | |||||
| bool CheckInputs(const CNodePtr &origin_node) { | |||||
| MS_EXCEPTION_IF_NULL(origin_node); | |||||
| if (origin_node->size() != kGatherV2DynInputNum + 1) { | |||||
| MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputNum | |||||
| << ". CNode= " << origin_node->DebugString(); | |||||
| return false; | |||||
| } | |||||
| auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); | |||||
| auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1); | |||||
| // this optimizer only support embedding_table has dynamic shape | |||||
| if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(2))) { | |||||
| return false; | |||||
| } | |||||
| if (param_shape[param_shape.size() - 1] != 1) { | |||||
| MS_LOG(DEBUG) << "GatherV2 in dynamic shape is not need fission. The last value of input0's shape is " | |||||
| << param_shape[param_shape.size() - 1]; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef GatherV2DsFission::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VectorRef pattern({prim::kPrimGatherV2, Xs}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr GatherV2DsFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto origin_node = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(origin_node); | |||||
| if (!CheckInputs(origin_node)) { | |||||
| return nullptr; | |||||
| } | |||||
| size_t pad_dim_size; | |||||
| auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0); | |||||
| if (input_dtype == kNumberTypeFloat32) { | |||||
| pad_dim_size = 8; | |||||
| } else if (input_dtype == kNumberTypeFloat16) { | |||||
| pad_dim_size = 16; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "GatherV2 data type not in (float32, float16), no need change"; | |||||
| return nullptr; | |||||
| } | |||||
| CNodePtr gather_v2_8; | |||||
| auto pad = CreatePad(graph, origin_node, pad_dim_size); | |||||
| gather_v2_8 = CreateGatherV2Ds(graph, origin_node, pad, pad_dim_size); | |||||
| return CreateSlice(graph, origin_node, gather_v2_8); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_GATHER_V2_DS_FISSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_GATHER_V2_DS_FISSION_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class GatherV2DsFission : public PatternProcessPass { | |||||
| public: | |||||
| explicit GatherV2DsFission(bool multigraph = true) : PatternProcessPass("gather_v2_ds_fission", multigraph) {} | |||||
| ~GatherV2DsFission() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_GATHER_V2_DS_FISSION_H_ | |||||
| @@ -98,6 +98,7 @@ constexpr size_t kTopkInputNum = 3; | |||||
| constexpr size_t kLarsV2InputNum = 5; | constexpr size_t kLarsV2InputNum = 5; | ||||
| constexpr size_t kFusedMulApplyMomentumOutputNum = 2; | constexpr size_t kFusedMulApplyMomentumOutputNum = 2; | ||||
| constexpr size_t kSplitInputNum = 2; | constexpr size_t kSplitInputNum = 2; | ||||
| constexpr size_t kGatherV2DynInputNum = 3; | |||||
| constexpr size_t kUnsortedSegmentSumInputNum = 2; | constexpr size_t kUnsortedSegmentSumInputNum = 2; | ||||
| enum FusedBatchNormInput { | enum FusedBatchNormInput { | ||||
| @@ -148,6 +148,7 @@ std::string GetRealOpType(const std::string &op_type) { | |||||
| {"SparseApplyFtrl", "SparseApplyFtrlD"}, | {"SparseApplyFtrl", "SparseApplyFtrlD"}, | ||||
| {"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"}, | {"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"}, | ||||
| {"SparseGatherV2", "GatherV2"}, | {"SparseGatherV2", "GatherV2"}, | ||||
| {"Pad", "PadD"}, | |||||
| }; | }; | ||||
| auto iter = kOpTypeMap.find(op_type); | auto iter = kOpTypeMap.find(op_type); | ||||
| if (iter == kOpTypeMap.end()) { | if (iter == kOpTypeMap.end()) { | ||||
| @@ -322,12 +322,14 @@ constexpr auto kAttrT = "T"; | |||||
| constexpr auto kAttrNum = "num"; | constexpr auto kAttrNum = "num"; | ||||
| constexpr auto kAttrRankSize = "rank_size"; | constexpr auto kAttrRankSize = "rank_size"; | ||||
| constexpr auto kAttrPadDimSize = "pad_dim_size"; | constexpr auto kAttrPadDimSize = "pad_dim_size"; | ||||
| constexpr auto kAttrPaddings = "paddings"; | |||||
| constexpr auto kAttrNumSegments = "num_segments"; | constexpr auto kAttrNumSegments = "num_segments"; | ||||
| constexpr auto kAttrBegin = "begin"; | constexpr auto kAttrBegin = "begin"; | ||||
| constexpr auto kAttrSize = "size"; | constexpr auto kAttrSize = "size"; | ||||
| constexpr auto kAttrIsDynamicShape = "is_dynamic_shape"; | constexpr auto kAttrIsDynamicShape = "is_dynamic_shape"; | ||||
| constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape"; | constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape"; | ||||
| constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape"; | constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape"; | ||||
| constexpr auto kAttrDynamicShapeDepends = "dynamic_shape_depends"; | |||||
| constexpr auto kAttrPynativeNextOpName = "next_op"; | constexpr auto kAttrPynativeNextOpName = "next_op"; | ||||
| constexpr auto kAttrPynativeNextIndex = "next_index"; | constexpr auto kAttrPynativeNextIndex = "next_index"; | ||||
| constexpr auto kAttrCompileInfo = "compile_info"; | constexpr auto kAttrCompileInfo = "compile_info"; | ||||
| @@ -251,7 +251,8 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tuple or list or dict. | // Inputs: a tuple or list or dict. | ||||
| @@ -470,5 +470,39 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim | |||||
| elements.push_back(args_spec_list[0]->Clone()->Broaden()); | elements.push_back(args_spec_list[0]->Clone()->Broaden()); | ||||
| return std::make_shared<AbstractTuple>(elements); | return std::make_shared<AbstractTuple>(elements); | ||||
| } | } | ||||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto input_shp = arg->shape()->shape(); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto padding_attr = primitive->GetAttr("paddings"); | |||||
| MS_EXCEPTION_IF_NULL(padding_attr); | |||||
| if (!padding_attr->isa<ValueTuple>()) { | |||||
| MS_LOG(EXCEPTION) << "paddings is not a ValueTuple"; | |||||
| } | |||||
| std::vector<ValuePtr> paddings = padding_attr->cast<ValueTuplePtr>()->value(); | |||||
| std::vector<std::vector<int64_t>> paddings_vec; | |||||
| for (ValuePtr paddings_elements : paddings) { | |||||
| std::vector<ValuePtr> paddings_elements_tuple = paddings_elements->cast<ValueTuplePtr>()->value(); | |||||
| std::vector<int64_t> paddings_vec_item; | |||||
| (void)std::transform(std::begin(paddings_elements_tuple), std::end(paddings_elements_tuple), | |||||
| std::back_inserter(paddings_vec_item), | |||||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||||
| paddings_vec.push_back(paddings_vec_item); | |||||
| } | |||||
| ShapeVector result_shp; | |||||
| size_t length = paddings_vec.size(); | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| if (paddings_vec[i].size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "paddings 's second dim size is not 2"; | |||||
| } | |||||
| result_shp.push_back(input_shp[i] + paddings_vec[i][0] + paddings_vec[i][1]); | |||||
| } | |||||
| return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp)); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -50,6 +50,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | ||||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | ||||
| {prim::kPrimPack, {InferImplPack, true}}, | {prim::kPrimPack, {InferImplPack, true}}, | ||||
| {prim::kPrimPad, {InferImplPad, true}}, | |||||
| {prim::kPrimUnique, {InferImplUnique, true}}, | {prim::kPrimUnique, {InferImplUnique, true}}, | ||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | ||||
| {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | ||||
| @@ -101,6 +101,7 @@ inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||||
| inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx"); | inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx"); | ||||
| inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | ||||
| inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | ||||
| inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); | |||||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | ||||
| inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | ||||
| @@ -193,6 +193,7 @@ from .sigmoid_grad import _sigmoid_grad_tbe | |||||
| from .resize_nearest_neighbor import _resize_nearest_neighbor_tbe | from .resize_nearest_neighbor import _resize_nearest_neighbor_tbe | ||||
| from .resize_nearest_neighbor_grad import _resize_nearest_neighbor_grad_tbe | from .resize_nearest_neighbor_grad import _resize_nearest_neighbor_grad_tbe | ||||
| from .pad_d import _pad_d_tbe | from .pad_d import _pad_d_tbe | ||||
| from .pad_d_ds import _pad_d_ds_tbe | |||||
| from .arg_max_with_value import _arg_max_with_value_tbe | from .arg_max_with_value import _arg_max_with_value_tbe | ||||
| from .arg_min_with_value import _arg_min_with_value_tbe | from .arg_min_with_value import _arg_min_with_value_tbe | ||||
| from .smooth_l1_loss import _smooth_l1_loss_tbe | from .smooth_l1_loss import _smooth_l1_loss_tbe | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Pad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| pad_d_op_info = TBERegOp("Pad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("pad_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("pad_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("paddings", "optional", "listListInt", "all") \ | |||||
| .dynamic_shape(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(pad_d_op_info) | |||||
| def _pad_d_ds_tbe(): | |||||
| """Pad TBE register""" | |||||
| return | |||||