From: @david-he91 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -21,6 +21,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | ||||
| node_ = kernel_node; | |||||
| CheckParam(kernel_node); | CheckParam(kernel_node); | ||||
| axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS)); | axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS)); | ||||
| @@ -28,27 +29,28 @@ void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (axis_ < 0) { | if (axis_ < 0) { | ||||
| axis_ = axis_ + SizeToInt(input_1_shape.size()); | axis_ = axis_ + SizeToInt(input_1_shape.size()); | ||||
| } | } | ||||
| input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < input_num_; i++) { | |||||
| auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||||
| auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_); | |||||
| input_flat_shape_list_.push_back(flat_shape); | |||||
| } | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(node_); | |||||
| std::vector<std::vector<size_t>> input_flat_shape_list; | |||||
| for (size_t i = 0; i < input_num; i++) { | |||||
| auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node_, i); | |||||
| auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_); | |||||
| input_flat_shape_list.push_back(flat_shape); | |||||
| } | |||||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | ||||
| auto buff_size = outputs[0]->size; | auto buff_size = outputs[0]->size; | ||||
| // each input's row of shape after flat are same | // each input's row of shape after flat are same | ||||
| auto before_axis = input_flat_shape_list_[0][0]; | |||||
| auto before_axis = input_flat_shape_list[0][0]; | |||||
| for (size_t i = 0; i < before_axis; ++i) { | for (size_t i = 0; i < before_axis; ++i) { | ||||
| for (size_t j = 0; j < input_num_; ++j) { | |||||
| for (size_t j = 0; j < input_num; ++j) { | |||||
| auto input_j_addr = reinterpret_cast<T *>(inputs[j]->addr); | auto input_j_addr = reinterpret_cast<T *>(inputs[j]->addr); | ||||
| auto copy_num = input_flat_shape_list_[j][1]; | |||||
| auto copy_num = input_flat_shape_list[j][1]; | |||||
| auto offset = copy_num * i; | auto offset = copy_num * i; | ||||
| auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T)); | auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -36,8 +36,7 @@ class ConcatCPUKernel : public CPUKernel { | |||||
| private: | private: | ||||
| void CheckParam(const CNodePtr &kernel_node); | void CheckParam(const CNodePtr &kernel_node); | ||||
| int axis_ = 0; | int axis_ = 0; | ||||
| size_t input_num_ = 1; | |||||
| std::vector<std::vector<size_t>> input_flat_shape_list_; | |||||
| CNodePtr node_ = nullptr; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL_T( | MS_REG_CPU_KERNEL_T( | ||||
| @@ -140,7 +140,7 @@ void OpTilingCalculater::Init() { | |||||
| tiling_func_map_ = optiling::OpTilingRegistryInterf::RegisteredOpInterf(); | tiling_func_map_ = optiling::OpTilingRegistryInterf::RegisteredOpInterf(); | ||||
| MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size(); | MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size(); | ||||
| for (const auto &iter : tiling_func_map_) { | for (const auto &iter : tiling_func_map_) { | ||||
| MS_LOG(INFO) << "Regist tiling func:" << iter.first; | |||||
| MS_LOG(INFO) << "Register tiling func:" << iter.first; | |||||
| } | } | ||||
| } | } | ||||
| @@ -150,6 +150,7 @@ std::string GetRealOpType(const std::string &op_type) { | |||||
| {"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"}, | {"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"}, | ||||
| {"SparseGatherV2", "GatherV2"}, | {"SparseGatherV2", "GatherV2"}, | ||||
| {"Pad", "PadD"}, | {"Pad", "PadD"}, | ||||
| {"Concat", "ConcatD"}, | |||||
| }; | }; | ||||
| auto iter = kOpTypeMap.find(op_type); | auto iter = kOpTypeMap.find(op_type); | ||||
| if (iter == kOpTypeMap.end()) { | if (iter == kOpTypeMap.end()) { | ||||
| @@ -30,6 +30,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // op name. Op which not exists in operator/ops.h, so define it's name here | // op name. Op which not exists in operator/ops.h, so define it's name here | ||||
| constexpr auto kConcatOpName = "Concat"; | |||||
| constexpr auto kUniqueOpName = "Unique"; | constexpr auto kUniqueOpName = "Unique"; | ||||
| constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; | constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; | ||||
| constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; | constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; | ||||
| @@ -494,7 +495,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH | |||||
| const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; | const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; | ||||
| const std::set<std::string> DynamicShapeConstInputToAttr = { | const std::set<std::string> DynamicShapeConstInputToAttr = { | ||||
| kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; | |||||
| kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, | |||||
| kTransposeOpName, kReduceSumOpName, kConcatOpName}; | |||||
| static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | ||||
| try { | try { | ||||
| @@ -291,6 +291,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -954,6 +954,90 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive | |||||
| return std::make_shared<AbstractTensor>(kBool, output_shape); | return std::make_shared<AbstractTensor>(kBool, output_shape); | ||||
| } | } | ||||
| AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| const std::string op_name = primitive->name(); | |||||
| if (args_spec_list.empty()) { | |||||
| MS_LOG(EXCEPTION) << "args_spec_list is empty."; | |||||
| } | |||||
| AbstractTuplePtr arg = nullptr; | |||||
| AbstractTensorPtr tensor_base = nullptr; | |||||
| size_t tuple_len = 0; | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| if (args_spec_list[0]->isa<AbstractTuple>()) { | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||||
| tuple_len = arg->elements().size(); | |||||
| tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0); | |||||
| } else if (args_spec_list[0]->isa<AbstractTensor>()) { | |||||
| tuple_len = args_spec_list.size(); | |||||
| tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(tensor_base); | |||||
| ShapeVector shape_base = tensor_base->shape()->shape(); | |||||
| int64_t rank_base = SizeToLong(shape_base.size()); | |||||
| ShapeVector min_shape_base = tensor_base->shape()->min_shape(); | |||||
| ShapeVector max_shape_base = tensor_base->shape()->max_shape(); | |||||
| (void)CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base); | |||||
| primitive->set_attr("T", tensor_base->element()->BuildType()); | |||||
| primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len))); | |||||
| ValuePtr axis = primitive->GetAttr("axis"); | |||||
| // Axis value should be in [-(rank_base + 1), rank_base). | |||||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); | |||||
| // If axis is negative, add offset(rank_base) to turn it to positive. | |||||
| axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base)); | |||||
| int64_t all_shp = shape_base[axis_value]; | |||||
| int64_t min_all_shp = min_shape_base[axis_value]; | |||||
| int64_t max_all_shp = max_shape_base[axis_value]; | |||||
| for (size_t i = 1; i < tuple_len; ++i) { | |||||
| AbstractTensorPtr tensor = nullptr; | |||||
| if (args_spec_list[0]->isa<AbstractTuple>()) { | |||||
| tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i); | |||||
| } else if (args_spec_list[0]->isa<AbstractTensor>()) { | |||||
| tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | |||||
| } | |||||
| ShapeVector shape_tensor = tensor->shape()->shape(); | |||||
| int64_t rank_tensor = SizeToLong(shape_tensor.size()); | |||||
| ShapeVector min_shape_tensor = tensor->shape()->min_shape(); | |||||
| ShapeVector max_shape_tensor = tensor->shape()->max_shape(); | |||||
| (void)CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor); | |||||
| (void)CheckDtypeSame(op_name, tensor_base, tensor); | |||||
| if (rank_tensor != rank_base) { | |||||
| MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank"; | |||||
| } | |||||
| for (int j = 0; j < rank_base; ++j) { | |||||
| if (j != axis_value && shape_tensor[j] != shape_base[j]) { | |||||
| MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Size"; | |||||
| } | |||||
| } | |||||
| if (all_shp == -1 || shape_base[axis_value] == -1) { | |||||
| all_shp = -1; | |||||
| } else { | |||||
| all_shp += shape_tensor[axis_value]; | |||||
| } | |||||
| min_all_shp += min_shape_tensor[axis_value]; | |||||
| max_all_shp += max_shape_tensor[axis_value]; | |||||
| } | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden()); | |||||
| MS_EXCEPTION_IF_NULL(ret); | |||||
| auto shape = ret->shape()->shape(); | |||||
| auto min_shape = ret->shape()->min_shape(); | |||||
| auto max_shape = ret->shape()->max_shape(); | |||||
| (void)CheckMinMaxShape(shape, &min_shape, &max_shape); | |||||
| shape[axis_value] = all_shp; | |||||
| min_shape[axis_value] = min_all_shp; | |||||
| max_shape[axis_value] = max_all_shp; | |||||
| ret->set_shape(std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string &op_name = primitive->name(); | const std::string &op_name = primitive->name(); | ||||
| @@ -84,6 +84,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | ||||
| {prim::kPrimSplit, {InferImplSplit, true}}, | {prim::kPrimSplit, {InferImplSplit, true}}, | ||||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | ||||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||||
| {prim::kPrimRange, {InferImplRange, true}}, | {prim::kPrimRange, {InferImplRange, true}}, | ||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | ||||
| @@ -170,6 +170,7 @@ from .minimum_ds import _minimum_ds_tbe | |||||
| from .minimum_grad import _minimum_grad_tbe | from .minimum_grad import _minimum_grad_tbe | ||||
| from .maximum_grad import _maximum_grad_tbe | from .maximum_grad import _maximum_grad_tbe | ||||
| from .concat import _concat_tbe | from .concat import _concat_tbe | ||||
| from .concat_ds import _concat_ds_tbe | |||||
| from .slice import _slice_tbe | from .slice import _slice_tbe | ||||
| from .sign import _sign_tbe | from .sign import _sign_tbe | ||||
| from .greater import _greater_tbe | from .greater import _greater_tbe | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Concat op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| concat_ds_op_info = TBERegOp("Concat") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("concat_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("concat_d") \ | |||||
| .partial_flag(True) \ | |||||
| .dynamic_shape(True) \ | |||||
| .attr("axis", "required", "int", "all") \ | |||||
| .input(0, "input_values", False, "dynamic", "all") \ | |||||
| .output(0, "output_data", False, "required", "all") \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(concat_ds_op_info) | |||||
| def _concat_ds_tbe(): | |||||
| """Concat TBE register""" | |||||
| return | |||||
| @@ -2149,6 +2149,19 @@ class Concat(PrimitiveWithInfer): | |||||
| out = {'shape': ret_shp, | out = {'shape': ret_shp, | ||||
| 'dtype': x_type[0], | 'dtype': x_type[0], | ||||
| 'value': value} | 'value': value} | ||||
| if -1 in x_shp[0]: | |||||
| x_min_shp = input_x['min_shape'] | |||||
| ret_min_shp = x_min_shp[0].copy() | |||||
| ret_min_shp[axis] = 0 | |||||
| for all_min_shp in x_min_shp: | |||||
| ret_min_shp[axis] += all_min_shp[axis] | |||||
| out['min_shape'] = ret_min_shp | |||||
| x_max_shp = input_x['max_shape'] | |||||
| ret_max_shp = x_max_shp[0].copy() | |||||
| ret_max_shp[axis] = 0 | |||||
| for all_max_shp in x_max_shp: | |||||
| ret_max_shp[axis] += all_max_shp[axis] | |||||
| out['max_shape'] = ret_max_shp | |||||
| return out | return out | ||||
| @@ -2790,7 +2803,7 @@ class StridedSlice(PrimitiveWithInfer): | |||||
| if has_ellipsis: | if has_ellipsis: | ||||
| # When there is ellipsis, handle the second half of the ellipsis split. | # When there is ellipsis, handle the second half of the ellipsis split. | ||||
| ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \ | ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \ | ||||
| len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) | |||||
| len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) | |||||
| ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims]) | ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims]) | ||||
| j += 1 | j += 1 | ||||
| i += ellipsis_occupied_dims | i += ellipsis_occupied_dims | ||||
| @@ -3988,7 +4001,7 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| offset = 1 | offset = 1 | ||||
| for i in range(len(self.block_shape)): | for i in range(len(self.block_shape)): | ||||
| padded = out_shape[i + offset] + self.paddings[i][0] + \ | padded = out_shape[i + offset] + self.paddings[i][0] + \ | ||||
| self.paddings[i][1] | |||||
| self.paddings[i][1] | |||||
| if padded % self.block_shape[i] != 0: | if padded % self.block_shape[i] != 0: | ||||
| raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | ||||
| f'block_shape[{i}] {self.block_shape[i]}') | f'block_shape[{i}] {self.block_shape[i]}') | ||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, axis=0): | |||||
| super(Net, self).__init__() | |||||
| self.unique = P.Unique() | |||||
| self.reshape = P.Reshape() | |||||
| self.concat = P.Concat(axis=axis) | |||||
| def construct(self, x1, x2): | |||||
| out1_unique, _ = self.unique(x1) | |||||
| out2_unique, _ = self.unique(x2) | |||||
| out1_shape = self.reshape(out1_unique, (1, -1, 2)) | |||||
| out2_shape = self.reshape(out2_unique, (1, -1, 2)) | |||||
| return self.concat((out1_shape, out2_shape)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_concat(): | |||||
| x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32) | |||||
| x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32) | |||||
| net = Net(axis=1) | |||||
| output = net(x1, x2) | |||||
| expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]]) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, axis=0): | |||||
| super(Net, self).__init__() | |||||
| self.unique = P.Unique() | |||||
| self.reshape = P.Reshape() | |||||
| self.concat = P.Concat(axis=axis) | |||||
| def construct(self, x1, x2): | |||||
| out1_unique, _ = self.unique(x1) | |||||
| out2_unique, _ = self.unique(x2) | |||||
| out1_shape = self.reshape(out1_unique, (1, -1, 2)) | |||||
| out2_shape = self.reshape(out2_unique, (1, -1, 2)) | |||||
| return self.concat((out1_shape, out2_shape)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_concat_cpu(): | |||||
| x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32) | |||||
| x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32) | |||||
| net = Net(axis=1) | |||||
| output = net(x1, x2) | |||||
| expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]]) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @@ -835,14 +835,14 @@ def test_mixed_precision_cast(): | |||||
| assert z.dtype == mstype.float16 | assert z.dtype == mstype.float16 | ||||
| def test_while_concat(): | |||||
| def test_while_add(): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, data): | def __init__(self, data): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self.start = Tensor(0, dtype=mstype.int32) | self.start = Tensor(0, dtype=mstype.int32) | ||||
| self.end = Tensor(2, dtype=mstype.int32) | self.end = Tensor(2, dtype=mstype.int32) | ||||
| self.out = Tensor(np.zeros([2, 3], dtype=np.float32)) | self.out = Tensor(np.zeros([2, 3], dtype=np.float32)) | ||||
| self.concat = P.Concat() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, inputs): | def construct(self, inputs): | ||||
| idx = self.start | idx = self.start | ||||
| @@ -850,7 +850,7 @@ def test_while_concat(): | |||||
| out = self.out | out = self.out | ||||
| while idx < end: | while idx < end: | ||||
| xi = inputs[idx, :, :] | xi = inputs[idx, :, :] | ||||
| out = self.concat((out, xi)) | |||||
| out = self.add(out, xi) | |||||
| idx = idx + 1 | idx = idx + 1 | ||||
| return out | return out | ||||