You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.txt 7.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. Class mindspore.train.callback.SummaryCollector(summary_dir, collect_freq=10, collect_specified_data=None, keep_default_action=True, custom_lineage_data=None, collect_tensor_freq=None, max_file_size=None, export_options=None)
  2. SummaryCollector可以收集一些常用信息。
  3. 它可以帮助收集loss、学习率、计算图等。
  4. SummaryCollector还可以允许summary算子将数据收集到summary文件中。
  5. 注:
  6. 1. 不允许在回调列表中存在多个SummaryCollector实例。
  7. 2. 并非所有信息都可以在训练阶段或评估阶段收集的。
  8. 3. SummaryCollector始终记录summary算子收集的数据。
  9. 4. SummaryCollector仅支持Linux系统。
  10. 参数:
  11. summary_dir (str):收集的数据将存储到此目录。
  12. 如果目录不存在,将自动创建。
  13. collect_freq (int):设置数据收集的频率,频率应大于零,单位为`step`。如果设置了频率,将在(current steps % freq)等于0时收集数据,并且将随时收集第一个step。
  14. 需要注意的是,如果使用数据下沉模式,单位将变成`epoch`。
  15. 不建议过于频繁地收集数据,因为这可能会影响性能。默认值:10。
  16. collect_specified_data (Union[None, dict]):对收集的数据进行自定义操作。
  17. 默认情况下,如果该参数设为None,则默认收集所有数据。
  18. 您可以使用字典自定义需要收集的数据类型。
  19. 例如,您可以设置{'collect_metric':False}不去收集metrics。
  20. 支持控制的数据如下。默认值:None。
  21. - collect_metric (bool):表示是否收集训练metrics,目前只收集loss。
  22. 把第一个输出视为loss,并且算出其平均数。
  23. 可选值:True/False。默认值:True。
  24. - collect_graph (bool):表示是否收集计算图。目前只收集训练计算图。可选值:True/False。默认值:True。
  25. - collect_train_lineage (bool):表示是否收集训练阶段的lineage数据,该字段将显示在MindInsight的lineage页面上。可选值:True/False。默认值:True。
  26. - collect_eval_lineage (bool):表示是否收集评估阶段的lineage数据,该字段将显示在MindInsight的lineage页面上。可选值:True/False。默认值:True。
  27. - collect_input_data (bool):表示是否为每次训练收集数据集。
  28. 目前仅支持图像数据。
  29. 如果数据集中有多列数据,则第一列应为图像数据。
  30. 可选值:True/False。默认值:True。
  31. - collect_dataset_graph (bool):表示是否收集训练阶段的数据集图。
  32. 可选值:True/False。默认值:True。
  33. - histogram_regular (Union[str, None]):收集参数分布页面的权重和偏置,并在MindInsight中展示。此字段允许常规字符串控制要收集的参数。
  34. 不建议一次收集太多参数,因为这会影响性能。
  35. 注:如果收集的参数太多并且内存不足,训练将会失败。
  36. 默认值:None,表示只收集前五个参数。
  37. keep_default_action (bool):此字段影响collect_specified_data字段的收集行为。
  38. True:表示设置指定数据后,默认收集非指定数据。
  39. False:表示设置指定数据后,只收集指定数据,不收集其他数据。可选值:True/False,默认值:True。
  40. custom_lineage_data (Union[dict, None]):允许您自定义数据并将数据显示在MingInsight的lineage页面上。在自定义数据中,key支持str类型,value支持str、int和float类型。默认值:None,表示不存在自定义数据。
  41. collect_tensor_freq (Optional[int]):语义与`collect_freq`的相同,但仅控制TensorSummary。
  42. 由于TensorSummary数据太大,无法与其他summary数据进行比较,因此此参数用于降低收集量。默认情况下,收集TensorSummary数据的最大step数量为20,但不会超过收集其他summary数据的step数量。
  43. 例如,给定`collect_freq=10`,当总step数量为600时,TensorSummary将收集20个step,而收集其他summary数据时会收集61个step。但当总step数量为为20时,TensorSummary和其他summary将收集3个step。
  44. 另外请注意,在并行模式下,会平均分配总的step数量,这会影响TensorSummary收集的step的数量。
  45. 默认值:None,表示要遵循上述规则。
  46. max_file_size (Optional[int]):可写入磁盘的每个文件的最大大小(以字节为单位)。
  47. 例如,如果不大于4GB,则设置`max_file_size=4*1024**3`。
  48. 默认值:None,表示无限制。
  49. export_options (Union[None, dict]):表示对导出的数据执行自定义操作。
  50. 注:导出的文件的大小不受max_file_size的限制。
  51. 您可以使用字典自定义导出的数据。例如,您可以设置{'tensor_format':'npy'}将tensor导出为NPY文件。
  52. 支持控制的数据如下所示。
  53. 默认值:None,表示不导出数据。
  54. - tensor_format (Union[str, None]):自定义导出的tensor的格式。支持["npy", None]。
  55. 默认值:None,表示不导出tensor。
  56. - npy:将tensor导出为NPY文件。
  57. 异常:
  58. ValueError:参数值与预期的不同。
  59. TypeError:参数类型与预期的不同。
  60. RuntimeError:数据采集过程中出现错误。
  61. 示例:
  62. >>> import mindspore.nn as nn
  63. >>> from mindspore import context
  64. >>> from mindspore.train.callback import SummaryCollector
  65. >>> from mindspore import Model
  66. >>> from mindspore.nn import Accuracy
  67. >>>
  68. >>> if __name__ == '__main__':
  69. ... # 如果device_target是GPU,则将device_target设为GPU。
  70. ... context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  71. ... mnist_dataset_dir = '/path/to/mnist_dataset_directory'
  72. ... # model_zoo.office.cv.lenet.src.dataset.py中显示的create_dataset方法的详细信息
  73. ... ds_train = create_dataset(mnist_dataset_dir, 32)
  74. ... # model_zoo.official.cv.lenet.src.lenet.py中显示的LeNet5的详细信息
  75. ... network = LeNet5(10)
  76. ... net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  77. ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
  78. ... model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
  79. ...
  80. ... # 简单用法:
  81. ... summary_collector = SummaryCollector(summary_dir='./summary_dir')
  82. ... model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=False)
  83. ...
  84. ... # 不收集metric,收集第一层参数。默认收集其他数据。
  85. ... specified={'collect_metric': False, 'histogram_regular': '^conv1.*'}
  86. ... summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified)
  87. ... model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=False)