| @@ -64,8 +64,9 @@ constexpr auto kDynamicStitch = "DynamicStitch"; | |||
| constexpr auto kSearchSorted = "SearchSorted"; | |||
| constexpr auto kResizeBilinear = "ResizeBilinear"; | |||
| constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad"; | |||
| const std::set<std::string> kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch, | |||
| kSearchSorted, kResizeBilinear, kResizeBilinearGrad}; | |||
| constexpr auto kScatterElements = "ScatterElements"; | |||
| const std::set<std::string> kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch, | |||
| kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements}; | |||
| const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, | |||
| kPadAndShift, kDropout3D, kDropout2D}; | |||
| const std::set<std::string> kDynamicInputOps{ | |||
| @@ -332,11 +332,12 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t | |||
| if (!shape->isa<abstract::Shape>()) { | |||
| MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); | |||
| } | |||
| auto shape_me = shape->cast<abstract::ShapePtr>()->shape(); | |||
| auto shape_ge = py::cast<Tensor &>(data[*count]).shape(); | |||
| if (shape_ge != shape_me) { | |||
| MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge | |||
| << " is not the same as the shape of the tensor derived: " << shape_me; | |||
| if (shape_ge != shape_me) { // dynamic shape | |||
| MS_LOG(WARNING) << "The shape of the " << *count << "th tensor returned: " << shape_ge | |||
| << " is not the same as the shape of the tensor derived: " << shape_me; | |||
| } | |||
| return data[(*count)++]; | |||
| @@ -27,6 +27,8 @@ std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"Ba | |||
| {"DynamicRNN", "ND"}, | |||
| {"DynamicRNNGrad", "ND"}, | |||
| {"MatMul", "ND"}, | |||
| {"BatchMatMul", "ND"}, | |||
| {"BatchMatMulV2", "ND"}, | |||
| {"Quant", "ND"}, | |||
| {"BasicLSTMCellWeightGrad", "HWCN"}, | |||
| {"ExtractImagePatches", "NCHW"}, | |||
| @@ -43,6 +43,7 @@ constexpr const char kNameSquare[] = "Square"; | |||
| constexpr const char kNameSquaredDifference[] = "SquaredDifference"; | |||
| constexpr const char kNamePow[] = "Pow"; | |||
| constexpr const char kNameBatchMatMul[] = "BatchMatMul"; | |||
| constexpr const char kNameBatchMatMulV2[] = "BatchMatMulV2"; | |||
| constexpr const char kNameStridedSlice[] = "StridedSlice"; | |||
| constexpr const char kNameStridedSliceGrad[] = "StridedSliceGrad"; | |||
| constexpr const char kNameExpandDims[] = "ExpandDims"; | |||
| @@ -97,6 +98,8 @@ constexpr const char kNameSoftplusGrad[] = "SoftplusGrad"; | |||
| constexpr const char kNameElu[] = "Elu"; | |||
| constexpr const char kNameEluGrad[] = "EluGrad"; | |||
| constexpr const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; | |||
| constexpr const char kNameScatterElements[] = "ScatterElements"; | |||
| constexpr const char kNameNonZero[] = "NonZero"; | |||
| constexpr const char kNameScatterUpdate[] = "ScatterUpdate"; | |||
| constexpr const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | |||
| constexpr const char kNameScatterMax[] = "ScatterMax"; | |||
| @@ -102,6 +102,12 @@ ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())} | |||
| OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}}; | |||
| REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance)) | |||
| // NonZero | |||
| INPUT_MAP(NonZero) = {{1, INPUT_DESC(x)}}; | |||
| ATTR_MAP(NonZero) = {{"transpose", ATTR_DESC(transpose, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(NonZero) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(NonZero, kNameNonZero, ADPT_DESC(NonZero)) | |||
| // Unsqueeze | |||
| INPUT_MAP(Unsqueeze) = {{1, INPUT_DESC(x)}}; | |||
| ATTR_MAP(Unsqueeze) = {{"axis", ATTR_DESC(axes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}}; | |||
| @@ -61,6 +61,9 @@ DECLARE_OP_USE_OUTPUT(ReverseSequence) | |||
| DECLARE_OP_ADAPTER(EditDistance) | |||
| DECLARE_OP_USE_OUTPUT(EditDistance) | |||
| DECLARE_OP_ADAPTER(NonZero) | |||
| DECLARE_OP_USE_OUTPUT(NonZero) | |||
| DECLARE_OP_ADAPTER(Unsqueeze) | |||
| DECLARE_OP_USE_OUTPUT(Unsqueeze) | |||
| } // namespace mindspore::transform | |||
| @@ -126,7 +126,14 @@ INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||
| ATTR_MAP(BatchMatMul) = {{"transpose_x1", ATTR_DESC(adj_x1, AnyTraits<bool>())}, | |||
| {"transpose_x2", ATTR_DESC(adj_x2, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(BatchMatMul) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(BatchMatMul, kNameBatchMatMul, ADPT_DESC(BatchMatMul)) | |||
| // BatchMatMul->BatchMatMulV2 | |||
| INPUT_MAP(BatchMatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||
| ATTR_MAP(BatchMatMulV2) = {{"transpose_x1", ATTR_DESC(adj_x1, AnyTraits<bool>())}, | |||
| {"transpose_x2", ATTR_DESC(adj_x2, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(BatchMatMulV2) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(BatchMatMul, kNameBatchMatMul, ADPT_DESC(BatchMatMulV2)) | |||
| REG_ADPT_DESC(BatchMatMulV2, kNameBatchMatMulV2, ADPT_DESC(BatchMatMulV2)) | |||
| // L2Loss | |||
| INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; | |||
| @@ -134,6 +141,12 @@ ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(L2Loss, kNameL2Loss, ADPT_DESC(L2Loss)) | |||
| // ScatterElements | |||
| INPUT_MAP(ScatterElements) = {{1, INPUT_DESC(data)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||
| ATTR_MAP(ScatterElements) = {{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())}}; | |||
| OUTPUT_MAP(ScatterElements) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(ScatterElements, kNameScatterElements, ADPT_DESC(ScatterElements)) | |||
| // FullyConnection | |||
| INPUT_MAP(FullyConnection) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(offset_w)}}; | |||
| @@ -59,6 +59,9 @@ DECLARE_OP_USE_OUTPUT(ScatterNdSub) | |||
| DECLARE_OP_ADAPTER(BatchMatMul) | |||
| DECLARE_OP_USE_OUTPUT(BatchMatMul) | |||
| DECLARE_OP_ADAPTER(BatchMatMulV2) | |||
| DECLARE_OP_USE_OUTPUT(BatchMatMulV2) | |||
| DECLARE_OP_ADAPTER(MatMul) | |||
| DECLARE_OP_USE_OUTPUT(MatMul) | |||
| @@ -80,6 +83,9 @@ DECLARE_OP_USE_OUTPUT(DiagPart) | |||
| DECLARE_OP_ADAPTER(L2Loss) | |||
| DECLARE_OP_USE_OUTPUT(L2Loss) | |||
| DECLARE_OP_ADAPTER(ScatterElements) | |||
| DECLARE_OP_USE_OUTPUT(ScatterElements) | |||
| DECLARE_OP_ADAPTER(FullyConnection) | |||
| DECLARE_OP_USE_OUTPUT(FullyConnection) | |||
| } // namespace mindspore::transform | |||
| @@ -175,6 +175,8 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplNonZero(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -185,6 +187,8 @@ AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePt | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplScatterElements(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -198,6 +198,30 @@ AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitiveP | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplNonZero(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| const size_t size_expected = 1; | |||
| CheckArgsSize(op_name, args_spec_list, size_expected); | |||
| AbstractTensorPtr x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| auto x_shape = x->shape(); | |||
| MS_EXCEPTION_IF_NULL(x_shape); | |||
| ShapeVector y_shape; | |||
| int64_t rank_base = SizeToLong(x_shape->shape().size()); | |||
| int64_t max_size = std::accumulate(x_shape->shape().begin(), x_shape->shape().end(), 1, std::multiplies<int64_t>()); | |||
| y_shape.emplace_back(rank_base); | |||
| // Indices of elements that are non-zero | |||
| y_shape.emplace_back(Shape::SHP_ANY); | |||
| ShapeVector min_shape = {rank_base, 1}; | |||
| ShapeVector max_shape = {rank_base, max_size}; | |||
| return std::make_shared<AbstractTensor>(kInt64, std::make_shared<Shape>(y_shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // inputs: a 1-d Tensor | |||
| @@ -426,6 +450,20 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplScatterElements(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckRequiredArgsSize(op_name, args_spec_list, 3); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| ShapeVector shape = x->shape()->shape(); | |||
| ShapeVector min_shape = x->shape()->min_shape(); | |||
| ShapeVector max_shape = x->shape()->max_shape(); | |||
| CheckMinMaxShape(shape, &min_shape, &max_shape); | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| @@ -117,6 +117,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}}, | |||
| {prim::kPrimScatterSub, {InferImplScatterSub, nullptr, true}}, | |||
| {prim::kPrimScatterElements, {InferImplScatterElements, nullptr, true}}, | |||
| {prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}}, | |||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}}, | |||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}}, | |||
| @@ -133,6 +134,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimSort, {InferImplSort, nullptr, true}}, | |||
| {prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}}, | |||
| {prim::kPrimTensorCopySlices, {InferImplTensorCopySlices, nullptr, true}}, | |||
| {prim::kPrimNonZero, {InferImplNonZero, nullptr, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}}, | |||
| @@ -223,6 +223,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("D | |||
| inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd"); | |||
| inline const PrimitivePtr kPrimScatterSub = std::make_shared<Primitive>("ScatterSub"); | |||
| inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate"); | |||
| inline const PrimitivePtr kPrimScatterElements = std::make_shared<Primitive>("ScatterElements"); | |||
| inline const PrimitivePtr kPrimTensorCopySlices = std::make_shared<Primitive>("TensorCopySlices"); | |||
| inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform"); | |||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | |||
| @@ -248,6 +249,7 @@ inline const PrimitivePtr kPrimMaskedFill = std::make_shared<Primitive>("MaskedF | |||
| inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect"); | |||
| inline const PrimitivePtr kPrimDiag = std::make_shared<Primitive>(kDiag); | |||
| inline const PrimitivePtr kPrimDiagPart = std::make_shared<Primitive>(kDiagPart); | |||
| inline const PrimitivePtr kPrimNonZero = std::make_shared<Primitive>("NonZero"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | |||
| @@ -36,15 +36,10 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive, | |||
| auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); | |||
| auto x_shp = x_shape_map[kShape]; | |||
| auto y_shp = y_shape_map[kShape]; | |||
| if (x_shp.size() != y_shp.size() || x_shp.size() < 3) { | |||
| MS_EXCEPTION(ValueError) << "For BatchMatMul, input x, y should have the same dimension size and should be greater" | |||
| << "or equal to 3, while x size = " << x_shp.size() << ", y size = " << y_shp.size(); | |||
| } | |||
| for (size_t i = 0; i < x_shp.size() - 2; ++i) { | |||
| if (x_shp[i] != y_shp[i]) { | |||
| MS_EXCEPTION(ValueError) << "For " << prim_name << " shapes in dim[" << i << "] are not the same" | |||
| << "while x1 is " << x_shp[i] << ", x2 is " << y_shp[i]; | |||
| } | |||
| if (x_shp.size() < 3 || y_shp.size() < 2) { | |||
| MS_EXCEPTION(ValueError) << "For BatchMatMul, input x should be greater or equal to 3, input y should be greater " | |||
| "or equal to 2 while x size = " | |||
| << x_shp.size() << ", y size = " << y_shp.size(); | |||
| } | |||
| std::vector<int> x_last(x_shp.end() - 2, x_shp.end()); | |||
| std::vector<int> y_last(y_shp.end() - 2, y_shp.end()); | |||
| @@ -78,9 +73,10 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive, | |||
| bool y_not_dyn = | |||
| std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != abstract::Shape::SHP_ANY; }); | |||
| if (x_not_dyn && y_not_dyn) { | |||
| size_t offset = x_shp.size() - 2; | |||
| auto x_c = x_shp[offset + (transpose_a ? 0 : 1)]; | |||
| auto y_r = y_shp[offset + (transpose_b ? 1 : 0)]; | |||
| size_t x_offset = x_shp.size() - 2; | |||
| size_t y_offset = y_shp.size() - 2; | |||
| auto x_c = x_shp[x_offset + (transpose_a ? 0 : 1)]; | |||
| auto y_r = y_shp[y_offset + (transpose_b ? 1 : 0)]; | |||
| if (x_c != y_r) { | |||
| MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_c << ", y_row: " << y_r | |||
| << ". In BatchMatMul x_col and y_row should be equal."; | |||
| @@ -92,18 +88,16 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive, | |||
| auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, | |||
| const ShapeVector yshp) -> void { | |||
| for (size_t i = 0; i < xshp.size() - 2; i++) { | |||
| if (xshp[i] != yshp[i]) { | |||
| if (xshp[i] > 0 && yshp[i] > 0) { | |||
| MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << "."; | |||
| } | |||
| if (xshp[i] < 0) { | |||
| output.push_back(abstract::Shape::SHP_ANY); | |||
| } else { | |||
| output.push_back(xshp[i]); | |||
| } | |||
| } | |||
| size_t offset = xshp.size() - 2; | |||
| output.push_back(xshp[offset + (transpose_a ? 1 : 0)]); | |||
| output.push_back(yshp[offset + (transpose_b ? 0 : 1)]); | |||
| size_t x_offset = xshp.size() - 2; | |||
| size_t y_offset = yshp.size() - 2; | |||
| output.push_back(xshp[x_offset + (transpose_a ? 1 : 0)]); | |||
| output.push_back(yshp[y_offset + (transpose_b ? 0 : 1)]); | |||
| return; | |||
| }; | |||
| make_shape(ret_shape, x_shp, y_shp); | |||
| @@ -79,3 +79,4 @@ from .stack_push_pop import _stack_destroy_aicpu | |||
| from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu | |||
| from .resize_bilinear import _resize_bilinear_aicpu | |||
| from .resize_bilinear_grad import _resize_bilinear_grad_aicpu | |||
| from .scatter_elements import _scatter_elements_aicpu | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ScatterElements op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| scatter_elements_op_info = AiCPURegOp("ScatterElements") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .attr("axis", "int") \ | |||
| .input(0, "data", "required") \ | |||
| .input(1, "indices", "required") \ | |||
| .input(2, "updates", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(scatter_elements_op_info) | |||
| def _scatter_elements_aicpu(): | |||
| """ScatterElements AiCPU register""" | |||
| return | |||
| @@ -414,3 +414,4 @@ from .hsigmoid import _hsigmoid_tbe | |||
| from .hshrink import _hshrink_tbe | |||
| from .hshrink_grad import _hshrink_grad_tbe | |||
| from .new_im2col import _new_im2col_tbe | |||
| from .non_zero_ds import _non_zero_ds_tbe | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """NonZero op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| non_zero_op_info = TBERegOp("NonZero") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("non_zero.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("non_zero") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("transpose", "optional", "bool", "all", "false") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(non_zero_op_info) | |||
| def _non_zero_ds_tbe(): | |||
| """NonZero TBE register""" | |||
| return | |||
| @@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||
| UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, | |||
| BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, | |||
| EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted, | |||
| TensorScatterMax, TensorScatterMin, TensorScatterSub) | |||
| TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements) | |||
| from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | |||
| _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, | |||
| @@ -120,7 +120,7 @@ from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, L | |||
| Dihedral14ForceWithAtomEnergyVirial, PMEEnergyUpdate, | |||
| ConstrainForceVirial, ConstrainForce, Constrain) | |||
| from .rl_ops import (BufferAppend, BufferGetItem, BufferSample) | |||
| from ._inner_ops import (MatmulDDS, DSDMatmul) | |||
| from ._inner_ops import (MatmulDDS, DSDMatmul, NonZero) | |||
| __all__ = [ | |||
| 'Unique', | |||
| @@ -468,6 +468,8 @@ __all__ = [ | |||
| "TensorScatterMax", | |||
| "TensorScatterMin", | |||
| "TensorScatterSub", | |||
| "ScatterElements", | |||
| "NonZero", | |||
| "SoftShrink", | |||
| "FFT3D", | |||
| "IFFT3D", | |||
| @@ -1345,3 +1345,35 @@ class MatmulDDSGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad): | |||
| return q, k | |||
| class NonZero(Primitive): | |||
| """ | |||
| Returns the indices of the elements that are non-zero (in row-major order - by dimension). | |||
| Args: | |||
| transpose (bool): Permutes the dimensions of the output tensor according to input permutation, | |||
| default is false | |||
| Inputs: | |||
| - **x** (Tensor), input array of rank >= 2. | |||
| Outputs: | |||
| 2-D Tensor, int64, indices of elements that are non-zero. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> op = ops.NonZero(True) | |||
| >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32) | |||
| >>> output = op(data) | |||
| >>> print(output) | |||
| [[ 0 0] | |||
| [ 1 1]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, transpose=False): | |||
| """Initialize ScatterElements""" | |||
| validator.check_value_type("transpose", transpose, [bool], self.name) | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| @@ -6403,3 +6403,53 @@ class SplitV(Primitive): | |||
| validator.check_value_type("num_split", num_split, [int], self.name) | |||
| validator.check_positive_int(num_split, "num_split", self.name) | |||
| self.init_prim_io_names(inputs=['input_x'], outputs=['output']) | |||
| class ScatterElements(Primitive): | |||
| """ | |||
| ScatterElements takes three inputs data, updates, and indices of the same rank r >= 1 | |||
| and an optional attribute axis that identifies an axis of data (default is 0). | |||
| The output of the operation is produced by creating a copy of the input data, and then updating its value to | |||
| values specified by updates at specific index positions specified by indices. | |||
| Args: | |||
| axis (int): which axis to scatter, default is 0. | |||
| Inputs: | |||
| - **data** (Tensor) - The target tensor. c | |||
| - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64. | |||
| - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | |||
| and update.shape should be equal to indices.shape. | |||
| Outputs: | |||
| Tensor, has the same shape and type as `data`. | |||
| Raises: | |||
| TypeError: If dtype of `indices` is neither int32 nor int64. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> op = ops.ScatterElements(0) | |||
| >>> data = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32) | |||
| >>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32) | |||
| >>> updates = Tensor(np.array([[0, 0, 0], [0, 0, 0]]), mindspore.float32) | |||
| >>> output = op(data, indices, updates) | |||
| >>> print(output) | |||
| [[ 0.0 0.0 3.0] | |||
| [ 0.0 5.0 0.0] | |||
| [ 7.0 0.0 0.0]] | |||
| >>> op = ops.ScatterElements(1) | |||
| >>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.int32) | |||
| >>> indices = Tensor(np.array([[2, 4]), mindspore.int32) | |||
| >>> updates = Tensor(np.array([[8, 8]]), mindspore.int32) | |||
| >>> output = op(data, indices, updates) | |||
| >>> print(output) | |||
| [[ 1 2 8 4 8]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| """Initialize ScatterElements""" | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y']) | |||
| @@ -1117,11 +1117,6 @@ class MatMul(PrimitiveWithCheck): | |||
| def check_shape(self, x1, x2): | |||
| self.check_shape_size(x1, x2) | |||
| cls_name = self.name | |||
| # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | |||
| for i in range(len(x1) - 2): | |||
| if x1[i] != x2[i]: | |||
| raise ValueError(f"For '{cls_name}', the dim[{i}] of 'x' should be equal to the dim[{i}] of 'y', " | |||
| f"but got 'x[{i}]': {x1[i]} and 'y[{i}]': {x2[i]}.") | |||
| # validate whether last two dims satisfying matrix multiply | |||
| x1_last = x1[-2:] | |||
| @@ -1150,7 +1145,7 @@ class BatchMatMul(MatMul): | |||
| \\text{output}[..., :, :] = \\text{matrix}(x[..., :, :]) * \\text{matrix}(y[..., :, :]) | |||
| The two input tensors must have the same rank and the rank must be not less than `3`. | |||
| The first input tensor must be not less than `3` and the second input must be not less than `2`. | |||
| Args: | |||
| transpose_x (bool): If true, the last two dimensions of `x` is transposed before multiplication. | |||
| @@ -1214,9 +1209,9 @@ class BatchMatMul(MatMul): | |||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||
| def check_shape_size(self, x, y): | |||
| if len(x) != len(y) or len(x) < 3: | |||
| raise ValueError(f"For '{self.name}', input 'x', 'y' should be the same dimension size and should be " | |||
| f"greater than or equal to 3, but got 'x' size: {len(x)}, 'y' size: {len(y)}.") | |||
| if len(x) < 3 or len(y) < 2: | |||
| raise ValueError(f"For '{self.name}', input 'x' should be greater than or equal to 3, input 'y' should " | |||
| f"be greater than or equal to 2, but got 'x' size: {len(x)}, 'y' size: {len(y)}.") | |||
| class CumSum(PrimitiveWithInfer): | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.scatter_elements = P.ScatterElements(1) | |||
| def construct(self, data, indices, updates): | |||
| return self.scatter_elements(data, indices, updates) | |||
| def test_net(): | |||
| data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32) | |||
| indices = np.array([[1, 0, 2], [0, 2, 1]]).astype(np.int32) | |||
| updates = np.array([[0, 0, 0], [0, 0, 0]]).astype(np.float32) | |||
| net = Net() | |||
| tdata = Tensor(data) | |||
| tindices = Tensor(indices) | |||
| tupdates = net(updates) | |||
| output = net(tdata, tindices, tupdates) | |||
| print(output.asnumpy()) | |||
| assert np.all([[0.0, 0.0, 3.0], [0.0, 5.0, 0.0], [7.0, 0.0, 0.0]] == output.asnumpy()) | |||