|
|
|
@@ -41,8 +41,16 @@ def create_network(name, *args, **kwargs): |
|
|
|
Create transformer network for large. |
|
|
|
''' |
|
|
|
if name == 'transformer_large': |
|
|
|
if "batch_size" in kwargs: |
|
|
|
transformer_net_cfg_large.batch_size = kwargs["batch_size"] |
|
|
|
if "seq_length" in kwargs: |
|
|
|
transformer_net_cfg_large.seq_length = kwargs["seq_length"] |
|
|
|
if "vocab_size" in kwargs: |
|
|
|
transformer_net_cfg_large.vocab_size = kwargs["vocab_size"] |
|
|
|
is_training = kwargs.get("is_training", False) |
|
|
|
if not is_training: |
|
|
|
transformer_net_cfg_large.batch_size = 1 |
|
|
|
transformer_net_cfg_large.hidden_dropout_prob = 0. |
|
|
|
transformer_net_cfg_large.attention_probs_dropout_prob = 0. |
|
|
|
return TransformerModel(transformer_net_cfg_large, is_training, *args) |
|
|
|
raise NotImplementedError(f"{name} is not implemented in the repo") |