Browse Source

move dynamic_shape_depend to backend

Signed-off-by: zhupuxu <zhupuxu@huawei.com>
tags/v1.2.0-rc1
zhupuxu 4 years ago
parent
commit
bc0775748e
2 changed files with 10 additions and 13 deletions
  1. +10
    -4
      mindspore/core/abstract/primitive_infer_map.cc
  2. +0
    -9
      mindspore/ops/operations/array_ops.py

+ 10
- 4
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -23,11 +23,17 @@
namespace mindspore {
namespace abstract {
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum";
constexpr auto kUnsortedSegmentMin = "UnsortedSegmentMin";
constexpr auto kUnsortedSegmentMax = "UnsortedSegmentMax";
const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
const auto kGather = prim::kPrimGather->name();
const auto kGatherV2 = prim::kPrimGatherV2->name();
const auto kDynamicShape = prim::kPrimDynamicShape->name();
const auto kRange = prim::kPrimRange->name();
static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {
{kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}};
{kUnsortedSegmentSum, {2}}, {kUnsortedSegmentMin, {2}}, {kUnsortedSegmentMax, {2}}, {kGather, {2}},
{kGatherV2, {2}}, {kDynamicShape, {0}}, {kRange, {0, 1, 2}},
};
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";


+ 0
- 9
mindspore/ops/operations/array_ops.py View File

@@ -593,7 +593,6 @@ class DynamicShape(Primitive):
"""init Shape"""
self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
self.add_prim_attr('is_dynamic_shape', True)
self.add_prim_attr("dynamic_shape_depends", [0])


class Squeeze(PrimitiveWithInfer):
@@ -811,7 +810,6 @@ class Gather(PrimitiveWithCheck):
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
@@ -836,7 +834,6 @@ class GatherV2(PrimitiveWithCheck):
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
@@ -1987,7 +1984,6 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
def __init__(self):
"""Initialize UnsortedSegmentMin"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape']
@@ -2043,7 +2039,6 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
def __init__(self):
"""Initialize UnsortedSegmentMax"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, x, segment_ids, num_segments):
x_shape = x['shape']
@@ -4980,10 +4975,6 @@ class Range(PrimitiveWithCheck):
self.maxlen = maxlen
self.add_prim_attr('maxlen', maxlen)

self.add_prim_attr("dynamic_shape_depends", [0])
self.add_prim_attr("dynamic_shape_depends", [1])
self.add_prim_attr("dynamic_shape_depends", [2])

def check_shape(self, start_shape, limit_shape, delta_shape):
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)


Loading…
Cancel
Save