Browse Source

!6590 fix bugs of op Debug, ReLUV2, EditDistance and Dense

Merge pull request !6590 from lihongkang/v2_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d56683157d
5 changed files with 11 additions and 39 deletions
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +0
    -9
      mindspore/ops/_grad/grad_debug_ops.py
  3. +2
    -2
      mindspore/ops/operations/__init__.py
  4. +7
    -0
      mindspore/ops/operations/array_ops.py
  5. +1
    -27
      mindspore/ops/operations/debug_ops.py

+ 1
- 1
mindspore/nn/layer/basic.py View File

@@ -165,7 +165,7 @@ class Dense(Cell):
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),

where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
with the same data type as the inputs created by the layer (only if has_bias is True).



+ 0
- 9
mindspore/ops/_grad/grad_debug_ops.py View File

@@ -66,12 +66,3 @@ def get_bprop_insert_gradient_of(self):
def bprop(x, out, dout):
return (f(dout),)
return bprop


@bprop_getters.register(P.Debug)
def get_bprop_debug(self):
"""Generate bprop for Debug"""

def bprop(x, out, dout):
return dout
return bprop

+ 2
- 2
mindspore/ops/operations/__init__.py View File

@@ -39,7 +39,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Debug, Print, Assert)
TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast

@@ -200,7 +200,6 @@ __all__ = [
'ImageSummary',
'TensorSummary',
'HistogramSummary',
"Debug",
"Print",
"Assert",
'InsertGradientOf',
@@ -375,6 +374,7 @@ __all__ = [
"ParallelConcat",
"Push",
"Pull",
"ReLUV2",
'SparseToDense',
]



+ 7
- 0
mindspore/ops/operations/array_ops.py View File

@@ -3619,6 +3619,12 @@ class EditDistance(PrimitiveWithInfer):
Tensor, a dense tensor with rank `R-1` and float32 data type.

Examples:
>>> import numpy as np
>>> from mindspore import context
>>> from mindspore import Tensor
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class EditDistance(nn.Cell):
>>> def __init__(self, hypothesis_shape, truth_shape, normalize=True):
>>> super(EditDistance, self).__init__()
@@ -3645,6 +3651,7 @@ class EditDistance(PrimitiveWithInfer):
def __init__(self, normalize=True):
"""Initialize EditDistance"""
self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name)
self.set_const_input_indexes([2, 5])

def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape):
validator.check_const_input('hypothesis_shape', h_shape['value'], self.name)


+ 1
- 27
mindspore/ops/operations/debug_ops.py View File

@@ -18,7 +18,7 @@ from types import FunctionType, MethodType
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
from ..primitive import prim_attr_register, PrimitiveWithInfer


def _check_summary_param(name, value, class_name):
@@ -342,32 +342,6 @@ class Print(PrimitiveWithInfer):
return mstype.int32


class Debug(Primitive):
"""
Prints tensor value.

Inputs:
- **value** (Tensor) - The value of tensor.

Examples:
>>> class DebugNN(nn.Cell):
>>> def __init__(self,):
>>> self.debug = nn.Debug()
>>>
>>> def construct(self, x, y):
>>> x = self.add(x, y)
>>> self.debug(x)
>>> return x
"""

@prim_attr_register
def __init__(self):
"""init"""

def __call__(self, *args, **kwargs):
pass


class Assert(PrimitiveWithInfer):
"""
Asserts that the given condition is true.


Loading…
Cancel
Save