You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_deepfm.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train_criteo."""
  16. import os
  17. import pytest
  18. from mindspore import context
  19. from mindspore.train.model import Model
  20. from mindspore.common import set_seed
  21. from src.deepfm import ModelBuilder, AUCMetric
  22. from src.config import DataConfig, ModelConfig, TrainConfig
  23. from src.dataset import create_dataset, DataType
  24. from src.callback import EvalCallBack, LossCallBack, TimeMonitor
  25. set_seed(1)
  26. @pytest.mark.level0
  27. @pytest.mark.platform_arm_ascend_training
  28. @pytest.mark.platform_x86_ascend_training
  29. @pytest.mark.env_onecard
  30. def test_deepfm():
  31. data_config = DataConfig()
  32. train_config = TrainConfig()
  33. device_id = int(os.getenv('DEVICE_ID'))
  34. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
  35. rank_size = None
  36. rank_id = None
  37. dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
  38. print("dataset_path:", dataset_path)
  39. ds_train = create_dataset(dataset_path,
  40. train_mode=True,
  41. epochs=1,
  42. batch_size=train_config.batch_size,
  43. data_type=DataType(data_config.data_format),
  44. rank_size=rank_size,
  45. rank_id=rank_id)
  46. model_builder = ModelBuilder(ModelConfig, TrainConfig)
  47. train_net, eval_net = model_builder.get_train_eval_net()
  48. auc_metric = AUCMetric()
  49. model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
  50. loss_file_name = './loss.log'
  51. time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
  52. loss_callback = LossCallBack(loss_file_path=loss_file_name)
  53. callback_list = [time_callback, loss_callback]
  54. eval_file_name = './auc.log'
  55. ds_eval = create_dataset(dataset_path, train_mode=False,
  56. epochs=1,
  57. batch_size=train_config.batch_size,
  58. data_type=DataType(data_config.data_format))
  59. eval_callback = EvalCallBack(model, ds_eval, auc_metric,
  60. eval_file_path=eval_file_name)
  61. callback_list.append(eval_callback)
  62. print("train_config.train_epochs:", train_config.train_epochs)
  63. model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
  64. export_loss_value = 0.51
  65. print("loss_callback.loss:", loss_callback.loss)
  66. assert loss_callback.loss < export_loss_value
  67. export_per_step_time = 40.0
  68. print("time_callback:", time_callback.per_step_time)
  69. assert time_callback.per_step_time < export_per_step_time
  70. print("*******test case pass!********")