Browse Source

!25592 Reshape support shape is variable

Merge pull request !25592 from wangnan39/reshape_support_tensor
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
5233c73805
14 changed files with 353 additions and 38 deletions
  1. +6
    -13
      mindspore/ccsrc/backend/kernel_compiler/host/dynamic_broadcast_gradient_args_kernel.cc
  2. +107
    -0
      mindspore/ccsrc/backend/kernel_compiler/host/dynamic_reshape_kernel.cc
  3. +43
    -0
      mindspore/ccsrc/backend/kernel_compiler/host/dynamic_reshape_kernel.h
  4. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc
  5. +2
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
  6. +73
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.cc
  7. +34
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h
  8. +1
    -0
      mindspore/ccsrc/utils/utils.h
  9. +13
    -6
      mindspore/core/abstract/prim_arrays.cc
  10. +3
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  11. +2
    -0
      mindspore/core/base/core_ops.h
  12. +1
    -1
      mindspore/ops/_utils/__init__.py
  13. +16
    -0
      mindspore/ops/_utils/utils.py
  14. +50
    -16
      mindspore/ops/operations/array_ops.py

+ 6
- 13
mindspore/ccsrc/backend/kernel_compiler/host/dynamic_broadcast_gradient_args_kernel.cc View File

@@ -149,23 +149,16 @@ std::vector<int64_t> GetInputShape(const CNodePtr &cnode, size_t index) {
size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64_t>> &grad_reduce_idx, size_t index,
size_t input_num) {
std::vector<int64_t> output;
size_t idx_num = grad_reduce_idx[index].size();
for (size_t k = 0; k < idx_num; ++k) {
output.push_back(grad_reduce_idx[index][idx_num - 1 - k]);
size_t out_size = grad_reduce_idx[index].size();
for (size_t k = 0; k < out_size; ++k) {
output.push_back(grad_reduce_idx[index][out_size - 1 - k]);
}
if (out_size == 0) {
return out_size;
}
auto out_addr = AnfAlgo::GetOutputAddr(cnode, index);
MS_EXCEPTION_IF_NULL(out_addr);
size_t out_size = idx_num;
if (idx_num == 0) {
out_size = input_num;
for (size_t k = 0; k < input_num; ++k) {
output.push_back(k);
}
}
std::vector<int64_t> out_shape{SizeToLong(out_size)};
auto output_type = TypeId::kNumberTypeInt64;
auto tensor_for_sync = std::make_shared<tensor::Tensor>(output_type, out_shape);


+ 107
- 0
mindspore/ccsrc/backend/kernel_compiler/host/dynamic_reshape_kernel.cc View File

@@ -0,0 +1,107 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/host/dynamic_reshape_kernel.h"
#include <functional>
#include "backend/session/anf_runtime_algorithm.h"
#include "abstract/utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputNum = 2;
std::vector<int64_t> GetInputValue(const CNodePtr &cnode, size_t index) {
auto address_x = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, index);
auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape_x.size() != 1) {
MS_LOG(EXCEPTION) << "Input" << index << " must be [1-D], but " << shape_x.size() << "-D.";
}
session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, index);
auto type_x = AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
if (type_x != TypeId::kNumberTypeInt64 && type_x != TypeId::kNumberTypeInt32) {
MS_LOG(EXCEPTION) << "Input x type must be int64 or int32, but :" << TypeIdToType(type_x);
}
size_t x_num = shape_x[0];
std::vector<int64_t> x{SizeToLong(x_num)};
auto x_shape_value = std::make_shared<tensor::Tensor>(type_x, x);
x_shape_value->set_device_address(address_x);
x_shape_value->data_sync();
std::vector<int64_t> input_shape;
if (type_x == TypeId::kNumberTypeInt64) {
auto x_value = reinterpret_cast<int64_t *>(x_shape_value->data_c());
MS_EXCEPTION_IF_NULL(x_value);
input_shape = {x_value, x_value + x_num};
} else {
auto x_value = reinterpret_cast<int *>(x_shape_value->data_c());
MS_EXCEPTION_IF_NULL(x_value);
for (size_t i = 0; i < x_num; i++) {
input_shape.push_back(static_cast<int64_t>(*x_value));
++x_value;
}
}
return input_shape;
}
} // namespace
void DynamicReshapeKernel::Execute() {
MS_LOG(INFO) << "Execute host ReshapeKernel Start";
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = AnfAlgo::GetInputTensorNum(cnode);
if (input_num != kInputNum) {
MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num;
}
auto address_x = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, 0);
MS_EXCEPTION_IF_NULL(address_x);
auto type_x = AnfAlgo::GetOutputInferDataType(cnode, 0);
auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
std::vector<int64_t> output_shapes = GetInputValue(cnode, 1);
int64_t dim_prod = 1;
int64_t neg_index = -1;
auto arr_prod = std::accumulate(shape_x.begin(), shape_x.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
for (size_t i = 0; i < output_shapes.size(); ++i) {
if (output_shapes[i] == -1) {
neg_index = SizeToLong(i);
} else {
dim_prod *= output_shapes[i];
}
}
if (neg_index != -1) {
output_shapes[LongToSize(neg_index)] = arr_prod / dim_prod;
}
size_t input_size_byte = LongToSize(arr_prod) * abstract::TypeIdSize(type_x);
auto output_addr = AnfAlgo::GetOutputAddr(cnode, 0);
MS_EXCEPTION_IF_NULL(output_addr);
if (!output_addr->SyncDeviceToDevice(output_shapes, input_size_byte, address_x->type_id(), address_x->GetPtr(),
address_x->format())) {
MS_LOG(EXCEPTION) << "Host Reshape sync device to device failed.";
}
MS_LOG(INFO) << "Execute host ReshapeKernel End";
}
device::DynamicKernelPtr DynamicReshapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
return std::make_shared<DynamicReshapeKernel>(stream_ptr, cnode_ptr);
}
} // namespace kernel
} // namespace mindspore

