Browse Source

update mass gpu network.

tags/v1.1.0
linqingke 5 years ago
parent
commit
771042a457
1 changed files with 0 additions and 9 deletions
  1. +0
    -9
      model_zoo/official/nlp/mass/src/transformer/infer_mass.py

+ 0
- 9
model_zoo/official/nlp/mass/src/transformer/infer_mass.py View File

@@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindspore import context

from src.dataset import load_dataset
from .transformer_for_infer import TransformerInferModel
from .transformer_for_train import TransformerTraining
from ..utils.load_weights import load_infer_weights

context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=False)


class TransformerInferCell(nn.Cell):
"""
Encapsulation class of transformer network infer.


Loading…
Cancel
Save