Browse Source

fix summary tage check error

Signed-off-by: candanzg <zhangshucheng@huawei.com>
tags/v0.3.0-alpha
candanzg 5 years ago
parent
commit
b84deeeb5c
2 changed files with 7 additions and 4 deletions
  1. +2
    -2
      mindspore/ccsrc/operator/prim_debug.cc
  2. +5
    -2
      tests/ut/python/train/summary/test_summary_ops_params_valid_check.py

+ 2
- 2
mindspore/ccsrc/operator/prim_debug.cc View File

@@ -51,7 +51,7 @@ AbstractBasePtr InferImplScalarSummary(const AnalysisEnginePtr &, const Primitiv
// Reomve the force check to support batch set summary use 'for' loop // Reomve the force check to support batch set summary use 'for' loop
auto item_v = descriptions->BuildValue(); auto item_v = descriptions->BuildValue();
if (!item_v->isa<StringImm>()) { if (!item_v->isa<StringImm>()) {
MS_LOG(ERROR) << "First parameter shoule be string";
MS_EXCEPTION(TypeError) << "Summary first parameter should be string";
} }


return std::make_shared<AbstractScalar>(kAnyValue, kBool); return std::make_shared<AbstractScalar>(kAnyValue, kBool);
@@ -75,7 +75,7 @@ AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const Primitiv
// Reomve the force check to support batch set summary use 'for' loop // Reomve the force check to support batch set summary use 'for' loop
auto item_v = descriptions->BuildValue(); auto item_v = descriptions->BuildValue();
if (!item_v->isa<StringImm>()) { if (!item_v->isa<StringImm>()) {
MS_LOG(WARNING) << "Summary first parameter must be string";
MS_EXCEPTION(TypeError) << "Summary first parameter should be string";
} }


return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<Bool>()); return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<Bool>());


+ 5
- 2
tests/ut/python/train/summary/test_summary_ops_params_valid_check.py View File

@@ -22,6 +22,7 @@ import os
import logging import logging
import random import random
import numpy as np import numpy as np
import pytest
from mindspore.train.summary.summary_record import SummaryRecord from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.nn as nn import mindspore.nn as nn
@@ -180,7 +181,8 @@ def test_summary_use_invalid_tag_None():
def test_summary_use_invalid_tag_Bool(): def test_summary_use_invalid_tag_Bool():
log.debug("begin test_summary_use_invalid_tag_Bool") log.debug("begin test_summary_use_invalid_tag_Bool")
net = SummaryDemoTag(True, True, True) net = SummaryDemoTag(True, True, True)
run_case(net)
with pytest.raises(TypeError):
run_case(net)
log.debug("finished test_summary_use_invalid_tag_Bool") log.debug("finished test_summary_use_invalid_tag_Bool")




@@ -196,7 +198,8 @@ def test_summary_use_invalid_tag_null():
def test_summary_use_invalid_tag_Int(): def test_summary_use_invalid_tag_Int():
log.debug("begin test_summary_use_invalid_tag_Int") log.debug("begin test_summary_use_invalid_tag_Int")
net = SummaryDemoTag(1, 2, 3) net = SummaryDemoTag(1, 2, 3)
run_case(net)
with pytest.raises(TypeError):
run_case(net)
log.debug("finished test_summary_use_invalid_tag_Int") log.debug("finished test_summary_use_invalid_tag_Int")






Loading…
Cancel
Save