|
|
@@ -17,7 +17,8 @@ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ |
|
|
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ |
|
|
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ |
|
|
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ |
|
|
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ |
|
|
BertTrainAccumulationAllReduceEachWithLossScaleCell, \ |
|
|
BertTrainAccumulationAllReduceEachWithLossScaleCell, \ |
|
|
BertTrainAccumulationAllReducePostWithLossScaleCell |
|
|
|
|
|
|
|
|
BertTrainAccumulationAllReducePostWithLossScaleCell, \ |
|
|
|
|
|
BertTrainOneStepWithLossScaleCellForAdam |
|
|
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ |
|
|
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ |
|
|
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ |
|
|
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ |
|
|
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ |
|
|
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ |
|
|
@@ -31,5 +32,6 @@ __all__ = [ |
|
|
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", |
|
|
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", |
|
|
"BertSelfAttention", "BertTransformer", "EmbeddingLookup", |
|
|
"BertSelfAttention", "BertTransformer", "EmbeddingLookup", |
|
|
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", |
|
|
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", |
|
|
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask" |
|
|
|
|
|
|
|
|
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask", |
|
|
|
|
|
"BertTrainOneStepWithLossScaleCellForAdam" |
|
|
] |
|
|
] |