Browse Source

!2475 fix the summary operator is not work in constant folding scene

Merge pull request !2475 from ougongchang/master
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c454e1045f
1 changed files with 11 additions and 4 deletions
  1. +11
    -4
      mindspore/ops/operations/debug_ops.py

+ 11
- 4
mindspore/ops/operations/debug_ops.py View File

@@ -32,6 +32,13 @@ def _check_summary_param(name, value, class_name):
validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name) validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)




# Note: The return value of the summary operator is not used,
# so there's nothing special about the return `dtype` or `shape`, any value is ok.
# The `value` should be set to None, else summary operators may be optimized at compile graph phase,
# it cause summary operators can not record data in constant folding scene.
SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}


class ScalarSummary(PrimitiveWithInfer): class ScalarSummary(PrimitiveWithInfer):
""" """
Output scalar to protocol buffer through scalar summary operator. Output scalar to protocol buffer through scalar summary operator.
@@ -67,7 +74,7 @@ class ScalarSummary(PrimitiveWithInfer):
raise ValueError(f"For 'value' the type should be scalar, " raise ValueError(f"For 'value' the type should be scalar, "
f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.") f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.")


return value
return SUMMARY_RETURN_VALUE




class ImageSummary(PrimitiveWithInfer): class ImageSummary(PrimitiveWithInfer):
@@ -104,7 +111,7 @@ class ImageSummary(PrimitiveWithInfer):
raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__}," raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__},"
f" but got {len(v_shape)}.") f" but got {len(v_shape)}.")


return value
return SUMMARY_RETURN_VALUE




class TensorSummary(PrimitiveWithInfer): class TensorSummary(PrimitiveWithInfer):
@@ -142,7 +149,7 @@ class TensorSummary(PrimitiveWithInfer):
raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
f"shape should not be [].") f"shape should not be [].")


return value
return SUMMARY_RETURN_VALUE




class HistogramSummary(PrimitiveWithInfer): class HistogramSummary(PrimitiveWithInfer):
@@ -180,7 +187,7 @@ class HistogramSummary(PrimitiveWithInfer):
raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
f"shape should not be [].") f"shape should not be [].")


return value
return SUMMARY_RETURN_VALUE




class InsertGradientOf(PrimitiveWithInfer): class InsertGradientOf(PrimitiveWithInfer):


Loading…
Cancel
Save