Browse Source

import comment and function of op print

tags/v0.2.0-alpha
guohongzilong 5 years ago
parent
commit
0eb72d76f0
2 changed files with 10 additions and 6 deletions
  1. +1
    -1
      mindspore/ccsrc/transform/op_adapter.h
  2. +9
    -5
      mindspore/ops/operations/debug_ops.py

+ 1
- 1
mindspore/ccsrc/transform/op_adapter.h View File

@@ -513,7 +513,7 @@ class OpAdapter : public BaseOpAdapter {
return; return;
} }
} else { } else {
MS_LOG(ERROR) << "Update output desc failed, unknow output shape type";
MS_LOG(WARNING) << "Update output desc failed, unknow output shape type";
return; return;
} }
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);


+ 9
- 5
mindspore/ops/operations/debug_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================


"""debug_ops""" """debug_ops"""
from ..._checkparam import ParamValidator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer


@@ -157,19 +158,20 @@ class InsertGradientOf(PrimitiveWithInfer):


class Print(PrimitiveWithInfer): class Print(PrimitiveWithInfer):
""" """
Output tensor to stdout.
Output tensor or string to stdout.


Inputs: Inputs:
- **input_x** (Tensor) - The graph node to attach to.
- **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports
multiple strings and tensors which are separated by ','.


Examples: Examples:
>>> class PrintDemo(nn.Cell): >>> class PrintDemo(nn.Cell):
>>> def __init__(self,):
>>> def __init__(self):
>>> super(PrintDemo, self).__init__() >>> super(PrintDemo, self).__init__()
>>> self.print = P.Print() >>> self.print = P.Print()
>>> >>>
>>> def construct(self, x):
>>> self.print(x)
>>> def construct(self, x, y):
>>> self.print('Print Tensor x and Tensor y:', x, y)
>>> return x >>> return x
""" """


@@ -181,4 +183,6 @@ class Print(PrimitiveWithInfer):
return [1] return [1]


def infer_dtype(self, *inputs): def infer_dtype(self, *inputs):
for dtype in inputs:
validator.check_subclass("input", dtype, (mstype.tensor, mstype.string))
return mstype.int32 return mstype.int32

Loading…
Cancel
Save