Merge pull request !1697 from chenhaozhe/bert-optimizationtags/v0.5.0-beta
| @@ -588,7 +588,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||||
| graph->set_output_null(is_trace_back); | graph->set_output_null(is_trace_back); | ||||
| AddParameterToGraphInputs(func_graph->parameters(), graph.get()); | AddParameterToGraphInputs(func_graph->parameters(), graph.get()); | ||||
| MS_EXCEPTION_IF_NULL(context_); | MS_EXCEPTION_IF_NULL(context_); | ||||
| FuncGraphManagerPtr manager = context_->manager(); | |||||
| FuncGraphManagerPtr manager = MakeManager({graph}); | |||||
| if (manager) { | if (manager) { | ||||
| manager->AddFuncGraph(graph); | manager->AddFuncGraph(graph); | ||||
| graph->set_manager(manager); | graph->set_manager(manager); | ||||
| @@ -22,6 +22,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops.functional import identity | from mindspore.ops.functional import identity | ||||
| from mindspore.ops.operations import _inner_ops as inner | from mindspore.ops.operations import _inner_ops as inner | ||||
| from mindspore.ops.primitive import constexpr | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore._extends import cell_attr_register | from mindspore._extends import cell_attr_register | ||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| @@ -236,6 +237,13 @@ class Dense(Cell): | |||||
| return str_info | return str_info | ||||
| @constexpr | |||||
| def _is_equal_one(x): | |||||
| if x is None: | |||||
| return False | |||||
| return bool(x.asnumpy().mean() == 1.0) | |||||
| class ClipByNorm(Cell): | class ClipByNorm(Cell): | ||||
| r""" | r""" | ||||
| Clips tensor values to a maximum :math:`L_2`-norm. | Clips tensor values to a maximum :math:`L_2`-norm. | ||||
| @@ -290,7 +298,10 @@ class ClipByNorm(Cell): | |||||
| l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) | l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) | ||||
| l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) | l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) | ||||
| intermediate = x * clip_norm | |||||
| if _is_equal_one(clip_norm): | |||||
| intermediate = x | |||||
| else: | |||||
| intermediate = x * clip_norm | |||||
| max_norm = self.max_op(l2norm, clip_norm) | max_norm = self.max_op(l2norm, clip_norm) | ||||
| values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) | values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) | ||||
| values_clip = self.reshape(values_clip, self.shape(x)) | values_clip = self.reshape(values_clip, self.shape(x)) | ||||
| @@ -32,7 +32,6 @@ from .bert_model import BertModel | |||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| GRADIENT_CLIP_VALUE = 1.0 | GRADIENT_CLIP_VALUE = 1.0 | ||||
| _nn_clip_by_norm = nn.ClipByNorm() | |||||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | clip_grad = C.MultitypeFuncGraph("clip_grad") | ||||
| @@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad): | |||||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | ||||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | F.cast(F.tuple_to_array((clip_value,)), dt)) | ||||
| else: | else: | ||||
| new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| return new_grad | return new_grad | ||||
| @@ -56,7 +56,7 @@ if cfg.bert_network == 'base': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=32, | batch_size=32, | ||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21136, | |||||
| vocab_size=21128, | |||||
| hidden_size=768, | hidden_size=768, | ||||
| num_hidden_layers=12, | num_hidden_layers=12, | ||||
| num_attention_heads=12, | num_attention_heads=12, | ||||
| @@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=32, | batch_size=32, | ||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21136, | |||||
| vocab_size=21128, | |||||
| hidden_size=1024, | hidden_size=1024, | ||||
| num_hidden_layers=24, | num_hidden_layers=24, | ||||
| num_attention_heads=16, | num_attention_heads=16, | ||||
| @@ -98,7 +98,7 @@ if cfg.bert_network == 'large': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=16, | batch_size=16, | ||||
| seq_length=512, | seq_length=512, | ||||
| vocab_size=30528, | |||||
| vocab_size=30522, | |||||
| hidden_size=1024, | hidden_size=1024, | ||||
| num_hidden_layers=24, | num_hidden_layers=24, | ||||
| num_attention_heads=16, | num_attention_heads=16, | ||||
| @@ -26,3 +26,19 @@ def test_clip_by_norm(): | |||||
| x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32)) | x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32)) | ||||
| clip_norm = Tensor(np.array([1]).astype(np.float32)) | clip_norm = Tensor(np.array([1]).astype(np.float32)) | ||||
| clip_by_norm(x, clip_norm) | clip_by_norm(x, clip_norm) | ||||
| @non_graph_engine | |||||
| def test_clip_by_norm_const(): | |||||
| class Network(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Network, self).__init__() | |||||
| self.norm_value = Tensor(np.array([1]).astype(np.float32)) | |||||
| self.clip = nn.ClipByNorm() | |||||
| def construct(self, x): | |||||
| return self.clip(x, self.norm_value) | |||||
| net = Network() | |||||
| x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32)) | |||||
| output = net(x) | |||||