浏览代码

develop op ScatterMax and dock ge process

tags/v0.3.0-alpha
buxue 5 年前
父节点
当前提交
ac86996746
共有 6 个文件被更改,包括 128 次插入54 次删除
  1. +2
    -0
      mindspore/ccsrc/transform/convert.cc
  2. +5
    -0
      mindspore/ccsrc/transform/op_declare.cc
  3. +2
    -0
      mindspore/ccsrc/transform/op_declare.h
  4. +2
    -1
      mindspore/ops/operations/__init__.py
  5. +48
    -1
      mindspore/ops/operations/array_ops.py
  6. +69
    -52
      tests/ut/python/ops/test_ops.py

+ 2
- 0
mindspore/ccsrc/transform/convert.cc 查看文件

@@ -102,6 +102,7 @@ const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu"; const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad"; const char kNameEluGrad[] = "EluGrad";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax";
const char kNameNMSWithMask[] = "NMSWithMask"; const char kNameNMSWithMask[] = "NMSWithMask";
const char kNameCheckValid[] = "CheckValid"; const char kNameCheckValid[] = "CheckValid";
const char kNameSmoothL1Loss[] = "SmoothL1Loss"; const char kNameSmoothL1Loss[] = "SmoothL1Loss";
@@ -253,6 +254,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
{string(kNameCheckValid), ADPT_DESC(CheckValid)}, {string(kNameCheckValid), ADPT_DESC(CheckValid)},
{string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)}, {string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)},


+ 5
- 0
mindspore/ccsrc/transform/op_declare.cc 查看文件

@@ -530,6 +530,11 @@ INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}}; OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}};


// ScatterMax
INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}};

// CheckValid // CheckValid
INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}}; INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}};
ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP; ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP;


+ 2
- 0
mindspore/ccsrc/transform/op_declare.h 查看文件

@@ -136,6 +136,8 @@ DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike) DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(ScatterNdUpdate) DECLARE_OP_ADAPTER(ScatterNdUpdate)
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
DECLARE_OP_ADAPTER(ScatterMax)
DECLARE_OP_USE_OUTPUT(ScatterMax)
DECLARE_OP_ADAPTER(NMSWithMask) DECLARE_OP_ADAPTER(NMSWithMask)
DECLARE_OP_USE_OUTPUT(NMSWithMask) DECLARE_OP_USE_OUTPUT(NMSWithMask)
DECLARE_OP_ADAPTER(Unpack) DECLARE_OP_ADAPTER(Unpack)


+ 2
- 1
mindspore/ops/operations/__init__.py 查看文件

