import pytest import os from typing import Any from dataclasses import dataclass from paddle.optimizer import Adam from paddle.io import DataLoader from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK from tests.helpers.models.paddle_model import PaddleNormalModel_Classification from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @dataclass class MNISTTrainPaddleConfig: num_labels: int = 10 feature_dimension: int = 784 batch_size: int = 32 shuffle: bool = True validate_every = -5 driver: str = "paddle" device = "gpu" @dataclass class MNISTTrainFleetConfig: num_labels: int = 10 feature_dimension: int = 784 batch_size: int = 32 shuffle: bool = True validate_every = -5 @dataclass class TrainerParameters: model: Any = None optimizers: Any = None train_dataloader: Any = None validate_dataloaders: Any = None input_mapping: Any = None output_mapping: Any = None metrics: Any = None # @pytest.fixture(params=[0], autouse=True) # def model_and_optimizers(request): # """ # 初始化单卡模式的模型和优化器 # """ # trainer_params = TrainerParameters() # print(paddle.device.get_device()) # if request.param == 0: # trainer_params.model = PaddleNormalModel_Classification( # num_labels=MNISTTrainPaddleConfig.num_labels, # feature_dimension=MNISTTrainPaddleConfig.feature_dimension # ) # trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) # train_dataloader = DataLoader( # dataset=PaddleDataset_MNIST("train"), # batch_size=MNISTTrainPaddleConfig.batch_size, # shuffle=True # ) # val_dataloader = DataLoader( # dataset=PaddleDataset_MNIST(mode="test"), # batch_size=MNISTTrainPaddleConfig.batch_size, # shuffle=True # ) # trainer_params.train_dataloader = train_dataloader # trainer_params.validate_dataloaders = val_dataloader # trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every # trainer_params.metrics = {"acc": Accuracy()} # return trainer_params @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) @magic_argv_env_context def test_trainer_paddle( # model_and_optimizers: TrainerParameters, driver, device, callbacks, n_epochs=15, ): trainer_params = TrainerParameters() trainer_params.model = PaddleNormalModel_Classification( num_labels=MNISTTrainPaddleConfig.num_labels, feature_dimension=MNISTTrainPaddleConfig.feature_dimension ) trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( dataset=PaddleDataset_MNIST("train"), batch_size=MNISTTrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( dataset=PaddleDataset_MNIST(mode="test"), batch_size=MNISTTrainPaddleConfig.batch_size, shuffle=True ) trainer_params.train_dataloader = train_dataloader trainer_params.validate_dataloaders = val_dataloader trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every trainer_params.metrics = {"acc": Accuracy(backend="paddle")} if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: with pytest.raises(SystemExit) as exc: trainer = Trainer( model=trainer_params.model, driver=driver, device=device, optimizers=trainer_params.optimizers, train_dataloader=trainer_params.train_dataloader, validate_dataloaders=trainer_params.validate_dataloaders, validate_every=trainer_params.validate_every, input_mapping=trainer_params.input_mapping, output_mapping=trainer_params.output_mapping, metrics=trainer_params.metrics, n_epochs=n_epochs, callbacks=callbacks, ) assert exc.value.code == 0 return else: trainer = Trainer( model=trainer_params.model, driver=driver, device=device, optimizers=trainer_params.optimizers, train_dataloader=trainer_params.train_dataloader, validate_dataloaders=trainer_params.validate_dataloaders, validate_every=trainer_params.validate_every, input_mapping=trainer_params.input_mapping, output_mapping=trainer_params.output_mapping, metrics=trainer_params.metrics, n_epochs=n_epochs, callbacks=callbacks, ) trainer.run()