Merge pull request !340 from ougongchang/mastertags/v0.2.0-alpha
| @@ -103,7 +103,8 @@ std::string CNode::fullname_with_scope() { | |||
| return fullname_with_scope_; | |||
| } | |||
| if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary)) { | |||
| if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || | |||
| IsApply(prim::kPrimHistogramSummary)) { | |||
| std::string tag = GetValue<std::string>(GetValueNode(input(1))); | |||
| if (tag == "") { | |||
| MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; | |||
| @@ -111,10 +112,12 @@ std::string CNode::fullname_with_scope() { | |||
| std::string name; | |||
| if (IsApply(prim::kPrimScalarSummary)) { | |||
| name = tag + "[:Scalar]"; | |||
| } else if (IsApply(prim::kPrimTensorSummary)) { | |||
| name = tag + "[:Tensor]"; | |||
| } else { | |||
| } else if (IsApply(prim::kPrimImageSummary)) { | |||
| name = tag + "[:Image]"; | |||
| } else if (IsApply(prim::kPrimHistogramSummary)) { | |||
| name = tag + "[:Histogram]"; | |||
| } else { | |||
| name = tag + "[:Tensor]"; | |||
| } | |||
| fullname_with_scope_ = name; | |||
| } else { | |||
| @@ -236,6 +236,7 @@ const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||
| const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | |||
| const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||
| const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | |||
| ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { | |||
| py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); | |||
| @@ -225,6 +225,7 @@ extern const PrimitivePtr kPrimStateSetItem; | |||
| extern const PrimitivePtr kPrimScalarSummary; | |||
| extern const PrimitivePtr kPrimImageSummary; | |||
| extern const PrimitivePtr kPrimTensorSummary; | |||
| extern const PrimitivePtr kPrimHistogramSummary; | |||
| extern const PrimitivePtr kPrimBroadcastGradientArgs; | |||
| extern const PrimitivePtr kPrimControlDepend; | |||
| extern const PrimitivePtr kPrimIs_; | |||
| @@ -69,7 +69,7 @@ AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const Primitiv | |||
| int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); | |||
| if (tensor_rank == 0) { | |||
| MS_LOG(EXCEPTION) << "Tensor/Image Summary evaluator second arg should be an tensor, but got a scalar"; | |||
| MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; | |||
| } | |||
| // Reomve the force check to support batch set summary use 'for' loop | |||
| @@ -51,25 +51,14 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { | |||
| // node because it is attribute or ge specific reason. | |||
| // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be | |||
| // converted to switch guarded. | |||
| std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, | |||
| {prim::kPrimMomentum, {2, 3}}, | |||
| {prim::kPrimStateSetItem, {1}}, | |||
| {prim::kPrimEnvGetItem, {1}}, | |||
| {prim::kPrimEnvSetItem, {1}}, | |||
| {prim::kPrimReduceSum, {2}}, | |||
| {prim::kPrimReduceMean, {2}}, | |||
| {prim::kPrimReduceAll, {2}}, | |||
| {prim::kPrimCast, {2}}, | |||
| {prim::kPrimTranspose, {2}}, | |||
| {prim::kPrimOneHot, {2}}, | |||
| {prim::kPrimGatherV2, {3}}, | |||
| {prim::kPrimReshape, {2}}, | |||
| {prim::kPrimAssign, {1}}, | |||
| {prim::kPrimAssignAdd, {1}}, | |||
| {prim::kPrimAssignSub, {1}}, | |||
| {prim::kPrimTensorSummary, {1}}, | |||
| {prim::kPrimImageSummary, {1}}, | |||
| {prim::kPrimScalarSummary, {1}}}); | |||
| std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list( | |||
| {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}}, | |||
| {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, {prim::kPrimReduceSum, {2}}, | |||
| {prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, | |||
| {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, {prim::kPrimGatherV2, {3}}, | |||
| {prim::kPrimReshape, {2}}, {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, | |||
| {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, {prim::kPrimImageSummary, {1}}, | |||
| {prim::kPrimScalarSummary, {1}}, {prim::kPrimHistogramSummary, {1}}}); | |||
| for (auto &item : white_list) { | |||
| auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { | |||
| return IsPrimitiveCNode(node, item.first) && idx == index; | |||
| @@ -66,6 +66,7 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM, | |||
| SCALARSUMMARY, | |||
| IMAGESUMMARY, | |||
| TENSORSUMMARY, | |||
| HISTOGRAMSUMMARY, | |||
| COL2IMV1, | |||
| RESOLVE, | |||
| BROADCASTGRADIENTARGS, | |||
| @@ -246,6 +246,7 @@ constexpr char STATESETITEM[] = "state_setitem"; | |||
| constexpr char SCALARSUMMARY[] = "ScalarSummary"; | |||
| constexpr char IMAGESUMMARY[] = "ImageSummary"; | |||
| constexpr char TENSORSUMMARY[] = "TensorSummary"; | |||
| constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; | |||
| constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; | |||
| constexpr char INVERTPERMUTATION[] = "InvertPermutation"; | |||
| constexpr char CONTROLDEPEND[] = "ControlDepend"; | |||
| @@ -131,6 +131,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimScalarSummary, {InferImplScalarSummary, true}}, | |||
| {prim::kPrimImageSummary, {InferImplTensorSummary, true}}, | |||
| {prim::kPrimTensorSummary, {InferImplTensorSummary, true}}, | |||
| {prim::kPrimHistogramSummary, {InferImplTensorSummary, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| @@ -714,7 +714,8 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { | |||
| } | |||
| auto input = cnode->inputs()[0]; | |||
| bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || | |||
| IsPrimitive(input, prim::kPrimTensorSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||
| IsPrimitive(input, prim::kPrimTensorSummary) || | |||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | |||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | |||
| IsPrimitive(input, prim::kPrimReturn); | |||
| @@ -45,7 +45,7 @@ void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, s | |||
| for (auto &n : apply_list) { | |||
| MS_EXCEPTION_IF_NULL(n); | |||
| if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || | |||
| IsPrimitiveCNode(n, prim::kPrimImageSummary)) { | |||
| IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { | |||
| int index = 0; | |||
| auto cnode = n->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -83,7 +83,7 @@ bool ExistSummaryNode(const KernelGraph *graph) { | |||
| auto all_nodes = DeepLinkedGraphSearch(ret); | |||
| for (auto &n : all_nodes) { | |||
| if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || | |||
| IsPrimitiveCNode(n, prim::kPrimImageSummary)) { | |||
| IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { | |||
| return true; | |||
| } | |||
| } | |||
| @@ -353,6 +353,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {prim::kPrimScalarSummary->name(), ADPT_DESC(Summary)}, | |||
| {prim::kPrimImageSummary->name(), ADPT_DESC(Summary)}, | |||
| {prim::kPrimTensorSummary->name(), ADPT_DESC(Summary)}, | |||
| {prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary)}, | |||
| {prim::kPrimTensorAdd->name(), | |||
| std::make_shared<OpAdapterDesc>(std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(1)}})), | |||
| std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(1)}})))}, | |||
| @@ -131,7 +131,7 @@ static TensorPtr GetMeTensorForSummary(const std::string& name, const std::share | |||
| auto shape = std::vector<int>({ONE_SHAPE}); | |||
| return TransformUtil::ConvertGeTensor(ge_tensor_ptr, shape); | |||
| } | |||
| if (tname == "[:Tensor]") { | |||
| if (tname == "[:Tensor]" || tname == "[:Histogram]") { | |||
| MS_LOG(DEBUG) << "The summary(" << name << ") is Tensor"; | |||
| // process the tensor summary | |||
| // Now we can't get the real shape, so we keep same shape with GE | |||
| @@ -49,6 +49,15 @@ def get_bprop_image_summary(self): | |||
| return bprop | |||
| @bprop_getters.register(P.HistogramSummary) | |||
| def get_bprop_histogram_summary(self): | |||
| """Generate bprop for HistogramSummary""" | |||
| def bprop(tag, x, out, dout): | |||
| return tag, zeros_like(x) | |||
| return bprop | |||
| @bprop_getters.register(P.InsertGradientOf) | |||
| def get_bprop_insert_gradient_of(self): | |||
| """Generate bprop for InsertGradientOf""" | |||
| @@ -34,7 +34,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary, | |||
| TensorSummary, Print) | |||
| TensorSummary, HistogramSummary, Print) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| from .inner_ops import ScalarCast | |||
| from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, | |||
| @@ -148,6 +148,7 @@ __all__ = [ | |||
| 'ScalarSummary', | |||
| 'ImageSummary', | |||
| 'TensorSummary', | |||
| 'HistogramSummary', | |||
| "Print", | |||
| 'InsertGradientOf', | |||
| 'InvertPermutation', | |||
| @@ -98,6 +98,33 @@ class TensorSummary(Primitive): | |||
| """init""" | |||
| class HistogramSummary(Primitive): | |||
| """ | |||
| Output tensor to protocol buffer through histogram summary operator. | |||
| Inputs: | |||
| - **name** (str) - The name of the input variable. | |||
| - **value** (Tensor) - The value of tensor, and the rank of tensor should be greater than 0. | |||
| Examples: | |||
| >>> class SummaryDemo(nn.Cell): | |||
| >>> def __init__(self,): | |||
| >>> super(SummaryDemo, self).__init__() | |||
| >>> self.summary = P.HistogramSummary() | |||
| >>> self.add = P.TensorAdd() | |||
| >>> | |||
| >>> def construct(self, x, y): | |||
| >>> x = self.add(x, y) | |||
| >>> name = "x" | |||
| >>> self.summary(name, x) | |||
| >>> return x | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init""" | |||
| class InsertGradientOf(PrimitiveWithInfer): | |||
| """ | |||
| Attach callback to graph node that will be invoked on the node's gradient. | |||
| @@ -24,17 +24,6 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| ''' | |||
| This testcase is used for save summary data only. You need install MindData first and uncomment the commented | |||
| packages to analyse summary data. | |||
| Using "minddata start --datalog='./test_me_summary_event_file/' --host=0.0.0.0" to make data visible. | |||
| ''' | |||
| # from minddata.datavisual.data_transform.data_manager import DataManager | |||
| # from minddata.datavisual.visual.train_visual.train_task_manager import TrainTaskManager | |||
| # from minddata.datavisual.visual.train_visual.scalars_processor import ScalarsProcessor | |||
| # from minddata.datavisual.common.enums import PluginNameEnum | |||
| # from minddata.datavisual.common.enums import DataManagerStatus | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| @@ -43,6 +32,7 @@ CUR_DIR = os.getcwd() | |||
| SUMMARY_DIR_ME = CUR_DIR + "/test_me_summary_event_file/" | |||
| SUMMARY_DIR_ME_TEMP = CUR_DIR + "/test_me_temp_summary_event_file/" | |||
| def clean_environment_file(srcDir): | |||
| if os.path.exists(srcDir): | |||
| ls = os.listdir(srcDir) | |||
| @@ -50,6 +40,8 @@ def clean_environment_file(srcDir): | |||
| filePath = os.path.join(srcDir, line) | |||
| os.remove(filePath) | |||
| os.removedirs(srcDir) | |||
| def save_summary_events_file(srcDir, desDir): | |||
| if not os.path.exists(desDir): | |||
| print("-- create desDir") | |||
| @@ -64,12 +56,14 @@ def save_summary_events_file(srcDir, desDir): | |||
| os.remove(filePath) | |||
| os.removedirs(srcDir) | |||
| class SummaryNet(nn.Cell): | |||
| def __init__(self, tag_tuple=None, scalar=1): | |||
| super(SummaryNet, self).__init__() | |||
| self.summary_s = P.ScalarSummary() | |||
| self.summary_i = P.ImageSummary() | |||
| self.summary_t = P.TensorSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| self.tag_tuple = tag_tuple | |||
| self.scalar = scalar | |||
| @@ -79,8 +73,10 @@ class SummaryNet(nn.Cell): | |||
| self.summary_s("x1", x) | |||
| z = self.add(x, y) | |||
| self.summary_t("z1", z) | |||
| self.histogram_summary("histogram", z) | |||
| return z | |||
| def train_summary_record_scalar_for_1(test_writer, steps, fwd_x, fwd_y): | |||
| net = SummaryNet() | |||
| out_me_dict = {} | |||
| @@ -93,6 +89,7 @@ def train_summary_record_scalar_for_1(test_writer, steps, fwd_x, fwd_y): | |||
| out_me_dict[i] = out_put.asnumpy() | |||
| return out_me_dict | |||
| def me_scalar_summary(steps, tag=None, value=None): | |||
| test_writer = SummaryRecord(SUMMARY_DIR_ME_TEMP) | |||
| @@ -104,44 +101,6 @@ def me_scalar_summary(steps, tag=None, value=None): | |||
| test_writer.close() | |||
| return out_me_dict | |||
| def print_scalar_data(): | |||
| print("============start print_scalar_data\n") | |||
| data_manager = DataManager() | |||
| data_manager.start_load_data(path=SUMMARY_DIR_ME) | |||
| while data_manager.get_status() != DataManagerStatus.DONE: | |||
| time.sleep(0.1) | |||
| task_manager = TrainTaskManager(data_manager) | |||
| train_jobs = task_manager.get_all_train_tasks(PluginNameEnum.scalar) | |||
| print(train_jobs) | |||
| """ | |||
| train_jobs | |||
| ['train_jobs': { | |||
| 'id': '12-123', | |||
| 'name': 'train_job_name', | |||
| 'tags': ['x1', 'y1'] | |||
| }] | |||
| """ | |||
| scalar_processor = ScalarsProcessor(data_manager) | |||
| metadata = scalar_processor.get_metadata_list(train_job_ids=train_jobs['train_jobs'][0]['id'], tag=train_jobs['train_jobs'][0]['tags'][0]) | |||
| print(metadata) | |||
| ''' | |||
| metadata | |||
| { | |||
| 'scalars' : [ | |||
| { | |||
| 'train_job_id' : '12-12', | |||
| 'metadatas' : [ | |||
| { | |||
| 'wall_time' : 0.1, | |||
| 'step' : 1, | |||
| 'value' : 0.1 | |||
| } | |||
| ] | |||
| } | |||
| ] | |||
| } | |||
| ''' | |||
| print("============end print_scalar_data\n") | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -621,6 +621,12 @@ TEST_F(TestConvert, TestTensorSummaryOps) { | |||
| ASSERT_TRUE(ret); | |||
| } | |||
| TEST_F(TestConvert, TestHistogramSummaryOps) { | |||
| auto prim = prim::kPrimHistogramSummary; | |||
| bool ret = MakeDfGraph(prim, 2); | |||
| ASSERT_TRUE(ret); | |||
| } | |||
| TEST_F(TestConvert, TestGreaterOps) { | |||
| auto prim = std::make_shared<Primitive>("Greater"); | |||
| bool ret = MakeDfGraph(prim, 2); | |||
| @@ -73,7 +73,8 @@ FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, unsigned int nparam) { | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim)); | |||
| for (unsigned int i = 0; i < nparam; i++) { | |||
| if ((prim->name() == "ScalarSummary" || prim->name() == "TensorSummary" || prim->name() == "ImageSummary") && | |||
| if ((prim->name() == "ScalarSummary" || prim->name() == "TensorSummary" || | |||
| prim->name() == "ImageSummary" || prim->name() == "HistogramSummary") && | |||
| i == 0) { | |||
| auto input = NewValueNode("testSummary"); | |||
| inputs.push_back(input); | |||
| @@ -198,6 +198,19 @@ class ScalarSummaryNet(nn.Cell): | |||
| return out | |||
| class HistogramSummaryNet(nn.Cell): | |||
| """HistogramSummaryNet definition""" | |||
| def __init__(self): | |||
| super(HistogramSummaryNet, self).__init__() | |||
| self.summary = P.HistogramSummary() | |||
| def construct(self, tensor): | |||
| string_in = "wight_value" | |||
| out = self.summary(string_in, tensor) | |||
| return out | |||
| class FusedBatchNormGrad(nn.Cell): | |||
| """ FusedBatchNormGrad definition """ | |||
| @@ -443,6 +456,10 @@ test_cases = [ | |||
| 'block': ScalarSummaryNet(), | |||
| 'desc_inputs': [2.2], | |||
| }), | |||
| ('HistogramSummary', { | |||
| 'block': HistogramSummaryNet(), | |||
| 'desc_inputs': [[1,2,3]], | |||
| }), | |||
| ('FusedBatchNormGrad', { | |||
| 'block': FusedBatchNormGrad(nn.BatchNorm2d(num_features=512, eps=1e-5, momentum=0.1)), | |||
| 'desc_inputs': [[64, 512, 7, 7], [64, 512, 7, 7]], | |||
| @@ -160,6 +160,19 @@ class SummaryNet(nn.Cell): | |||
| return self.add(x, y) | |||
| class HistogramSummaryNet(nn.Cell): | |||
| def __init__(self,): | |||
| super(HistogramSummaryNet, self).__init__() | |||
| self.summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x, y): | |||
| out = self.add(x, y) | |||
| string_in = "out" | |||
| self.summary(string_in, out) | |||
| return out | |||
| test_case_math_ops = [ | |||
| ('Neg', { | |||
| 'block': P.Neg(), | |||
| @@ -1104,6 +1117,12 @@ test_case_other_ops = [ | |||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | |||
| Tensor(np.array([1.2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('HistogramSummary', { | |||
| 'block': HistogramSummaryNet(), | |||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | |||
| Tensor(np.array([1.2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ] | |||
| test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, test_case_other_ops] | |||
| @@ -132,6 +132,7 @@ class SummaryDemo(nn.Cell): | |||
| def __init__(self,): | |||
| super(SummaryDemo, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x, y): | |||
| @@ -139,6 +140,7 @@ class SummaryDemo(nn.Cell): | |||
| z = self.add(x, y) | |||
| self.s("z1", z) | |||
| self.s("y1", y) | |||
| self.histogram_summary("histogram", z) | |||
| return z | |||
| @@ -40,6 +40,7 @@ class SummaryDemoTag(nn.Cell): | |||
| def __init__(self, tag1, tag2, tag3): | |||
| super(SummaryDemoTag, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| self.tag1 = tag1 | |||
| self.tag2 = tag2 | |||
| @@ -50,6 +51,7 @@ class SummaryDemoTag(nn.Cell): | |||
| z = self.add(x, y) | |||
| self.s(self.tag2, z) | |||
| self.s(self.tag3, y) | |||
| self.histogram_summary(self.tag1, x) | |||
| return z | |||
| @@ -58,6 +60,7 @@ class SummaryDemoTagForSet(nn.Cell): | |||
| def __init__(self, tag_tuple): | |||
| super(SummaryDemoTagForSet, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| self.tag_tuple = tag_tuple | |||
| @@ -65,6 +68,7 @@ class SummaryDemoTagForSet(nn.Cell): | |||
| z = self.add(x, y) | |||
| for tag in self.tag_tuple: | |||
| self.s(tag, x) | |||
| self.histogram_summary(tag, x) | |||
| return z | |||
| @@ -98,6 +102,19 @@ class SummaryDemoValueForSet(nn.Cell): | |||
| self.s(tag, self.v) | |||
| return z | |||
| class HistogramSummaryNet(nn.Cell): | |||
| "HistogramSummaryNet definition" | |||
| def __init__(self, value): | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| self.value = value | |||
| def construct(self, tensors1, tensor2): | |||
| self.histogram_summary("value", self.value) | |||
| return self.add(tensors1, tensor2) | |||
| def run_case(net): | |||
| """ run_case """ | |||
| # step 0: create the thread | |||
| @@ -121,8 +138,8 @@ def run_case(net): | |||
| # Test 1: use the repeat tag | |||
| def test_scalar_summary_use_repeat_tag(): | |||
| log.debug("begin test_scalar_summary_use_repeat_tag") | |||
| def test_summary_use_repeat_tag(): | |||
| log.debug("begin test_summary_use_repeat_tag") | |||
| net = SummaryDemoTag("x", "x", "x") | |||
| try: | |||
| run_case(net) | |||
| @@ -130,12 +147,12 @@ def test_scalar_summary_use_repeat_tag(): | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_scalar_summary_use_repeat_tag") | |||
| log.debug("finished test_summary_use_repeat_tag") | |||
| # Test 2: repeat tag use for set summary | |||
| def test_scalar_summary_use_repeat_tag_for_set(): | |||
| log.debug("begin test_scalar_summary_use_repeat_tag_for_set") | |||
| def test_summary_use_repeat_tag_for_set(): | |||
| log.debug("begin test_summary_use_repeat_tag_for_set") | |||
| net = SummaryDemoTagForSet(("x", "x", "x")) | |||
| try: | |||
| run_case(net) | |||
| @@ -143,12 +160,12 @@ def test_scalar_summary_use_repeat_tag_for_set(): | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_scalar_summary_use_repeat_tag_for_set") | |||
| log.debug("finished test_summary_use_repeat_tag_for_set") | |||
| # Test3: test with invalid tag(None, bool, "", int) | |||
| def test_scalar_summary_use_invalid_tag_None(): | |||
| log.debug("begin test_scalar_summary_use_invalid_tag_None") | |||
| def test_summary_use_invalid_tag_None(): | |||
| log.debug("begin test_summary_use_invalid_tag_None") | |||
| net = SummaryDemoTag(None, None, None) | |||
| try: | |||
| run_case(net) | |||
| @@ -156,31 +173,31 @@ def test_scalar_summary_use_invalid_tag_None(): | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_scalar_summary_use_invalid_tag_None") | |||
| log.debug("finished test_summary_use_invalid_tag_None") | |||
| # Test4: test with invalid tag(None, bool, "", int) | |||
| def test_scalar_summary_use_invalid_tag_Bool(): | |||
| log.debug("begin test_scalar_summary_use_invalid_tag_Bool") | |||
| def test_summary_use_invalid_tag_Bool(): | |||
| log.debug("begin test_summary_use_invalid_tag_Bool") | |||
| net = SummaryDemoTag(True, True, True) | |||
| run_case(net) | |||
| log.debug("finished test_scalar_summary_use_invalid_tag_Bool") | |||
| log.debug("finished test_summary_use_invalid_tag_Bool") | |||
| # Test5: test with invalid tag(None, bool, "", int) | |||
| def test_scalar_summary_use_invalid_tag_null(): | |||
| log.debug("begin test_scalar_summary_use_invalid_tag_null") | |||
| def test_summary_use_invalid_tag_null(): | |||
| log.debug("begin test_summary_use_invalid_tag_null") | |||
| net = SummaryDemoTag("", "", "") | |||
| run_case(net) | |||
| log.debug("finished test_scalar_summary_use_invalid_tag_null") | |||
| log.debug("finished test_summary_use_invalid_tag_null") | |||
| # Test6: test with invalid tag(None, bool, "", int) | |||
| def test_scalar_summary_use_invalid_tag_Int(): | |||
| log.debug("begin test_scalar_summary_use_invalid_tag_Int") | |||
| def test_summary_use_invalid_tag_Int(): | |||
| log.debug("begin test_summary_use_invalid_tag_Int") | |||
| net = SummaryDemoTag(1, 2, 3) | |||
| run_case(net) | |||
| log.debug("finished test_scalar_summary_use_invalid_tag_Int") | |||
| log.debug("finished test_summary_use_invalid_tag_Int") | |||
| # Test7: test with invalid value(None, "") | |||
| @@ -196,7 +213,6 @@ def test_scalar_summary_use_invalid_value_None(): | |||
| log.debug("finished test_scalar_summary_use_invalid_tag_Int") | |||
| # Test8: test with invalid value(None, "") | |||
| def test_scalar_summary_use_invalid_value_None_ForSet(): | |||
| log.debug("begin test_scalar_summary_use_invalid_value_None_ForSet") | |||
| @@ -221,3 +237,30 @@ def test_scalar_summary_use_invalid_value_null(): | |||
| else: | |||
| assert False | |||
| log.debug("finished test_scalar_summary_use_invalid_value_null") | |||
| def test_histogram_summary_use_valid_value(): | |||
| """Test histogram summary with valid value""" | |||
| log.debug("Begin test_histogram_summary_use_valid_value") | |||
| try: | |||
| net = HistogramSummaryNet(Tensor(np.array([1,2,3]))) | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("Finished test_histogram_summary_use_valid_value") | |||
| def test_histogram_summary_use_scalar_value(): | |||
| """Test histogram summary use scalar value""" | |||
| log.debug("Begin test_histogram_summary_use_scalar_value") | |||
| try: | |||
| scalar = Tensor(1) | |||
| net = HistogramSummaryNet(scalar) | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("Finished test_histogram_summary_use_scalar_value") | |||