Merge pull request !25592 from wangnan39/reshape_support_tensortags/v1.6.0
| @@ -149,23 +149,16 @@ std::vector<int64_t> GetInputShape(const CNodePtr &cnode, size_t index) { | |||
| size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64_t>> &grad_reduce_idx, size_t index, | |||
| size_t input_num) { | |||
| std::vector<int64_t> output; | |||
| size_t idx_num = grad_reduce_idx[index].size(); | |||
| for (size_t k = 0; k < idx_num; ++k) { | |||
| output.push_back(grad_reduce_idx[index][idx_num - 1 - k]); | |||
| size_t out_size = grad_reduce_idx[index].size(); | |||
| for (size_t k = 0; k < out_size; ++k) { | |||
| output.push_back(grad_reduce_idx[index][out_size - 1 - k]); | |||
| } | |||
| if (out_size == 0) { | |||
| return out_size; | |||
| } | |||
| auto out_addr = AnfAlgo::GetOutputAddr(cnode, index); | |||
| MS_EXCEPTION_IF_NULL(out_addr); | |||
| size_t out_size = idx_num; | |||
| if (idx_num == 0) { | |||
| out_size = input_num; | |||
| for (size_t k = 0; k < input_num; ++k) { | |||
| output.push_back(k); | |||
| } | |||
| } | |||
| std::vector<int64_t> out_shape{SizeToLong(out_size)}; | |||
| auto output_type = TypeId::kNumberTypeInt64; | |||
| auto tensor_for_sync = std::make_shared<tensor::Tensor>(output_type, out_shape); | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/host/dynamic_reshape_kernel.h" | |||
| #include <functional> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "abstract/utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kInputNum = 2; | |||
| std::vector<int64_t> GetInputValue(const CNodePtr &cnode, size_t index) { | |||
| auto address_x = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, index); | |||
| auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); | |||
| if (shape_x.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Input" << index << " must be [1-D], but " << shape_x.size() << "-D."; | |||
| } | |||
| session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, index); | |||
| auto type_x = AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); | |||
| if (type_x != TypeId::kNumberTypeInt64 && type_x != TypeId::kNumberTypeInt32) { | |||
| MS_LOG(EXCEPTION) << "Input x type must be int64 or int32, but :" << TypeIdToType(type_x); | |||
| } | |||
| size_t x_num = shape_x[0]; | |||
| std::vector<int64_t> x{SizeToLong(x_num)}; | |||
| auto x_shape_value = std::make_shared<tensor::Tensor>(type_x, x); | |||
| x_shape_value->set_device_address(address_x); | |||
| x_shape_value->data_sync(); | |||
| std::vector<int64_t> input_shape; | |||
| if (type_x == TypeId::kNumberTypeInt64) { | |||
| auto x_value = reinterpret_cast<int64_t *>(x_shape_value->data_c()); | |||
| MS_EXCEPTION_IF_NULL(x_value); | |||
| input_shape = {x_value, x_value + x_num}; | |||
| } else { | |||
| auto x_value = reinterpret_cast<int *>(x_shape_value->data_c()); | |||
| MS_EXCEPTION_IF_NULL(x_value); | |||
| for (size_t i = 0; i < x_num; i++) { | |||
| input_shape.push_back(static_cast<int64_t>(*x_value)); | |||
| ++x_value; | |||
| } | |||
| } | |||
| return input_shape; | |||
| } | |||
| } // namespace | |||
| void DynamicReshapeKernel::Execute() { | |||
| MS_LOG(INFO) << "Execute host ReshapeKernel Start"; | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| if (input_num != kInputNum) { | |||
| MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num; | |||
| } | |||
| auto address_x = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, 0); | |||
| MS_EXCEPTION_IF_NULL(address_x); | |||
| auto type_x = AnfAlgo::GetOutputInferDataType(cnode, 0); | |||
| auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| std::vector<int64_t> output_shapes = GetInputValue(cnode, 1); | |||
| int64_t dim_prod = 1; | |||
| int64_t neg_index = -1; | |||
| auto arr_prod = std::accumulate(shape_x.begin(), shape_x.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| for (size_t i = 0; i < output_shapes.size(); ++i) { | |||
| if (output_shapes[i] == -1) { | |||
| neg_index = SizeToLong(i); | |||
| } else { | |||
| dim_prod *= output_shapes[i]; | |||
| } | |||
| } | |||
| if (neg_index != -1) { | |||
| output_shapes[LongToSize(neg_index)] = arr_prod / dim_prod; | |||
| } | |||
| size_t input_size_byte = LongToSize(arr_prod) * abstract::TypeIdSize(type_x); | |||
| auto output_addr = AnfAlgo::GetOutputAddr(cnode, 0); | |||
| MS_EXCEPTION_IF_NULL(output_addr); | |||
| if (!output_addr->SyncDeviceToDevice(output_shapes, input_size_byte, address_x->type_id(), address_x->GetPtr(), | |||
| address_x->format())) { | |||
| MS_LOG(EXCEPTION) << "Host Reshape sync device to device failed."; | |||
| } | |||
| MS_LOG(INFO) << "Execute host ReshapeKernel End"; | |||
| } | |||
| device::DynamicKernelPtr DynamicReshapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { | |||
| return std::make_shared<DynamicReshapeKernel>(stream_ptr, cnode_ptr); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_RESHAPE_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_RESHAPE_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "runtime/device/ascend/executor/host_dynamic_kernel.h" | |||
| #include "backend/kernel_compiler/host/host_kernel_mod.h" | |||
| using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel; | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class DynamicReshapeKernel : public HostDynamicKernel { | |||
| public: | |||
| DynamicReshapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {} | |||
| ~DynamicReshapeKernel() override = default; | |||
| void Execute() override; | |||
| }; | |||
| class DynamicReshapeKernelMod : public HostKernelMod { | |||
| public: | |||
| DynamicReshapeKernelMod() = default; | |||
| ~DynamicReshapeKernelMod() override = default; | |||
| device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; | |||
| }; | |||
| MS_HOST_REG_KERNEL(DynamicReshape, DynamicReshapeKernelMod); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_RESHAPE_KERNEL_H_ | |||
| @@ -24,8 +24,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| static const std::set<std::string> host_kernel = {prim::kPrimDynamicShape->name(), | |||
| prim::kPrimDynamicBroadcastGradientArgs->name()}; | |||
| static const std::set<std::string> host_kernel = { | |||
| prim::kPrimDynamicShape->name(), prim::kPrimDynamicBroadcastGradientArgs->name(), prim::kPrimDynamicReshape->name()}; | |||
| void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||
| MS_LOG(INFO) << "HostMetadataInfo."; | |||
| @@ -144,6 +144,7 @@ | |||
| #include "backend/optimizer/ascend/mindir/update_input_names_strided_slice_grad.h" | |||
| #include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/bn_grad_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h" | |||
| #include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h" | |||
| @@ -585,6 +586,7 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) { | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::BatchNormGradUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::DynamicReshapeUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2UnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>()); | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/trace_base.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| size_t kDynamicReshapeInputNum = 2; | |||
| AnfNodePtr CreateDynamicReshape(const FuncGraphPtr &graph, const CNodePtr &reshape_node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(reshape_node); | |||
| const auto &reshape_node_inputs = reshape_node->inputs(); | |||
| CheckCNodeInputSize(reshape_node, kDynamicReshapeInputNum); | |||
| std::vector<AnfNodePtr> dynamic_reshape_inputs = {NewValueNode(std::make_shared<Primitive>(kDynamicReshapeOpName)), | |||
| reshape_node_inputs[kDim1], reshape_node_inputs[kDim2]}; | |||
| auto dynamic_reshape_node = graph->NewCNode(dynamic_reshape_inputs); | |||
| MS_EXCEPTION_IF_NULL(dynamic_reshape_node); | |||
| dynamic_reshape_node->set_scope(reshape_node->scope()); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(reshape_node, 0)}; | |||
| auto shapes = {AnfAlgo::GetOutputDetailShape(reshape_node, 0)}; | |||
| AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, dynamic_reshape_node.get()); | |||
| AnfAlgo::CopyNodeAttrs(reshape_node, dynamic_reshape_node); | |||
| return dynamic_reshape_node; | |||
| } | |||
| } // namespace | |||
| const BaseRef DynamicReshapeUnifyMindIR::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| auto prim = std::make_shared<Primitive>(kReshapeOpName); | |||
| return VectorRef({prim, Xs}); | |||
| } | |||
| const AnfNodePtr DynamicReshapeUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() < kDynamicReshapeInputNum + 1) { | |||
| return nullptr; | |||
| } | |||
| auto shp_input = cnode->input(kDynamicReshapeInputNum); | |||
| if (shp_input->isa<ValueNode>()) { | |||
| return nullptr; | |||
| } | |||
| return CreateDynamicReshape(func_graph, cnode); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DYNAMIC_RESHAPE_UNIFY_MINDIR_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DYNAMIC_RESHAPE_UNIFY_MINDIR_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class DynamicReshapeUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit DynamicReshapeUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("dynamic_reshape_unify_mindir", multigraph) {} | |||
| ~DynamicReshapeUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DYNAMIC_RESHAPE_UNIFY_MINDIR_H_ | |||
| @@ -95,6 +95,7 @@ constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd"; | |||
| constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin"; | |||
| constexpr auto kFlattenGradOpName = "FlattenGrad"; | |||
| constexpr auto kExpandDimsOpName = "ExpandDims"; | |||
| constexpr auto kDynamicReshapeOpName = "DynamicReshape"; | |||
| constexpr auto kReshapeOpName = "Reshape"; | |||
| constexpr auto kTransposeOpName = "Transpose"; | |||
| constexpr auto kTransposeNODOpName = "TransposeNOD"; | |||
| @@ -827,12 +827,19 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| } else { | |||
| ValuePtr sh = primitive->GetAttr("shape"); | |||
| MS_EXCEPTION_IF_NULL(sh); | |||
| auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(reshape_value_tuple); | |||
| auto reshape_tuple = reshape_value_tuple->value(); | |||
| (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| if (sh->isa<ValueTuple>()) { | |||
| auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(reshape_value_tuple); | |||
| auto reshape_tuple = reshape_value_tuple->value(); | |||
| (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| } else if (sh->isa<tensor::Tensor>()) { | |||
| auto tensor_value = sh->cast<tensor::TensorPtr>(); | |||
| shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape"); | |||
| } else { | |||
| MS_EXCEPTION(ValueError) << "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or " | |||
| << "constant Tensor."; | |||
| } | |||
| } | |||
| auto max_shape = shape; | |||
| @@ -62,6 +62,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| const auto kSlice = prim::kPrimSlice->name(); | |||
| const auto kSliceGrad = prim::kPrimSliceGrad->name(); | |||
| const auto kReshape = prim::kPrimReshape->name(); | |||
| const auto kDynamicReshape = prim::kPrimDynamicReshape->name(); | |||
| // common dynamic shape depends | |||
| static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {{kUnsortedSegmentSum, {2}}, | |||
| {kUnsortedSegmentMin, {2}}, | |||
| @@ -78,6 +79,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| {kStridedSliceGrad, {1, 2, 3, 4}}, | |||
| {kTile, {1}}, | |||
| {kReshape, {1}}, | |||
| {kDynamicReshape, {1}}, | |||
| {kSlice, {1, 2}}, | |||
| {kSliceGrad, {2, 3}}, | |||
| {kDynamicBroadcastTo, {1}}}; | |||
| @@ -150,6 +152,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimDynamicStitch, {InferImplDynamicStitch, nullptr, true}}, | |||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}}, | |||
| {prim::kPrimDynamicReshape, {InferImplReshape, nullptr, true}}, | |||
| {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, nullptr, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}}, | |||
| @@ -97,6 +97,7 @@ constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; | |||
| constexpr auto kTranspose = "Transpose"; | |||
| constexpr auto kSplitV = "SplitV"; | |||
| constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo"; | |||
| constexpr auto kDynamicReshape = "DynamicReshape"; | |||
| // NN | |||
| constexpr auto kCTCLoss = "CTCLoss"; | |||
| @@ -209,6 +210,7 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>( | |||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||
| inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||
| inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||
| inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>(kDynamicReshape); | |||
| inline const PrimitivePtr kPrimSubAndFilter = std::make_shared<Primitive>("SubAndFilter"); | |||
| inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx"); | |||
| inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | |||
| @@ -14,6 +14,6 @@ | |||
| # ============================================================================ | |||
| """ops utils.""" | |||
| from .utils import get_broadcast_shape, get_concat_offset | |||
| from .utils import get_broadcast_shape, get_concat_offset, is_shape_unknown | |||
| __all__ = ['get_broadcast_shape', 'get_concat_offset'] | |||
| @@ -135,3 +135,19 @@ def generate_shape_index(out_shape, indices_shape, axis): | |||
| index = tuple(range(out_rank)) | |||
| perm = perm_part1 + index[:axis] + index[axis + ind_rank:] | |||
| return perm | |||
| @constexpr | |||
| def is_shape_unknown(shape): | |||
| for i in shape: | |||
| if i < 0: | |||
| return True | |||
| return False | |||
| @constexpr | |||
| def is_dim_unknown(shape): | |||
| for i in shape: | |||
| if i == -2: | |||
| return True | |||
| return False | |||
| @@ -28,7 +28,7 @@ import numpy as np | |||
| from mindspore import log as logger | |||
| from mindspore.common.initializer import Zero | |||
| from .. import signature as sig | |||
| from .._utils import get_broadcast_shape | |||
| from .._utils import get_broadcast_shape, is_shape_unknown | |||
| from .._utils import get_concat_offset | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||
| @@ -486,18 +486,49 @@ class Reshape(PrimitiveWithInfer): | |||
| """Initialize Reshape""" | |||
| self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) | |||
| def _get_shape_and_range(self, x, shape): | |||
| """ get min and max shape when output shape is dynamic""" | |||
| min_shape = None | |||
| max_shape = None | |||
| x_shp = x['shape'] | |||
| if is_shape_unknown(shape['shape']): | |||
| out_shape = [-2] | |||
| return out_shape, min_shape, max_shape | |||
| shape_rank = shape['shape'][0] | |||
| if not x_shp: | |||
| # x is a scalar, output shape fixed | |||
| out_shape = [1] * shape_rank | |||
| return out_shape, min_shape, max_shape | |||
| out_shape = [-1] * shape_rank | |||
| if "max_value" in shape and "min_value" in shape: | |||
| min_shape = shape["min_value"] | |||
| max_shape = shape["max_value"] | |||
| if len(min_shape) != shape_rank or len(max_shape) != shape_rank: | |||
| raise RuntimeError("The primitive[Reshape]'s input[shape] min or max value not math the shape rank.") | |||
| for i in range(shape_rank): | |||
| if min_shape[i] == max_shape[i]: | |||
| out_shape[i] = min_shape[i] | |||
| elif is_shape_unknown(x_shp) and "max_shape" in x: | |||
| # when dynamic memory allocation is supported, max_shape can be left out | |||
| min_shape = [1] * shape_rank | |||
| max_shape = [int(np.prod(x["max_shape"]))] * shape_rank | |||
| return out_shape, min_shape, max_shape | |||
| def __infer__(self, x, shape): | |||
| shape_v = shape['value'] | |||
| if not shape_v and shape['shape']: | |||
| # for shape is not const value and not none | |||
| shape_rank = shape['shape'][0] | |||
| # unknown dims for shape | |||
| if shape_rank == -1: | |||
| shape_rank = 1 | |||
| out_shape = [-1] * shape_rank | |||
| min_shape = [1] * shape_rank | |||
| max_shape = [1] * shape_rank | |||
| x_shp = x['shape'] | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| # for shape is not constant | |||
| if shape_v is None: | |||
| out_shape, min_shape, max_shape = self._get_shape_and_range(x, shape) | |||
| if is_shape_unknown(out_shape): | |||
| # `min_shape` and `max_shape` can't be None before dynamic memory allocation is supported | |||
| shape_shp = shape['shape'] | |||
| shape_rank = 1 if is_shape_unknown(shape_shp) else shape_shp[0] | |||
| min_shape = [1] * shape_rank if min_shape is None else min_shape | |||
| max_shape = [1] * shape_rank if max_shape is None else max_shape | |||
| return { | |||
| 'shape': out_shape, | |||
| 'dtype': x['dtype'], | |||
| @@ -506,10 +537,13 @@ class Reshape(PrimitiveWithInfer): | |||
| 'min_shape': min_shape | |||
| } | |||
| x_shp = x['shape'] | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| shape_v = list(shape_v) | |||
| if isinstance(shape_v, Tensor_): | |||
| validator.check_tensor_dtype_valid("shape", shape['dtype'], [mstype.int64], self.name) | |||
| shape_v = shape_v.asnumpy().tolist() | |||
| else: | |||
| validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| shape_v = list(shape_v) | |||
| neg_index = -1 | |||
| dim_prod = 1 | |||
| for i, shp_i in enumerate(shape_v): | |||
| @@ -522,7 +556,7 @@ class Reshape(PrimitiveWithInfer): | |||
| else: | |||
| dim_prod *= shp_i | |||
| if -1 in x_shp: | |||
| if is_shape_unknown(x_shp): | |||
| if 'max_shape' in x: | |||
| x_max_shape = x['max_shape'] | |||
| else: | |||