You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_parameter.py 13 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define schema of model lineage input parameters."""
  16. from marshmallow import Schema, fields, ValidationError, pre_load, validates
  17. from marshmallow.validate import Range
  18. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
  19. LineageErrors
  20. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  21. LineageParamTypeError, LineageParamValueError
  22. from mindinsight.lineagemgr.common.log import logger
  23. from mindinsight.lineagemgr.common.utils import enum_to_list
  24. from mindinsight.lineagemgr.querier.querier import LineageType
  25. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  26. from mindinsight.utils.exceptions import MindInsightException
  27. try:
  28. from mindspore.dataset.engine import Dataset
  29. from mindspore.nn import Cell, Optimizer
  30. from mindspore.common.tensor import Tensor
  31. from mindspore.train.callback import _ListCallback
  32. except (ImportError, ModuleNotFoundError):
  33. logger.error('MindSpore Not Found!')
  34. class RunContextArgs(Schema):
  35. """Define the parameter schema for RunContext."""
  36. optimizer = fields.Function(allow_none=True)
  37. loss_fn = fields.Function(allow_none=True)
  38. net_outputs = fields.Function(allow_none=True)
  39. train_network = fields.Function(allow_none=True)
  40. train_dataset = fields.Function(allow_none=True)
  41. epoch_num = fields.Int(allow_none=True, validate=Range(min=1))
  42. batch_num = fields.Int(allow_none=True, validate=Range(min=0))
  43. cur_step_num = fields.Int(allow_none=True, validate=Range(min=0))
  44. parallel_mode = fields.Str(allow_none=True)
  45. device_number = fields.Int(allow_none=True, validate=Range(min=1))
  46. list_callback = fields.Function(allow_none=True)
  47. @pre_load
  48. def check_optimizer(self, data, **kwargs):
  49. optimizer = data.get("optimizer")
  50. if optimizer and not isinstance(optimizer, Optimizer):
  51. raise ValidationError({'optimizer': [
  52. "Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
  53. ]})
  54. return data
  55. @pre_load
  56. def check_train_network(self, data, **kwargs):
  57. train_network = data.get("train_network")
  58. if train_network and not isinstance(train_network, Cell):
  59. raise ValidationError({'train_network': [
  60. "Parameter train_network must be an instance of mindspore.nn.Cell."]})
  61. return data
  62. @pre_load
  63. def check_train_dataset(self, data, **kwargs):
  64. train_dataset = data.get("train_dataset")
  65. if train_dataset and not isinstance(train_dataset, Dataset):
  66. raise ValidationError({'train_dataset': [
  67. "Parameter train_dataset must be an instance of "
  68. "mindspore.dataengine.datasets.Dataset"]})
  69. return data
  70. @pre_load
  71. def check_loss(self, data, **kwargs):
  72. net_outputs = data.get("net_outputs")
  73. if net_outputs and not isinstance(net_outputs, Tensor):
  74. raise ValidationError({'net_outpus': [
  75. "The parameter net_outputs is invalid. It should be a Tensor."
  76. ]})
  77. return data
  78. @pre_load
  79. def check_list_callback(self, data, **kwargs):
  80. list_callback = data.get("list_callback")
  81. if list_callback and not isinstance(list_callback, _ListCallback):
  82. raise ValidationError({'list_callback': [
  83. "Parameter list_callback must be an instance of "
  84. "mindspore.train.callback._ListCallback."
  85. ]})
  86. return data
  87. class EvalParameter(Schema):
  88. """Define the parameter schema for Evaluation job."""
  89. valid_dataset = fields.Function(allow_none=True)
  90. metrics = fields.Dict(allow_none=True)
  91. @pre_load
  92. def check_valid_dataset(self, data, **kwargs):
  93. valid_dataset = data.get("valid_dataset")
  94. if valid_dataset and not isinstance(valid_dataset, Dataset):
  95. raise ValidationError({'valid_dataset': [
  96. "Parameter valid_dataset must be an instance of "
  97. "mindspore.dataengine.datasets.Dataset"]})
  98. return data
  99. class SearchModelConditionParameter(Schema):
  100. """Define the search model condition parameter schema."""
  101. summary_dir = fields.Dict()
  102. loss_function = fields.Dict()
  103. train_dataset_path = fields.Dict()
  104. train_dataset_count = fields.Dict()
  105. test_dataset_path = fields.Dict()
  106. test_dataset_count = fields.Dict()
  107. network = fields.Dict()
  108. optimizer = fields.Dict()
  109. learning_rate = fields.Dict()
  110. epoch = fields.Dict()
  111. batch_size = fields.Dict()
  112. device_num = fields.Dict()
  113. loss = fields.Dict()
  114. model_size = fields.Dict()
  115. limit = fields.Int(validate=lambda n: 0 < n <= 100)
  116. offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
  117. sorted_name = fields.Str()
  118. sorted_type = fields.Str(allow_none=True)
  119. dataset_mark = fields.Dict()
  120. lineage_type = fields.Dict()
  121. @staticmethod
  122. def check_dict_value_type(data, value_type):
  123. """Check dict value type and int scope."""
  124. for key, value in data.items():
  125. if key == "in":
  126. if not isinstance(value, (list, tuple)):
  127. raise ValidationError("The value of `in` operation must be list or tuple.")
  128. else:
  129. if not isinstance(value, value_type):
  130. raise ValidationError("Wrong value type.")
  131. if value_type is int:
  132. if value < 0 or value > pow(2, 63) - 1:
  133. raise ValidationError("Int value should <= pow(2, 63) - 1.")
  134. if isinstance(value, bool):
  135. raise ValidationError("Wrong value type.")
  136. @staticmethod
  137. def check_param_value_type(data):
  138. """Check input param's value type."""
  139. for key, value in data.items():
  140. if key == "in":
  141. if not isinstance(value, (list, tuple)):
  142. raise ValidationError("The value of `in` operation must be list or tuple.")
  143. else:
  144. if isinstance(value, bool) or \
  145. (not isinstance(value, float) and not isinstance(value, int)):
  146. raise ValidationError("Wrong value type.")
  147. @staticmethod
  148. def check_operation(data):
  149. """Check input param's compare operation."""
  150. if not set(data.keys()).issubset(['in', 'eq']):
  151. raise ValidationError("Its operation should be `in` or `eq`.")
  152. if len(data.keys()) > 1:
  153. raise ValidationError("More than one operation.")
  154. @validates("loss")
  155. def check_loss(self, data):
  156. """Check loss."""
  157. SearchModelConditionParameter.check_param_value_type(data)
  158. @validates("learning_rate")
  159. def check_learning_rate(self, data):
  160. """Check learning_rate."""
  161. SearchModelConditionParameter.check_param_value_type(data)
  162. @validates("loss_function")
  163. def check_loss_function(self, data):
  164. """Check loss function."""
  165. SearchModelConditionParameter.check_operation(data)
  166. SearchModelConditionParameter.check_dict_value_type(data, str)
  167. @validates("train_dataset_path")
  168. def check_train_dataset_path(self, data):
  169. """Check train dataset path."""
  170. SearchModelConditionParameter.check_operation(data)
  171. SearchModelConditionParameter.check_dict_value_type(data, str)
  172. @validates("train_dataset_count")
  173. def check_train_dataset_count(self, data):
  174. """Check train dataset count."""
  175. SearchModelConditionParameter.check_dict_value_type(data, int)
  176. @validates("test_dataset_path")
  177. def check_test_dataset_path(self, data):
  178. """Check test dataset path."""
  179. SearchModelConditionParameter.check_operation(data)
  180. SearchModelConditionParameter.check_dict_value_type(data, str)
  181. @validates("test_dataset_count")
  182. def check_test_dataset_count(self, data):
  183. """Check test dataset count."""
  184. SearchModelConditionParameter.check_dict_value_type(data, int)
  185. @validates("network")
  186. def check_network(self, data):
  187. """Check network."""
  188. SearchModelConditionParameter.check_operation(data)
  189. SearchModelConditionParameter.check_dict_value_type(data, str)
  190. @validates("optimizer")
  191. def check_optimizer(self, data):
  192. """Check optimizer."""
  193. SearchModelConditionParameter.check_operation(data)
  194. SearchModelConditionParameter.check_dict_value_type(data, str)
  195. @validates("epoch")
  196. def check_epoch(self, data):
  197. """Check epoch."""
  198. SearchModelConditionParameter.check_dict_value_type(data, int)
  199. @validates("batch_size")
  200. def check_batch_size(self, data):
  201. """Check batch size."""
  202. SearchModelConditionParameter.check_dict_value_type(data, int)
  203. @validates("device_num")
  204. def check_device_num(self, data):
  205. """Check device num."""
  206. SearchModelConditionParameter.check_dict_value_type(data, int)
  207. @validates("model_size")
  208. def check_model_size(self, data):
  209. """Check model size."""
  210. SearchModelConditionParameter.check_dict_value_type(data, int)
  211. @validates("summary_dir")
  212. def check_summary_dir(self, data):
  213. """Check summary dir."""
  214. SearchModelConditionParameter.check_operation(data)
  215. SearchModelConditionParameter.check_dict_value_type(data, str)
  216. @validates("dataset_mark")
  217. def check_dataset_mark(self, data):
  218. """Check dataset mark."""
  219. SearchModelConditionParameter.check_operation(data)
  220. SearchModelConditionParameter.check_dict_value_type(data, str)
  221. @validates("lineage_type")
  222. def check_lineage_type(self, data):
  223. """Check lineage type."""
  224. SearchModelConditionParameter.check_operation(data)
  225. SearchModelConditionParameter.check_dict_value_type(data, str)
  226. recv_types = []
  227. for key, value in data.items():
  228. if key == "in":
  229. recv_types = value
  230. else:
  231. recv_types.append(value)
  232. lineage_types = enum_to_list(LineageType)
  233. if not set(recv_types).issubset(lineage_types):
  234. raise ValidationError("Given lineage type should be one of %s." % lineage_types)
  235. @pre_load
  236. def check_comparision(self, data, **kwargs):
  237. """Check comparision for all parameters in schema."""
  238. for attr, condition in data.items():
  239. if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']:
  240. continue
  241. if not isinstance(attr, str):
  242. raise LineageParamValueError('The search attribute not supported.')
  243. if attr not in FIELD_MAPPING and not attr.startswith(('metric/', 'user_defined/')):
  244. raise LineageParamValueError('The search attribute not supported.')
  245. if not isinstance(condition, dict):
  246. raise LineageParamTypeError("The search_condition element {} should be dict."
  247. .format(attr))
  248. for key in condition.keys():
  249. if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
  250. raise LineageParamValueError("The compare condition should be in "
  251. "('eq', 'lt', 'gt', 'le', 'ge', 'in').")
  252. if attr.startswith('metric/'):
  253. if len(attr) == 7:
  254. raise LineageParamValueError(
  255. 'The search attribute not supported.'
  256. )
  257. try:
  258. SearchModelConditionParameter.check_param_value_type(condition)
  259. except ValidationError:
  260. raise MindInsightException(
  261. error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR,
  262. message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.format(attr)
  263. )
  264. return data