| @@ -23,3 +23,5 @@ from .bias_add import expand_biasadd | |||
| from .bias_add_grad import expand_biasaddgrad | |||
| from .fused_adam import expand_fusedadam | |||
| from .fused_adam_weight_decay import expand_fusedadamweightdecay | |||
| from .reduce_mean import expand_reducemean | |||
| from .tanh_grad import expand_tanhgrad | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for reduce_mean""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def expand_reducemean(expand_info): | |||
| """ReduceMean expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| attrs = expand_info['attr'] | |||
| axis = None | |||
| keep_dims = None | |||
| for item in attrs: | |||
| if 'axis' in item: | |||
| axis = item['axis'] | |||
| if 'keep_dims' in item: | |||
| keep_dims = item['keep_dims'] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| x_shape = input_x.shape | |||
| graph_scope.set_input(input_x) | |||
| # cal reduce_mean | |||
| # when axis = None, reduce axis are all | |||
| all_shape = 1.0 | |||
| real_axis = [] | |||
| if not axis: | |||
| for i, shape in enumerate(x_shape): | |||
| real_axis.append(i) | |||
| all_shape *= shape | |||
| else: | |||
| for idx in axis: | |||
| all_shape *= x_shape[idx] | |||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape, input_x.data_format) | |||
| if not axis: | |||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) | |||
| else: | |||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) | |||
| result = graph_builder.emit('RealDiv', [sum_x, all_shape_value]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for tanh_grad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| ONE = 1.0 | |||
| def expand_tanhgrad(expand_info): | |||
| """TanhGrad expander""" | |||
| # tanh_grad(y, dy) = dy * (1- y * y) | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| const_one = graph_builder.value(input_y.dtype, ONE, input_y.data_format) | |||
| graph_scope.set_input(input_y, input_dy) | |||
| # cal result | |||
| double_y = graph_builder.emit('Mul', [input_y, input_y]) | |||
| one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) | |||
| result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| @@ -702,9 +702,9 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, | |||
| prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, | |||
| }; | |||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, | |||
| prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad, | |||
| prim::kPrimReduceMean}; | |||
| return expand_ops; | |||
| } | |||
| @@ -29,12 +29,19 @@ bool BindValueToGraph::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto &value_nodes = kernel_graph->graph_value_nodes(); | |||
| bool changed = false; | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| for (auto node : todos) { | |||
| if (!GetValueNode<tensor::TensorPtr>(node)) { | |||
| continue; | |||
| } | |||
| if (auto vptr = node->cast<ValueNodePtr>(); value_nodes.count(vptr) == 0) { | |||
| kernel_graph->AddValueNodeToGraph(vptr); | |||
| auto new_node = kernel_graph->NewValueNode(vptr); | |||
| mng->Replace(vptr, new_node); | |||
| kernel_graph->AddValueNodeToGraph(new_node); | |||
| changed = true; | |||
| } | |||
| } | |||
| @@ -0,0 +1,132 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_graph_kernel=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self, decay_flag=True): | |||
| super(Net, self).__init__() | |||
| self.decay_flag = decay_flag | |||
| self.op_mul = P.Mul() | |||
| self.op_square = P.Square() | |||
| self.op_sqrt = P.Sqrt() | |||
| self.op_cast = P.Cast() | |||
| self.op_reshape = P.Reshape() | |||
| self.op_shape = P.Shape() | |||
| self.param = Parameter(Tensor(np.array([1, 3, 5]).astype(np.float32)), name='param') | |||
| self.m = Parameter(Tensor(np.array([0.11, 0.33, 0.55]).astype(np.float32)), name='m') | |||
| self.v = Parameter(Tensor(np.array([1.2, 3.4, 5.6]).astype(np.float32)), name='v') | |||
| @ms_function | |||
| def construct(self, beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr): | |||
| param_fp32 = self.op_cast(self.param, mstype.float32) | |||
| m_fp32 = self.op_cast(self.m, mstype.float32) | |||
| v_fp32 = self.op_cast(self.v, mstype.float32) | |||
| gradient_fp32 = self.op_cast(gradient, mstype.float32) | |||
| next_m = self.op_mul(beta1, m_fp32) + \ | |||
| self.op_mul(self.op_cast(one_sub_beta_1, mstype.float32), gradient_fp32) | |||
| next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(one_sub_beta_2, | |||
| mstype.float32), self.op_square(gradient_fp32)) | |||
| update = next_m / (eps + self.op_sqrt(next_v)) | |||
| if self.decay_flag: | |||
| update = self.op_mul(weight_decay_tensor, param_fp32) + update | |||
| update_with_lr = self.op_mul(lr, update) | |||
| next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32)) | |||
| depend_v = F.depend(next_param, F.assign(self.param, next_param)) | |||
| depend_v = F.depend(depend_v, F.assign(self.m, next_m)) | |||
| depend_v = F.depend(depend_v, F.assign(self.v, next_v)) | |||
| return depend_v | |||
| def CalFusedAdam(beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, param, m, v, | |||
| is_weight_decay=False): | |||
| m_expect = beta1 * m + one_sub_beta_1 * gradient | |||
| v_expect = beta2 * v + one_sub_beta_2 * gradient * gradient | |||
| update = m_expect / (np.sqrt(v_expect) + eps) | |||
| if is_weight_decay: | |||
| update += weight_decay_tensor * param | |||
| param_expect = param - lr * update | |||
| return param_expect, m_expect, v_expect | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_adam(): | |||
| np.random.seed(0) | |||
| beta1 = np.array([0.9]).astype(np.float32) | |||
| beta2 = np.array([0.999]).astype(np.float32) | |||
| one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32) | |||
| one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32) | |||
| lr = np.array([0.012]).astype(np.float32) | |||
| eps = np.array([1e-6]).astype(np.float32) | |||
| weight_decay_tensor = np.array([0.021]).astype(np.float32) | |||
| gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32) | |||
| m = np.array([0.11, 0.33, 0.55]).astype(np.float32) | |||
| v = np.array([1.2, 3.4, 5.6]).astype(np.float32) | |||
| param = np.array([1, 3, 5]).astype(np.float32) | |||
| is_weight_decay = False | |||
| opt = Net(is_weight_decay) | |||
| _ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps), | |||
| Tensor(weight_decay_tensor), Tensor(lr)) | |||
| param_expect, m_expect, v_expect = CalFusedAdam( | |||
| beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, | |||
| param, m, v, is_weight_decay) | |||
| assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||
| assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||
| assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||
| def test_adam_weight_decay(): | |||
| np.random.seed(0) | |||
| beta1 = np.array([0.9]).astype(np.float32) | |||
| beta2 = np.array([0.999]).astype(np.float32) | |||
| one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32) | |||
| one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32) | |||
| lr = np.array([0.012]).astype(np.float32) | |||
| eps = np.array([1e-6]).astype(np.float32) | |||
| weight_decay_tensor = np.array([0.021]).astype(np.float32) | |||
| gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32) | |||
| m = np.array([0.11, 0.33, 0.55]).astype(np.float32) | |||
| v = np.array([1.2, 3.4, 5.6]).astype(np.float32) | |||
| param = np.array([1, 3, 5]).astype(np.float32) | |||
| is_weight_decay = True | |||
| opt = Net(is_weight_decay) | |||
| _ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps), | |||
| Tensor(weight_decay_tensor), Tensor(lr)) | |||
| param_expect, m_expect, v_expect = CalFusedAdam( | |||
| beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, | |||
| param, m, v, is_weight_decay) | |||
| assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||
| assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||
| assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| import mindspore.ops.operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||
| class Net(Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.reduce_mean = P.ReduceMean(keep_dims=False) | |||
| def construct(self, x): | |||
| return self.reduce_mean(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_mean(): | |||
| np.random.seed(0) | |||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| expect = np.mean(input_x, keepdims=False) | |||
| net = Net() | |||
| result = net(Tensor(input_x)) | |||
| res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||
| assert res | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| import mindspore.ops.operations._grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||
| class TanhGradNet(Cell): | |||
| def __init__(self): | |||
| super(TanhGradNet, self).__init__() | |||
| self.tanh_grad = G.TanhGrad() | |||
| def construct(self, y, dy): | |||
| return self.tanh_grad(y, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_tanh_grad(): | |||
| np.random.seed(0) | |||
| input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| net = TanhGradNet() | |||
| result = net(Tensor(input_y), Tensor(input_dy)) | |||
| expect = input_dy * (1.0 - input_y * input_y) | |||
| res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||
| assert res | |||