@@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, InvertPermutation, Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape,
SameTypeShape, ScatterMax,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, Squeeze, StridedSlice, Tile,
@@ -184,6 +184,7 @@ __all__ = [
'BoundingBoxDecode', 'BoundingBoxDecode',
'L2Normalize', 'L2Normalize',
'ScatterNd', 'ScatterNd',
'ScatterMax',
'ResizeNearestNeighbor', 'ResizeNearestNeighbor',
'Pad', 'Pad',
'MirrorPad', 'MirrorPad',


+ 48
- 1
mindspore/ops/operations/array_ops.py 查看文件

@@ -1953,7 +1953,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Using given values to update tensor value, along with the input indices. Using given values to update tensor value, along with the input indices.


Args: Args:
use_locking (bool): Whether protect the assignment by a lock. Defaule: True.
use_locking (bool): Whether protect the assignment by a lock. Default: True.


Inputs: Inputs:
- **input_x** (Tensor) - The target tensor. - **input_x** (Tensor) - The target tensor.
@@ -1995,6 +1995,53 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_dtype return x_dtype




class ScatterMax(PrimitiveWithInfer):
"""
Update the value of the input tensor through the max operation.

Using given values to update tensor value through the max operation, along with the input indices,.

Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.

Inputs:
- **input_x** (Tensor) - The target tensor.
- **indices** (Tensor) - The index to do max operation whose data type should be int.
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.

Outputs:
Tensor, has the same shape and data type as `input_x`.

Examples:
>>> input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
>>> scatter_max = P.ScatterMax()
>>> output = scatter_max(input_x, indices, update)
[[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
"""

@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterMax"""
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
validator.check_value_type('use_locking', use_locking, (bool,), self.name)

def infer_shape(self, x_shape, indices_shape, updates_shape):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype


class SpaceToDepth(PrimitiveWithInfer): class SpaceToDepth(PrimitiveWithInfer):
r""" r"""
Rearrange blocks of spatial data into depth. Rearrange blocks of spatial data into depth.


+ 69
- 52
tests/ut/python/ops/test_ops.py 查看文件

@@ -15,7 +15,7 @@
""" test ops """ """ test ops """
import functools import functools
import numpy as np import numpy as np
from mindspore import ops
from mindspore import ops, Parameter, context
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
@@ -26,10 +26,10 @@ from mindspore.common import dtype as mstype
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine


from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward\
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config import pipeline_for_compile_grad_ge_graph_for_case_by_case_config




@@ -150,7 +150,7 @@ class CumSumNet(nn.Cell):




class SummaryNet(nn.Cell): class SummaryNet(nn.Cell):
def __init__(self,):
def __init__(self):
super(SummaryNet, self).__init__() super(SummaryNet, self).__init__()
self.s = P.ScalarSummary() self.s = P.ScalarSummary()
self.add = P.TensorAdd() self.add = P.TensorAdd()
@@ -161,7 +161,7 @@ class SummaryNet(nn.Cell):




class HistogramSummaryNet(nn.Cell): class HistogramSummaryNet(nn.Cell):
def __init__(self,):
def __init__(self):
super(HistogramSummaryNet, self).__init__() super(HistogramSummaryNet, self).__init__()
self.summary = P.HistogramSummary() self.summary = P.HistogramSummary()
self.add = P.TensorAdd() self.add = P.TensorAdd()
@@ -173,6 +173,19 @@ class HistogramSummaryNet(nn.Cell):
return out return out




class ScatterMax(nn.Cell):
"""ScatterMax net definition"""

def __init__(self):
super(ScatterMax, self).__init__()
self.scatter_max = P.ScatterMax()
self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], np.float32)), name="ref")

def construct(self, indices, updates):
out = self.scatter_max(self.ref, indices, updates)
return out


test_case_math_ops = [ test_case_math_ops = [
('Neg', { ('Neg', {
'block': P.Neg(), 'block': P.Neg(),
@@ -298,28 +311,28 @@ test_case_math_ops = [
('StridedSlice', { ('StridedSlice', {
'block': P.StridedSlice(), 'block': P.StridedSlice(),
'desc_const': [(0, 1, 2, 1), 'desc_const': [(0, 1, 2, 1),
(2, 3, 3, 4),
(1, 1, 1, 1)],
(2, 3, 3, 4),
(1, 1, 1, 1)],
'desc_inputs': [[2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[2, 2, 1, 3]]}), 'desc_bprop': [[2, 2, 1, 3]]}),
('Slice_1', { ('Slice_1', {
'block': P.Slice(), 'block': P.Slice(),
'desc_const': [(0, 1, 2, 1), 'desc_const': [(0, 1, 2, 1),
(1, 1, 1, 2)],
(1, 1, 1, 2)],
'desc_inputs': [[2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[1, 1, 1, 2]]}), 'desc_bprop': [[1, 1, 1, 2]]}),
('StridedSliceGrad', { ('StridedSliceGrad', {
'block': G.StridedSliceGrad(), 'block': G.StridedSliceGrad(),
'desc_const': [(64, 1, 1024), 'desc_const': [(64, 1, 1024),
(0, 1, 0),
(64, 2, 1024),
(1, 1, 1)],
(0, 1, 0),
(64, 2, 1024),
(1, 1, 1)],
'desc_inputs': [[64, 128, 1024]], 'desc_inputs': [[64, 128, 1024]],
'skip': ['backward']}), 'skip': ['backward']}),
('RandomChoiceWithMask', { ('RandomChoiceWithMask', {
'block': P.RandomChoiceWithMask(256), 'block': P.RandomChoiceWithMask(256),
'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))], 'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))],
'desc_bprop': [[256,4], [256,4]],
'desc_bprop': [[256, 4], [256, 4]],
'skip': ['backward']}), 'skip': ['backward']}),
('LessEqual', { ('LessEqual', {
'block': P.LessEqual(), 'block': P.LessEqual(),
@@ -419,7 +432,7 @@ test_case_math_ops = [
'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}), 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}),
('NotEqual_0', { ('NotEqual_0', {
'block': P.NotEqual(), 'block': P.NotEqual(),
'desc_inputs': [ 1, [2, 3, 4, 5]],
'desc_inputs': [1, [2, 3, 4, 5]],
'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))], 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))],
'skip': ['backward']}), 'skip': ['backward']}),
('Greater', { ('Greater', {
@@ -435,13 +448,13 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_))], 'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_))],
'desc_bprop': [Tensor(np.ones((3, 4, 5), np.bool_))]}), 'desc_bprop': [Tensor(np.ones((3, 4, 5), np.bool_))]}),
('LogicalAnd', { ('LogicalAnd', {
'block': P.LogicalAnd(),
'desc_inputs': [Tensor(np.zeros((2, 3, 4), np.bool_)), Tensor(np.ones((1), np.bool_))],
'desc_bprop': [Tensor(np.zeros((2, 3, 4), np.bool_))]}),
'block': P.LogicalAnd(),
'desc_inputs': [Tensor(np.zeros((2, 3, 4), np.bool_)), Tensor(np.ones((1), np.bool_))],
'desc_bprop': [Tensor(np.zeros((2, 3, 4), np.bool_))]}),
('LogicalOr', { ('LogicalOr', {
'block': P.LogicalOr(),
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))],
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}),
'block': P.LogicalOr(),
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))],
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}),
('NpuAllocFloatStatus', { ('NpuAllocFloatStatus', {
'block': P.NPUAllocFloatStatus(), 'block': P.NPUAllocFloatStatus(),
'desc_inputs': [], 'desc_inputs': [],
@@ -476,8 +489,8 @@ test_case_math_ops = [
('CumSum', { ('CumSum', {
'block': P.CumSum(), 'block': P.CumSum(),
'desc_const': [0], 'desc_const': [0],
'desc_inputs': [Tensor(np.array([[3, 4],[1, 6]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[3, 4],[4, 10]]).astype(np.float16))]}),
'desc_inputs': [Tensor(np.array([[3, 4], [1, 6]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[3, 4], [4, 10]]).astype(np.float16))]}),
('ReduceSum_3', { ('ReduceSum_3', {
'block': P.ReduceSum(), 'block': P.ReduceSum(),
'desc_const': [0], 'desc_const': [0],
@@ -717,8 +730,8 @@ test_case_nn_ops = [
('UnsortedSegmentSum', { ('UnsortedSegmentSum', {
'block': P.UnsortedSegmentSum(), 'block': P.UnsortedSegmentSum(),
'desc_const': [1280], 'desc_const': [1280],
'desc_inputs': [[1280,1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192,1024]],
'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192, 1024]],
'skip': ['backward']}), 'skip': ['backward']}),
('UnsortedSegmentSum_1', { ('UnsortedSegmentSum_1', {
'block': P.UnsortedSegmentSum(), 'block': P.UnsortedSegmentSum(),
@@ -821,19 +834,20 @@ test_case_nn_ops = [
'skip': ['backward']}), 'skip': ['backward']}),
('ArgmaxNet', { ('ArgmaxNet', {
'block': ArgmaxNet(), 'block': ArgmaxNet(),
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))],
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'skip': ['backward']}), 'skip': ['backward']}),
('ArgminNet', { ('ArgminNet', {
'block': ArgminNet(), 'block': ArgminNet(),
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))],
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'skip': ['backward']}), 'skip': ['backward']}),
('CumSumNet', { ('CumSumNet', {
'block': CumSumNet(), 'block': CumSumNet(),
'desc_const': [0], 'desc_const': [0],
'desc_inputs': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))]}),
'desc_inputs': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))],
'desc_bprop': [
Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))]}),
('OneHot', { ('OneHot', {
'block': P.OneHot(), 'block': P.OneHot(),
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
@@ -1021,31 +1035,31 @@ test_case_array_ops = [
'desc_inputs': [(Tensor(np.array([1], np.float32)), 'desc_inputs': [(Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))], Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}),
'desc_bprop': [[3, ]]}),
('Pack_0', { ('Pack_0', {
'block': NetForPackInput(P.Pack()), 'block': NetForPackInput(P.Pack()),
'desc_inputs':[[2, 2], [2, 2], [2, 2]],
'desc_bprop':[[3, 2, 2]],
'desc_inputs': [[2, 2], [2, 2], [2, 2]],
'desc_bprop': [[3, 2, 2]],
}), }),
('Pack_1', { ('Pack_1', {
'block': NetForPackInput(P.Pack(axis=-2)), 'block': NetForPackInput(P.Pack(axis=-2)),
'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]],
'desc_bprop':[[3, 2, 3, 3]],
'desc_inputs': [[3, 2, 3], [3, 2, 3], [3, 2, 3]],
'desc_bprop': [[3, 2, 3, 3]],
}), }),
('Pack_2', { ('Pack_2', {
'block': NetForPackInput(P.Pack()), 'block': NetForPackInput(P.Pack()),
'desc_inputs':[[128, 128], [128, 128]],
'desc_bprop':[[2, 128, 128]],
'desc_inputs': [[128, 128], [128, 128]],
'desc_bprop': [[2, 128, 128]],
}), }),
('Unpack_0', { ('Unpack_0', {
'block': NetForUnpackInput(P.Unpack(axis=0)), 'block': NetForUnpackInput(P.Unpack(axis=0)),
'desc_inputs':[[2, 4]],
'desc_bprop':[[4], [4]],
'desc_inputs': [[2, 4]],
'desc_bprop': [[4], [4]],
}), }),
('Unpack_1', { ('Unpack_1', {
'block': NetForUnpackInput(P.Unpack(axis=-1)), 'block': NetForUnpackInput(P.Unpack(axis=-1)),
'desc_inputs':[Tensor(np.array([[1, 1, 1]], np.float32))],
'desc_bprop':[[1], [1], [1]],
'desc_inputs': [Tensor(np.array([[1, 1, 1]], np.float32))],
'desc_bprop': [[1], [1], [1]],
}), }),
('Diag_1', { ('Diag_1', {
'block': P.Diag(), 'block': P.Diag(),
@@ -1117,6 +1131,11 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2,), np.int32))), Tensor(np.ones((2,), np.int32))),
'desc_bprop': [([3, 3], {'dtype': np.int32})]}), 'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
('ScatterMax', {
'block': ScatterMax(),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float32) * 99)),
'skip': ['backward']}),
('SmoothL1Loss', { ('SmoothL1Loss', {
'block': P.SmoothL1Loss(), 'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]], 'desc_inputs': [[256, 4], [256, 4]],
@@ -1131,17 +1150,17 @@ test_case_other_ops = [
Tensor(np.array([1.2]).astype(np.float32))], Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}), 'skip': ['backward']}),
('ConfusionMulGrad_1', { ('ConfusionMulGrad_1', {
'block': P.ConfusionMulGrad(axis = [0], keep_dims = False),
'block': P.ConfusionMulGrad(axis=[0], keep_dims=False),
'desc_inputs': [[3, 2], [3, 2], [3, 2]], 'desc_inputs': [[3, 2], [3, 2], [3, 2]],
'desc_bprop': [[3, 2], [2]], 'desc_bprop': [[3, 2], [2]],
'skip': ['backward']}), 'skip': ['backward']}),
('ConfusionMulGrad_2', { ('ConfusionMulGrad_2', {
'block': P.ConfusionMulGrad(axis = [0], keep_dims = True),
'block': P.ConfusionMulGrad(axis=[0], keep_dims=True),
'desc_inputs': [[3, 2], [3, 2], [3, 2]], 'desc_inputs': [[3, 2], [3, 2], [3, 2]],
'desc_bprop': [[3, 2], [1, 2]], 'desc_bprop': [[3, 2], [1, 2]],
'skip': ['backward']}), 'skip': ['backward']}),
('ConfusionMulGrad_3', { ('ConfusionMulGrad_3', {
'block': P.ConfusionMulGrad(axis = (), keep_dims = True),
'block': P.ConfusionMulGrad(axis=(), keep_dims=True),
'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]],
'desc_bprop': [[2, 3, 4], [1, 1, 1]], 'desc_bprop': [[2, 3, 4], [1, 1, 1]],
'skip': ['backward']}), 'skip': ['backward']}),
@@ -1150,7 +1169,7 @@ test_case_other_ops = [
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
Tensor(np.array([1.2]).astype(np.float32))], Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}), 'skip': ['backward']}),
] ]


test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, test_case_other_ops] test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, test_case_other_ops]
@@ -1162,15 +1181,13 @@ test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
test_exec_case = test_case test_exec_case = test_case


test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
'backward' not in x[1]['skip'], test_case)

'backward' not in x[1]['skip'], test_case)


import mindspore.context as context


@non_graph_engine @non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec(): def test_exec():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
return test_exec_case return test_exec_case




@@ -1207,12 +1224,12 @@ raise_set = [
'desc_bprop': [[2, 3]]}), 'desc_bprop': [[2, 3]]}),
('Pack', { ('Pack', {
'block': (NetForPackInput(P.Pack()), {'exception': ValueError}), 'block': (NetForPackInput(P.Pack()), {'exception': ValueError}),
'desc_inputs':[[2, 2]],
'desc_bprop':[[1, 2, 2]]}),
'desc_inputs': [[2, 2]],
'desc_bprop': [[1, 2, 2]]}),
('PReLU', { ('PReLU', {
'block': (P.PReLU(), {'exception': ValueError}), 'block': (P.PReLU(), {'exception': ValueError}),
'desc_inputs':[[2], [1]],
'desc_bprop':[[1]]}),
'desc_inputs': [[2], [1]],
'desc_bprop': [[1]]}),


] ]




正在加载...
取消
保存