# T6. fastNLP 与 paddle 或 jittor 的结合

&emsp; 1 &ensp; fastNLP 结合 paddle 训练模型
 
&emsp; &emsp; 1.1 &ensp; 关于 paddle 的简单介绍

&emsp; &emsp; 1.2 &ensp; 使用 paddle 搭建并训练模型

&emsp; 2 &ensp; fastNLP 结合 jittor 训练模型

&emsp; &emsp; 2.1 &ensp; 关于 jittor 的简单介绍

&emsp; &emsp; 2.2 &ensp; 使用 jittor 搭建并训练模型

&emsp; 3 &ensp; fastNLP 实现 paddle 与 pytorch 互转

In [None]:
from datasets import load_dataset

sst2data = load_dataset('glue', 'sst2')

In [None]:
import sys
sys.path.append('..')

from fastNLP import DataSet

dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]

dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, 
                   progress_bar="tqdm")
dataset.delete_field('sentence')
dataset.delete_field('label')
dataset.delete_field('idx')

from fastNLP import Vocabulary

vocab = Vocabulary()
vocab.from_dataset(dataset, field_name='words')
vocab.index_dataset(dataset, field_name='words')

train_dataset, evaluate_dataset = dataset.split(ratio=0.85)
print(type(train_dataset), isinstance(train_dataset, DataSet))

from fastNLP.io import DataBundle

data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})

## 1. fastNLP 结合 paddle 训练模型

```python
import paddle

lstm = paddle.nn.LSTM(16, 32, 2)

x = paddle.randn((4, 23, 16))
h = paddle.randn((2, 4, 32))
c = paddle.randn((2, 4, 32))

y, (h, c) = lstm(x, (h, c))

print(y.shape)  # [4, 23, 32]
print(h.shape)  # [2, 4, 32]
print(c.shape)  # [2, 4, 32]
```

In [None]:
import paddle
import paddle.nn as nn


class ClsByPaddle(nn.Layer):
    def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):
        nn.Layer.__init__(self)
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        # self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim,  
        #                     num_layers=num_layers, direction='bidirectional', dropout=dropout)
        self.mlp = nn.Sequential(('linear_1', nn.Linear(hidden_dim * 2, hidden_dim * 2)),
                                 ('activate', nn.ReLU()),
                                 ('linear_2', nn.Linear(hidden_dim * 2, output_dim)))
        
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, words):
        output = self.embedding(words)
        # output, (hidden, cell) = self.lstm(output)
        hidden = paddle.randn((2, words.shape[0], self.hidden_dim))
        output = self.mlp(paddle.concat((hidden[-1], hidden[-2]), axis=1))
        return output
    
    def train_step(self, words, target):
        pred = self(words)
        return {"loss": self.loss_fn(pred, target)}

    def evaluate_step(self, words, target):
        pred = self(words)
        pred = paddle.max(pred, axis=-1)[1]
        return {"pred": pred, "target": target}

In [None]:
model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)

model

In [None]:
from paddle.optimizer import AdamW

optimizers = AdamW(parameters=model.parameters(), learning_rate=1e-2)

In [None]:
from fastNLP import prepare_paddle_dataloader

# train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)
# evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)

dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)

In [None]:
from fastNLP import Trainer, Accuracy

trainer = Trainer(
    model=model,
    driver='paddle',
    device='gpu',                           # 'cpu', 'gpu', 'gpu:x'
    n_epochs=10,
    optimizers=optimizers,
    train_dataloader=dl_bundle['train'],    # train_dataloader,
    evaluate_dataloaders=dl_bundle['dev'],  # evaluate_dataloader,
    metrics={'acc': Accuracy()}
)

In [None]:
trainer.run(num_eval_batch_per_dl=10)  # 然后卡了？

## 2. fastNLP 结合 jittor 训练模型

In [None]:
import jittor
import jittor.nn as nn

from jittor import Module


class ClsByJittor(Module):
    def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):
        Module.__init__(self)
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim,  
                            num_layers=num_layers, bidirectional=True, dropout=dropout)
        self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),
                                  nn.ReLU(),
                                  nn.Linear(hidden_dim * 2, output_dim)])

        self.loss_fn = nn.BCELoss()

    def execute(self, words):
        output = self.embedding(words)
        output, (hidden, cell) = self.lstm(output)
        # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))
        output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))
        return output
    
    def train_step(self, words, target):
        pred = self(words)
        return {"loss": self.loss_fn(pred, target)}

    def evaluate_step(self, words, target):
        pred = self(words)
        pred = jittor.max(pred, axis=-1)[1]
        return {"pred": pred, "target": target}

In [None]:
model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)

model

In [None]:
from jittor.optim import AdamW

optimizers = AdamW(params=model.parameters(), lr=1e-2)

In [None]:
from fastNLP import prepare_jittor_dataloader

# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)
# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)

dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)

In [None]:
from fastNLP import Trainer, Accuracy

trainer = Trainer(
    model=model,
    driver='jittor',
    device='gpu',                           # 'cpu', 'gpu', 'cuda'
    n_epochs=10,
    optimizers=optimizers,
    train_dataloader=dl_bundle['train'],    # train_dataloader,
    evaluate_dataloaders=dl_bundle['dev'],  # evaluate_dataloader,
    metrics={'acc': Accuracy()}
)

In [None]:
trainer.run(num_eval_batch_per_dl=10)