|
|
|
@@ -14,7 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""train_criteo.""" |
|
|
|
import os |
|
|
|
import pytest |
|
|
|
# import pytest |
|
|
|
|
|
|
|
from mindspore import context |
|
|
|
from mindspore.train.model import Model |
|
|
|
@@ -27,10 +27,10 @@ from src.callback import EvalCallBack, LossCallBack, TimeMonitor |
|
|
|
|
|
|
|
set_seed(1) |
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
# @pytest.mark.level0 |
|
|
|
# @pytest.mark.platform_arm_ascend_training |
|
|
|
# @pytest.mark.platform_x86_ascend_training |
|
|
|
# @pytest.mark.env_onecard |
|
|
|
def test_deepfm(): |
|
|
|
data_config = DataConfig() |
|
|
|
train_config = TrainConfig() |
|
|
|
|