diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index ce5bd7b488..83696e8e80 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -25,3 +25,6 @@ from .fused_adam import expand_fusedadam from .fused_adam_weight_decay import expand_fusedadamweightdecay from .reduce_mean import expand_reducemean from .tanh_grad import expand_tanhgrad +from .maximum_grad import expand_maximumgrad +from .minimum_grad import expand_minimumgrad +from .dropout_grad import expand_dropoutgrad diff --git a/mindspore/_extends/graph_kernel/expanders/dropout_grad.py b/mindspore/_extends/graph_kernel/expanders/dropout_grad.py new file mode 100644 index 0000000000..a18d2f1ff8 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/dropout_grad.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for DropoutGrad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_dropoutgrad(expand_info): + """DropoutGrad expander""" + # get op info. + dy_desc = expand_info['input_desc'][0] + mask_desc = expand_info['input_desc'][1] + keep_prob = None + for attr in expand_info['attr']: + if 'keep_prob' in attr: + keep_prob = attr['keep_prob'] + if keep_prob is None: + raise RuntimeError("keep_prob does not exist in attrs.") + # generate a graph. + graph_builder = builder.GraphBuilder() + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) + input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) + graph_scope.set_input(input_dy, input_mask) + r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob, "DefaultFormat") + # create op. + result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) + result = graph_builder.emit('Mul', [result, input_mask]) + # set graph output. + graph_scope.set_output(result) + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/expanders/maximum_grad.py b/mindspore/_extends/graph_kernel/expanders/maximum_grad.py new file mode 100644 index 0000000000..1625c5976c --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/maximum_grad.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for maximum_grad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_maximumgrad(expand_info): + """MaximumGrad expander""" + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + input_desc_2 = expand_info['input_desc'][2] + attrs = expand_info['attr'] + grad_x = None + grad_y = None + for item in attrs: + if 'grad_x' in item: + grad_x = item['grad_x'] + if 'grad_y' in item: + grad_y = item['grad_y'] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) + graph_scope.set_input(input_x, input_y, input_dout) + x_dtype = input_x.dtype + # cal result + ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) + ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype}) + dx = graph_builder.emit('Mul', [ge_result, input_dout]) + dy = graph_builder.emit('Sub', [input_dout, dx]) + + # set graph output according to grad_x and grad_y + if grad_x and grad_y: + graph_scope.set_output(dx, dy) + if grad_x and not grad_y: + graph_scope.set_output(dx) + if grad_y and not grad_x: + graph_scope.set_output(dy) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/expanders/minimum_grad.py b/mindspore/_extends/graph_kernel/expanders/minimum_grad.py new file mode 100644 index 0000000000..365ca478f5 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/minimum_grad.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for minimum_grad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_minimumgrad(expand_info): + """MinimumGrad expander""" + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + input_desc_2 = expand_info['input_desc'][2] + attrs = expand_info['attr'] + grad_x = None + grad_y = None + for item in attrs: + if 'grad_x' in item: + grad_x = item['grad_x'] + if 'grad_y' in item: + grad_y = item['grad_y'] + graph_builder = builder.GraphBuilder() + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) + graph_scope.set_input(input_x, input_y, input_dout) + x_dtype = input_x.dtype + + # cal result + le_result = graph_builder.emit('LessEqual', [input_x, input_y]) + le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype}) + dx = graph_builder.emit('Mul', [le_result, input_dout]) + dy = graph_builder.emit('Sub', [input_dout, dx]) + + # set graph output according to grad_x and grad_y + if grad_x and grad_y: + graph_scope.set_output(dx, dy) + if grad_x and not grad_y: + graph_scope.set_output(dx) + if grad_y and not grad_x: + graph_scope.set_output(dy) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 766c1a0547..e227cf4db4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -702,9 +702,9 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector GetExpandOps() { std::unordered_set expand_ops = { - prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, - prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad, - prim::kPrimReduceMean}; + prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, + prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad, + prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad}; return expand_ops; } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ad441456c1..2a4d126b88 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -161,6 +161,7 @@ inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared( inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); +inline const PrimitivePtr kPrimDropoutGrad = std::make_shared("DropoutGrad"); inline const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); inline const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); inline const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); diff --git a/tests/st/ops/graph_kernel/test_maximum_grad.py b/tests/st/ops/graph_kernel/test_maximum_grad.py new file mode 100644 index 0000000000..ceba8676c5 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_maximum_grad.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations._grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + + +class MaxmumGradNet(Cell): + def __init__(self): + super(MaxmumGradNet, self).__init__() + self.maximum_grad = G.MaximumGrad() + + def construct(self, x, y, dy): + return self.maximum_grad(x, y, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maximum_grad(): + np.random.seed(0) + input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32) + input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32) + input_dout = np.maximum(input_x, input_y).astype(np.float32) + net = MaxmumGradNet() + result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout)) + dx = input_dout * (input_x >= input_y) + dy = input_dout - dx + assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True) diff --git a/tests/st/ops/graph_kernel/test_minimum_grad.py b/tests/st/ops/graph_kernel/test_minimum_grad.py new file mode 100644 index 0000000000..b761688d87 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_minimum_grad.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations._grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + + +class MinmumGradNet(Cell): + def __init__(self): + super(MinmumGradNet, self).__init__() + self.minimum_grad = G.MinimumGrad() + + def construct(self, x, y, dy): + return self.minimum_grad(x, y, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_minimum_grad(): + np.random.seed(0) + input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32) + input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32) + input_dout = np.minimum(input_x, input_y).astype(np.float32) + net = MinmumGradNet() + result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout)) + dx = input_dout * (input_x <= input_y) + dy = input_dout - dx + assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)