+ 43
- 0
mindspore/ccsrc/backend/kernel_compiler/host/dynamic_reshape_kernel.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_RESHAPE_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_RESHAPE_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
#include "backend/kernel_compiler/host/host_kernel_mod.h"
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
namespace mindspore {
namespace kernel {
class DynamicReshapeKernel : public HostDynamicKernel {
public:
DynamicReshapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {}
~DynamicReshapeKernel() override = default;
void Execute() override;
};
class DynamicReshapeKernelMod : public HostKernelMod {
public:
DynamicReshapeKernelMod() = default;
~DynamicReshapeKernelMod() override = default;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
};
MS_HOST_REG_KERNEL(DynamicReshape, DynamicReshapeKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_RESHAPE_KERNEL_H_

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc View File

@@ -24,8 +24,8 @@
namespace mindspore {
namespace kernel {
static const std::set<std::string> host_kernel = {prim::kPrimDynamicShape->name(),
prim::kPrimDynamicBroadcastGradientArgs->name()};
static const std::set<std::string> host_kernel = {
prim::kPrimDynamicShape->name(), prim::kPrimDynamicBroadcastGradientArgs->name(), prim::kPrimDynamicReshape->name()};
void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
MS_LOG(INFO) << "HostMetadataInfo.";


+ 2
- 0
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc View File

@@ -144,6 +144,7 @@
#include "backend/optimizer/ascend/mindir/update_input_names_strided_slice_grad.h"
#include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/bn_grad_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
#include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
@@ -585,6 +586,7 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::BatchNormGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DynamicReshapeUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2UnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());


+ 73
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.cc View File

@@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h"

#include <vector>
#include <memory>

#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/trace_base.h"

namespace mindspore {
namespace opt {
namespace {
size_t kDynamicReshapeInputNum = 2;

AnfNodePtr CreateDynamicReshape(const FuncGraphPtr &graph, const CNodePtr &reshape_node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(reshape_node);
const auto &reshape_node_inputs = reshape_node->inputs();
CheckCNodeInputSize(reshape_node, kDynamicReshapeInputNum);
std::vector<AnfNodePtr> dynamic_reshape_inputs = {NewValueNode(std::make_shared<Primitive>(kDynamicReshapeOpName)),
reshape_node_inputs[kDim1], reshape_node_inputs[kDim2]};
auto dynamic_reshape_node = graph->NewCNode(dynamic_reshape_inputs);
MS_EXCEPTION_IF_NULL(dynamic_reshape_node);
dynamic_reshape_node->set_scope(reshape_node->scope());
auto types = {AnfAlgo::GetOutputInferDataType(reshape_node, 0)};
auto shapes = {AnfAlgo::GetOutputDetailShape(reshape_node, 0)};
AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, dynamic_reshape_node.get());
AnfAlgo::CopyNodeAttrs(reshape_node, dynamic_reshape_node);
return dynamic_reshape_node;
}
} // namespace

const BaseRef DynamicReshapeUnifyMindIR::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kReshapeOpName);
return VectorRef({prim, Xs});
}

const AnfNodePtr DynamicReshapeUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);

auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() < kDynamicReshapeInputNum + 1) {
return nullptr;
}
auto shp_input = cnode->input(kDynamicReshapeInputNum);
if (shp_input->isa<ValueNode>()) {
return nullptr;
}
return CreateDynamicReshape(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore

+ 34
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DYNAMIC_RESHAPE_UNIFY_MINDIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DYNAMIC_RESHAPE_UNIFY_MINDIR_H_

#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"

namespace mindspore {
namespace opt {
class DynamicReshapeUnifyMindIR : public PatternProcessPass {
public:
explicit DynamicReshapeUnifyMindIR(bool multigraph = true)
: PatternProcessPass("dynamic_reshape_unify_mindir", multigraph) {}
~DynamicReshapeUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DYNAMIC_RESHAPE_UNIFY_MINDIR_H_

+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -95,6 +95,7 @@ constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd";
constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
constexpr auto kFlattenGradOpName = "FlattenGrad";
constexpr auto kExpandDimsOpName = "ExpandDims";
constexpr auto kDynamicReshapeOpName = "DynamicReshape";
constexpr auto kReshapeOpName = "Reshape";
constexpr auto kTransposeOpName = "Transpose";
constexpr auto kTransposeNODOpName = "TransposeNOD";


+ 13
- 6
mindspore/core/abstract/prim_arrays.cc View File

@@ -827,12 +827,19 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
} else {
ValuePtr sh = primitive->GetAttr("shape");
MS_EXCEPTION_IF_NULL(sh);
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
auto reshape_tuple = reshape_value_tuple->value();

(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
if (sh->isa<ValueTuple>()) {
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
auto reshape_tuple = reshape_value_tuple->value();
(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
} else if (sh->isa<tensor::Tensor>()) {
auto tensor_value = sh->cast<tensor::TensorPtr>();
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape");
} else {
MS_EXCEPTION(ValueError) << "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or "
<< "constant Tensor.";
}
}

auto max_shape = shape;


+ 3
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -62,6 +62,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
const auto kSlice = prim::kPrimSlice->name();
const auto kSliceGrad = prim::kPrimSliceGrad->name();
const auto kReshape = prim::kPrimReshape->name();
const auto kDynamicReshape = prim::kPrimDynamicReshape->name();
// common dynamic shape depends
static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {{kUnsortedSegmentSum, {2}},
{kUnsortedSegmentMin, {2}},
@@ -78,6 +79,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
{kStridedSliceGrad, {1, 2, 3, 4}},
{kTile, {1}},
{kReshape, {1}},
{kDynamicReshape, {1}},
{kSlice, {1, 2}},
{kSliceGrad, {2, 3}},
{kDynamicBroadcastTo, {1}}};
@@ -150,6 +152,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDynamicStitch, {InferImplDynamicStitch, nullptr, true}},
{prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
{prim::kPrimDynamicReshape, {InferImplReshape, nullptr, true}},
{prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
{prim::kPrimSplit, {InferImplSplit, nullptr, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},


+ 2
- 0
mindspore/core/base/core_ops.h View File

@@ -97,6 +97,7 @@ constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
constexpr auto kTranspose = "Transpose";
constexpr auto kSplitV = "SplitV";
constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo";
constexpr auto kDynamicReshape = "DynamicReshape";

// NN
constexpr auto kCTCLoss = "CTCLoss";
@@ -209,6 +210,7 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>(kDynamicReshape);
inline const PrimitivePtr kPrimSubAndFilter = std::make_shared<Primitive>("SubAndFilter");
inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx");
inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache");


+ 1
- 1
mindspore/ops/_utils/__init__.py View File

@@ -14,6 +14,6 @@
# ============================================================================

"""ops utils."""
from .utils import get_broadcast_shape, get_concat_offset
from .utils import get_broadcast_shape, get_concat_offset, is_shape_unknown

__all__ = ['get_broadcast_shape', 'get_concat_offset']

+ 16
- 0
mindspore/ops/_utils/utils.py View File

@@ -135,3 +135,19 @@ def generate_shape_index(out_shape, indices_shape, axis):
index = tuple(range(out_rank))
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
return perm


@constexpr
def is_shape_unknown(shape):
for i in shape:
if i < 0:
return True
return False


@constexpr
def is_dim_unknown(shape):
for i in shape:
if i == -2:
return True
return False

+ 50
- 16
mindspore/ops/operations/array_ops.py View File

@@ -28,7 +28,7 @@ import numpy as np
from mindspore import log as logger
from mindspore.common.initializer import Zero
from .. import signature as sig
from .._utils import get_broadcast_shape
from .._utils import get_broadcast_shape, is_shape_unknown
from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
@@ -486,18 +486,49 @@ class Reshape(PrimitiveWithInfer):
"""Initialize Reshape"""
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])

def _get_shape_and_range(self, x, shape):
""" get min and max shape when output shape is dynamic"""
min_shape = None
max_shape = None
x_shp = x['shape']
if is_shape_unknown(shape['shape']):
out_shape = [-2]
return out_shape, min_shape, max_shape

shape_rank = shape['shape'][0]
if not x_shp:
# x is a scalar, output shape fixed
out_shape = [1] * shape_rank
return out_shape, min_shape, max_shape

out_shape = [-1] * shape_rank
if "max_value" in shape and "min_value" in shape:
min_shape = shape["min_value"]
max_shape = shape["max_value"]
if len(min_shape) != shape_rank or len(max_shape) != shape_rank:
raise RuntimeError("The primitive[Reshape]'s input[shape] min or max value not math the shape rank.")
for i in range(shape_rank):
if min_shape[i] == max_shape[i]:
out_shape[i] = min_shape[i]
elif is_shape_unknown(x_shp) and "max_shape" in x:
# when dynamic memory allocation is supported, max_shape can be left out
min_shape = [1] * shape_rank
max_shape = [int(np.prod(x["max_shape"]))] * shape_rank
return out_shape, min_shape, max_shape

def __infer__(self, x, shape):
shape_v = shape['value']
if not shape_v and shape['shape']:
# for shape is not const value and not none
shape_rank = shape['shape'][0]

# unknown dims for shape
if shape_rank == -1:
shape_rank = 1
out_shape = [-1] * shape_rank
min_shape = [1] * shape_rank
max_shape = [1] * shape_rank
x_shp = x['shape']
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
# for shape is not constant
if shape_v is None:
out_shape, min_shape, max_shape = self._get_shape_and_range(x, shape)
if is_shape_unknown(out_shape):
# `min_shape` and `max_shape` can't be None before dynamic memory allocation is supported
shape_shp = shape['shape']
shape_rank = 1 if is_shape_unknown(shape_shp) else shape_shp[0]
min_shape = [1] * shape_rank if min_shape is None else min_shape
max_shape = [1] * shape_rank if max_shape is None else max_shape
return {
'shape': out_shape,
'dtype': x['dtype'],
@@ -506,10 +537,13 @@ class Reshape(PrimitiveWithInfer):
'min_shape': min_shape
}

x_shp = x['shape']
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_value_type("shape", shape_v, [tuple], self.name)
shape_v = list(shape_v)
if isinstance(shape_v, Tensor_):
validator.check_tensor_dtype_valid("shape", shape['dtype'], [mstype.int64], self.name)
shape_v = shape_v.asnumpy().tolist()
else:
validator.check_value_type("shape", shape_v, [tuple], self.name)
shape_v = list(shape_v)

neg_index = -1
dim_prod = 1
for i, shp_i in enumerate(shape_v):
@@ -522,7 +556,7 @@ class Reshape(PrimitiveWithInfer):
else:
dim_prod *= shp_i

if -1 in x_shp:
if is_shape_unknown(x_shp):
if 'max_shape' in x:
x_max_shape = x['max_shape']
else:


Loading…
Cancel
Save