You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_parameter.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define schema of model lineage input parameters."""
  16. from marshmallow import Schema, fields, ValidationError, pre_load, validates
  17. from marshmallow.validate import Range
  18. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
  19. LineageErrors
  20. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  21. LineageParamTypeError, LineageParamValueError
  22. from mindinsight.lineagemgr.common.log import logger
  23. from mindinsight.lineagemgr.common.utils import enum_to_list
  24. from mindinsight.lineagemgr.querier.querier import LineageType
  25. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  26. from mindinsight.utils.exceptions import MindInsightException
  27. try:
  28. from mindspore.dataset.engine import Dataset
  29. from mindspore.nn import Cell, Optimizer
  30. from mindspore.common.tensor import Tensor
  31. from mindspore.train.callback import _ListCallback
  32. except (ImportError, ModuleNotFoundError):
  33. logger.error('MindSpore Not Found!')
  34. class RunContextArgs(Schema):
  35. """Define the parameter schema for RunContext."""
  36. optimizer = fields.Function(allow_none=True)
  37. loss_fn = fields.Function(allow_none=True)
  38. net_outputs = fields.Function(allow_none=True)
  39. train_network = fields.Function(allow_none=True)
  40. train_dataset = fields.Function(allow_none=True)
  41. epoch_num = fields.Int(allow_none=True, validate=Range(min=1))
  42. batch_num = fields.Int(allow_none=True, validate=Range(min=0))
  43. cur_step_num = fields.Int(allow_none=True, validate=Range(min=0))
  44. parallel_mode = fields.Str(allow_none=True)
  45. device_number = fields.Int(allow_none=True, validate=Range(min=1))
  46. list_callback = fields.Function(allow_none=True)
  47. @pre_load
  48. def check_optimizer(self, data, **kwargs):
  49. optimizer = data.get("optimizer")
  50. if optimizer and not isinstance(optimizer, Optimizer):
  51. raise ValidationError({'optimizer': [
  52. "Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
  53. ]})
  54. return data
  55. @pre_load
  56. def check_train_network(self, data, **kwargs):
  57. train_network = data.get("train_network")
  58. if train_network and not isinstance(train_network, Cell):
  59. raise ValidationError({'train_network': [
  60. "Parameter train_network must be an instance of mindspore.nn.Cell."]})
  61. return data
  62. @pre_load
  63. def check_train_dataset(self, data, **kwargs):
  64. train_dataset = data.get("train_dataset")
  65. if train_dataset and not isinstance(train_dataset, Dataset):
  66. raise ValidationError({'train_dataset': [
  67. "Parameter train_dataset must be an instance of "
  68. "mindspore.dataengine.datasets.Dataset"]})
  69. return data
  70. @pre_load
  71. def check_loss(self, data, **kwargs):
  72. net_outputs = data.get("net_outputs")
  73. if net_outputs and not isinstance(net_outputs, Tensor):
  74. raise ValidationError({'net_outpus': [
  75. "The parameter net_outputs is invalid. It should be a Tensor."
  76. ]})
  77. return data
  78. @pre_load
  79. def check_list_callback(self, data, **kwargs):
  80. list_callback = data.get("list_callback")
  81. if list_callback and not isinstance(list_callback, _ListCallback):
  82. raise ValidationError({'list_callback': [
  83. "Parameter list_callback must be an instance of "
  84. "mindspore.train.callback._ListCallback."
  85. ]})
  86. return data
  87. class EvalParameter(Schema):
  88. """Define the parameter schema for Evaluation job."""
  89. valid_dataset = fields.Function(allow_none=True)
  90. metrics = fields.Dict(allow_none=True)
  91. @pre_load
  92. def check_valid_dataset(self, data, **kwargs):
  93. valid_dataset = data.get("valid_dataset")
  94. if valid_dataset and not isinstance(valid_dataset, Dataset):
  95. raise ValidationError({'valid_dataset': [
  96. "Parameter valid_dataset must be an instance of "
  97. "mindspore.dataengine.datasets.Dataset"]})
  98. return data
  99. class SearchModelConditionParameter(Schema):
  100. """Define the search model condition parameter schema."""
  101. summary_dir = fields.Dict()
  102. loss_function = fields.Dict()
  103. train_dataset_path = fields.Dict()
  104. train_dataset_count = fields.Dict()
  105. test_dataset_path = fields.Dict()
  106. test_dataset_count = fields.Dict()
  107. network = fields.Dict()
  108. optimizer = fields.Dict()
  109. learning_rate = fields.Dict()
  110. epoch = fields.Dict()
  111. batch_size = fields.Dict()
  112. loss = fields.Dict()
  113. model_size = fields.Dict()
  114. limit = fields.Int(validate=lambda n: 0 < n <= 100)
  115. offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
  116. sorted_name = fields.Str()
  117. sorted_type = fields.Str(allow_none=True)
  118. dataset_mark = fields.Dict()
  119. lineage_type = fields.Dict()
  120. @staticmethod
  121. def check_dict_value_type(data, value_type):
  122. """Check dict value type and int scope."""
  123. for key, value in data.items():
  124. if key == "in":
  125. if not isinstance(value, (list, tuple)):
  126. raise ValidationError("The value of `in` operation must be list or tuple.")
  127. else:
  128. if not isinstance(value, value_type):
  129. raise ValidationError("Wrong value type.")
  130. if value_type is int:
  131. if value < 0 or value > pow(2, 63) - 1:
  132. raise ValidationError("Int value should <= pow(2, 63) - 1.")
  133. if isinstance(value, bool):
  134. raise ValidationError("Wrong value type.")
  135. @staticmethod
  136. def check_param_value_type(data):
  137. """Check input param's value type."""
  138. for key, value in data.items():
  139. if key == "in":
  140. if not isinstance(value, (list, tuple)):
  141. raise ValidationError("The value of `in` operation must be list or tuple.")
  142. else:
  143. if isinstance(value, bool) or \
  144. (not isinstance(value, float) and not isinstance(value, int)):
  145. raise ValidationError("Wrong value type.")
  146. @staticmethod
  147. def check_operation(data):
  148. """Check input param's compare operation."""
  149. if not set(data.keys()).issubset(['in', 'eq']):
  150. raise ValidationError("Its operation should be `in` or `eq`.")
  151. if len(data.keys()) > 1:
  152. raise ValidationError("More than one operation.")
  153. @validates("loss")
  154. def check_loss(self, data):
  155. """Check loss."""
  156. SearchModelConditionParameter.check_param_value_type(data)
  157. @validates("learning_rate")
  158. def check_learning_rate(self, data):
  159. """Check learning_rate."""
  160. SearchModelConditionParameter.check_param_value_type(data)
  161. @validates("loss_function")
  162. def check_loss_function(self, data):
  163. """Check loss function."""
  164. SearchModelConditionParameter.check_operation(data)
  165. SearchModelConditionParameter.check_dict_value_type(data, str)
  166. @validates("train_dataset_path")
  167. def check_train_dataset_path(self, data):
  168. """Check train dataset path."""
  169. SearchModelConditionParameter.check_operation(data)
  170. SearchModelConditionParameter.check_dict_value_type(data, str)
  171. @validates("train_dataset_count")
  172. def check_train_dataset_count(self, data):
  173. """Check train dataset count."""
  174. SearchModelConditionParameter.check_dict_value_type(data, int)
  175. @validates("test_dataset_path")
  176. def check_test_dataset_path(self, data):
  177. """Check test dataset path."""
  178. SearchModelConditionParameter.check_operation(data)
  179. SearchModelConditionParameter.check_dict_value_type(data, str)
  180. @validates("test_dataset_count")
  181. def check_test_dataset_count(self, data):
  182. """Check test dataset count."""
  183. SearchModelConditionParameter.check_dict_value_type(data, int)
  184. @validates("network")
  185. def check_network(self, data):
  186. """Check network."""
  187. SearchModelConditionParameter.check_operation(data)
  188. SearchModelConditionParameter.check_dict_value_type(data, str)
  189. @validates("optimizer")
  190. def check_optimizer(self, data):
  191. """Check optimizer."""
  192. SearchModelConditionParameter.check_operation(data)
  193. SearchModelConditionParameter.check_dict_value_type(data, str)
  194. @validates("epoch")
  195. def check_epoch(self, data):
  196. """Check epoch."""
  197. SearchModelConditionParameter.check_dict_value_type(data, int)
  198. @validates("batch_size")
  199. def check_batch_size(self, data):
  200. """Check batch size."""
  201. SearchModelConditionParameter.check_dict_value_type(data, int)
  202. @validates("model_size")
  203. def check_model_size(self, data):
  204. """Check model size."""
  205. SearchModelConditionParameter.check_dict_value_type(data, int)
  206. @validates("summary_dir")
  207. def check_summary_dir(self, data):
  208. """Check summary dir."""
  209. SearchModelConditionParameter.check_operation(data)
  210. SearchModelConditionParameter.check_dict_value_type(data, str)
  211. @validates("dataset_mark")
  212. def check_dataset_mark(self, data):
  213. """Check dataset mark."""
  214. SearchModelConditionParameter.check_operation(data)
  215. SearchModelConditionParameter.check_dict_value_type(data, str)
  216. @validates("lineage_type")
  217. def check_lineage_type(self, data):
  218. """Check lineage type."""
  219. SearchModelConditionParameter.check_operation(data)
  220. SearchModelConditionParameter.check_dict_value_type(data, str)
  221. recv_types = []
  222. for key, value in data.items():
  223. if key == "in":
  224. recv_types = value
  225. else:
  226. recv_types.append(value)
  227. lineage_types = enum_to_list(LineageType)
  228. if not set(recv_types).issubset(lineage_types):
  229. raise ValidationError("Given lineage type should be one of %s." % lineage_types)
  230. @pre_load
  231. def check_comparision(self, data, **kwargs):
  232. """Check comparision for all parameters in schema."""
  233. for attr, condition in data.items():
  234. if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']:
  235. continue
  236. if not isinstance(attr, str):
  237. raise LineageParamValueError('The search attribute not supported.')
  238. if attr not in FIELD_MAPPING and not attr.startswith(('metric/', 'user_defined/')):
  239. raise LineageParamValueError('The search attribute not supported.')
  240. if not isinstance(condition, dict):
  241. raise LineageParamTypeError("The search_condition element {} should be dict."
  242. .format(attr))
  243. for key in condition.keys():
  244. if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
  245. raise LineageParamValueError("The compare condition should be in "
  246. "('eq', 'lt', 'gt', 'le', 'ge', 'in').")
  247. if attr.startswith('metric/'):
  248. if len(attr) == 7:
  249. raise LineageParamValueError(
  250. 'The search attribute not supported.'
  251. )
  252. try:
  253. SearchModelConditionParameter.check_param_value_type(condition)
  254. except ValidationError:
  255. raise MindInsightException(
  256. error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR,
  257. message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.format(attr)
  258. )
  259. return data

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。