Merge pull request !2565 from ddwolf/add_case_for_precisoin_of_berttags/v0.6.0-beta
| @@ -348,16 +348,13 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph | |||
| uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | |||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream)); | |||
| auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream); | |||
| processed_streams_.emplace(true_stream_id); | |||
| auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||
| if (value_ptr == nullptr) { | |||
| if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { | |||
| continue; | |||
| } | |||
| auto need_active = GetValue<bool>(value_ptr); | |||
| auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst); | |||
| if (need_active) { | |||
| processed_streams_.emplace(cur_stream_id); | |||
| } | |||
| @@ -371,20 +368,17 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph | |||
| void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | |||
| vector<CNodePtr> *orders) { | |||
| orders->emplace_back(switch_ptr); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||
| if (value_ptr == nullptr) { | |||
| if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { | |||
| return; | |||
| } | |||
| auto need_active = GetValue<bool>(value_ptr); | |||
| auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst); | |||
| if (!need_active) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(switch_ptr); | |||
| auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream)); | |||
| auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream); | |||
| MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) | |||
| << "; active stream id:" << true_stream_id; | |||
| @@ -677,14 +671,11 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra | |||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | |||
| cur_cnode_ptr = cnode_ptr_list[i]; | |||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||
| if (value_ptr == nullptr) { | |||
| if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { | |||
| continue; | |||
| } | |||
| auto need_active = GetValue<bool>(value_ptr); | |||
| auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst); | |||
| if (need_active) { | |||
| auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | |||
| MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first"; | |||
| @@ -276,7 +276,8 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j | |||
| input_desc_json[kName] = op_input_name; | |||
| input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||
| if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||
| if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && | |||
| GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||
| MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | |||
| << "] as const tensor, shape: [" << Vector2Str(input_shape) | |||
| << "], value: " << input_desc_json[kValue]; | |||
| @@ -291,7 +291,7 @@ bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &no | |||
| // graph kernel cnode. | |||
| auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| return fg->has_flag(key); | |||
| return fg->has_attr(key); | |||
| } | |||
| size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { | |||
| @@ -68,9 +68,14 @@ class AnfRuntimeAlgorithm { | |||
| std::string node_debug_log = node->DebugString(); | |||
| MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); | |||
| } | |||
| auto primitive = GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| return GetValue<T>(primitive->GetAttr(key)); | |||
| // single op cnode. | |||
| if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { | |||
| return GetValue<T>(primitive->GetAttr(key)); | |||
| } | |||
| // graph kernel cnode. | |||
| auto fg = GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| return GetValue<T>(fg->get_attr(key)); | |||
| } | |||
| static bool IsTupleOutput(const AnfNodePtr &anf); | |||
| // set attr of anf node | |||
| @@ -0,0 +1,193 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train bert network without lossscale""" | |||
| import os | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| from mindspore import context | |||
| from mindspore import log as logger | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.optim import Lamb | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
| from mindspore.train.model import Model | |||
| from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell | |||
| from src.bert_model import BertConfig | |||
| DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] | |||
| SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" | |||
| def get_config(version='base', batch_size=1): | |||
| """get config""" | |||
| if version == 'base': | |||
| bert_config = BertConfig( | |||
| batch_size=batch_size, | |||
| seq_length=128, | |||
| vocab_size=21136, | |||
| hidden_size=768, | |||
| num_hidden_layers=2, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=True, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float32) | |||
| elif version == 'large': | |||
| bert_config = BertConfig( | |||
| batch_size=batch_size, | |||
| seq_length=128, | |||
| vocab_size=30522, | |||
| hidden_size=1024, | |||
| num_hidden_layers=2, | |||
| num_attention_heads=16, | |||
| intermediate_size=4096, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.0, | |||
| attention_probs_dropout_prob=0.0, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=True, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=True) | |||
| else: | |||
| bert_config = BertConfig(batch_size=batch_size) | |||
| return bert_config | |||
| def me_de_train_dataset(): | |||
| """test me de train dataset""" | |||
| # apply repeat operations | |||
| repeat_count = 1 | |||
| ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", | |||
| "next_sentence_labels", "masked_lm_positions", | |||
| "masked_lm_ids", "masked_lm_weights"], shuffle=False) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | |||
| ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) | |||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||
| # apply batch operations | |||
| batch_size = int(os.getenv('BATCH_SIZE', '16')) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| return ds | |||
| def weight_variable(shape): | |||
| """weight variable""" | |||
| np.random.seed(1) | |||
| ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) | |||
| return Tensor(ones) | |||
| class ModelCallback(Callback): | |||
| def __init__(self): | |||
| super(ModelCallback, self).__init__() | |||
| self.loss_list = [] | |||
| self.overflow_list = [] | |||
| self.lossscale_list = [] | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0]) | |||
| self.overflow_list.append(cb_params.net_outputs[1].asnumpy()) | |||
| self.lossscale_list.append(cb_params.net_outputs[2].asnumpy()) | |||
| print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_bert_tdt(): | |||
| """test bert tdt""" | |||
| np.random.seed(0) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) | |||
| context.set_context(enable_graph_kernel=True) | |||
| ds = me_de_train_dataset() | |||
| config = get_config(version='large', batch_size=16) | |||
| netwithloss = BertNetworkWithLoss(config, True) | |||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), | |||
| start_learning_rate=5e-5, end_learning_rate=1e-9, | |||
| power=10.0, warmup_steps=0, weight_decay=0.01) | |||
| scale_window = 3 | |||
| scale_manager = DynamicLossScaleManager(262144, 2, scale_window) | |||
| netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, | |||
| scale_update_cell=scale_manager.get_update_cell()) | |||
| netwithgrads.set_train(True) | |||
| model = Model(netwithgrads) | |||
| callback = ModelCallback() | |||
| params = netwithloss.trainable_params() | |||
| for param in params: | |||
| param.init_data() | |||
| value = param.default_input | |||
| name = param.name | |||
| if isinstance(value, Tensor): | |||
| if name.split('.')[-1] in ['weight']: | |||
| if name.split('.')[-3] in ['cls2']: | |||
| logger.info("***************** BERT param name is 1 {}".format(name)) | |||
| param.default_input = weight_variable(value.asnumpy().shape) | |||
| else: | |||
| logger.info("***************** BERT param name is 2 {}".format(name)) | |||
| tempshape = value.asnumpy().shape | |||
| shape = (tempshape[1], tempshape[0]) | |||
| weight_value = weight_variable(shape).asnumpy() | |||
| param.default_input = Tensor(np.transpose(weight_value, [1, 0])) | |||
| else: | |||
| logger.info("***************** BERT param name is 3 {}".format(name)) | |||
| param.default_input = weight_variable(value.asnumpy().shape) | |||
| model.train(1, ds, callbacks=callback, dataset_sink_mode=False) | |||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | |||
| loss_value = np.array(callback.loss_list) | |||
| expect_loss_value = [12.559319, 12.333815, 12.339806, 12.350235, 12.343947, 12.830965, 12.375336, 12.973715, | |||
| 12.57929, 12.7766905] | |||
| error = loss_value - expect_loss_value | |||
| print("loss value: {}".format(loss_value)) | |||
| print("error value: {}".format(error)) | |||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | |||
| overflow = np.array(callback.overflow_list) | |||
| expect_overflow = [True, True, True, True, False, False, False, True, False, False] | |||
| print("overflow: {}".format(overflow)) | |||
| assert (overflow == expect_overflow).all() | |||
| loss_scale = np.array(callback.lossscale_list) | |||
| expect_loss_scale = [131072.0, 65536.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0] | |||
| print("loss scale: {}".format(loss_scale)) | |||
| assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | |||
| if __name__ == '__main__': | |||
| test_bert_tdt() | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.nn import Cell | |||
| from mindspore.nn.graph_kernels import LambUpdateWithLR, LambNextMV | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class LambNet(Cell): | |||
| def __init__(self, i2, i5, x6): | |||
| super(LambNet, self).__init__() | |||
| self.i2 = Parameter(i2, name='i2') | |||
| self.i5 = Parameter(i5, name='i5') | |||
| self.x6 = Parameter(x6, name='x6') | |||
| self.lamb_next = LambNextMV() | |||
| self.lamb_update = LambUpdateWithLR() | |||
| def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, | |||
| x1, x2, x3, x4, x5, gy, se, my): | |||
| return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, | |||
| ix1, ix2, ix3), \ | |||
| self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) | |||
| def LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my): | |||
| trust_ratio = np.where(np.greater(x2, gy), | |||
| np.where(np.greater(x1, gy), np.divide(x2, x3), se), | |||
| se) | |||
| trust_ratio = np.maximum(np.minimum(trust_ratio, my), gy) | |||
| update_with_lr = trust_ratio * x4 * x5 | |||
| next_param = x6 - np.reshape(update_with_lr, x6.shape) | |||
| return next_param | |||
| def LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3): | |||
| m_fp32 = i5.astype(np.float32) | |||
| v_fp32 = i2.astype(np.float32) | |||
| next_m = i8 * m_fp32 + i9 * i4 | |||
| next_v = x0 * v_fp32 + x1 * i1 | |||
| next_mm = next_m / i6 | |||
| next_vv = next_v / i3 | |||
| update = next_mm / (np.sqrt(next_vv) + x3) | |||
| add3 = next_mm / np.sqrt(next_vv + x3) + x2 * i7 | |||
| return add3, next_m, next_v, update | |||
| def tensor_all(*args): | |||
| res = [Tensor(a) for a in args] | |||
| return res | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_graph_kernel_lamb(): | |||
| shape = [1, 16] | |||
| oshape = [1] | |||
| np.random.seed(0) | |||
| x1 = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| x2 = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| x3 = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| x4 = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| x5 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| x6 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| gy = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| se = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| my = np.random.normal(0, 1, oshape).astype(np.float32) | |||
| tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy = tensor_all( | |||
| x1, x2, x3, x4, x5, x6, gy, se, my) | |||
| np.random.seed(1) | |||
| i1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) | |||
| i2 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) | |||
| i3 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) | |||
| i4 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| i5 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| i6 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) | |||
| i7 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| i8 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| i9 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| ix0 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) | |||
| ix1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) | |||
| ix2 = np.random.normal(0, 1, shape).astype(np.float32) | |||
| ix3 = np.ones(shape).astype(np.float32) * 1e-6 | |||
| ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3 = \ | |||
| tensor_all(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, ix1, ix2, ix3) | |||
| context.set_context(enable_graph_kernel=True) | |||
| net = LambNet(ti2, ti5, tx6) | |||
| (wa3, wup), _ = net(ti1, ti3, ti4, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3, | |||
| tx1, tx2, tx3, tx4, tx5, tgy, tse, tmy) | |||
| wi2 = net.i2.data.asnumpy().copy() | |||
| wi5 = net.i5.data.asnumpy().copy() | |||
| ares = net.x6.data.asnumpy().copy() | |||
| context.set_context(enable_graph_kernel=False) | |||
| a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, | |||
| ix1, ix2, ix3) | |||
| np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my) | |||
| rtol = 0.0001 | |||
| atol = 0.0001 | |||
| wres = (wa3.asnumpy().copy(), wi5, wi2, wup.asnumpy().copy()) | |||
| bres = (a3, a0, a1, up) | |||
| cmp_res = list(map(lambda x, y: np.allclose(x, y, rtol, atol), | |||
| wres, bres)) | |||
| assert all(cmp_res) and np.allclose(ares, np_res, rtol, atol) | |||