|
|
|
@@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.train.model import Model |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
|
|
|
|
from mindspore import context |
|
|
|
|
|
|
|
from src.dataset import load_dataset |
|
|
|
from .transformer_for_infer import TransformerInferModel |
|
|
|
from .transformer_for_train import TransformerTraining |
|
|
|
from ..utils.load_weights import load_infer_weights |
|
|
|
|
|
|
|
context.set_context( |
|
|
|
mode=context.GRAPH_MODE, |
|
|
|
save_graphs=False, |
|
|
|
device_target="Ascend", |
|
|
|
reserve_class_name_in_scope=False) |
|
|
|
|
|
|
|
|
|
|
|
class TransformerInferCell(nn.Cell): |
|
|
|
""" |
|
|
|
Encapsulation class of transformer network infer. |
|
|
|
|