| @@ -17,7 +17,8 @@ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ | |||||
| BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ | BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ | ||||
| BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | ||||
| BertTrainAccumulationAllReduceEachWithLossScaleCell, \ | BertTrainAccumulationAllReduceEachWithLossScaleCell, \ | ||||
| BertTrainAccumulationAllReducePostWithLossScaleCell | |||||
| BertTrainAccumulationAllReducePostWithLossScaleCell, \ | |||||
| BertTrainOneStepWithLossScaleCellForAdam | |||||
| from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | ||||
| BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | ||||
| EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | ||||
| @@ -31,5 +32,6 @@ __all__ = [ | |||||
| "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", | "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", | ||||
| "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | ||||
| "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", | "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", | ||||
| "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask" | |||||
| "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask", | |||||
| "BertTrainOneStepWithLossScaleCellForAdam" | |||||
| ] | ] | ||||