From db8256e61f4efb952b1676389834fccc1cc33d47 Mon Sep 17 00:00:00 2001 From: wenfangpei Date: Tue, 20 Apr 2021 10:27:10 +0800 Subject: [PATCH] adapt for logsoftmax in ascend --- .../_extends/graph_kernel/expanders/logsoftmax.py | 12 ++++++++++-- .../optimizer/graph_kernel/graph_kernel_expander.cc | 2 ++ tests/st/ops/graph_kernel/test_logsoftmax.py | 10 ++++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/mindspore/_extends/graph_kernel/expanders/logsoftmax.py b/mindspore/_extends/graph_kernel/expanders/logsoftmax.py index 1765840889..27aa803560 100644 --- a/mindspore/_extends/graph_kernel/expanders/logsoftmax.py +++ b/mindspore/_extends/graph_kernel/expanders/logsoftmax.py @@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF from ._utils import Expander, ExpanderInfoValidator as VLD -@VLD.add_format(DF.DEFAULT, DF.DEFAULT) +@VLD.add_format(DF.DEFAULT) @VLD.check_attrs('axis') class LogSoftmax(Expander): """LogSoftmax expander""" @@ -25,10 +25,18 @@ class LogSoftmax(Expander): def _expand(self, graph_builder): input_x = self.inputs[0] axis = self.attrs['axis'] + processor = self.processor + if isinstance(axis, int): axis = (axis,) - max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) + ori_dtype = input_x.dtype + if ori_dtype != "float16" and processor == "aicore": + input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) + max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True}) + max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) + else: + max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) data_sub = graph_builder.emit('Sub', [input_x, max_x]) data_exp = graph_builder.emit('Exp', [data_sub]) data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index dc9a26dbdf..6922eb164c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -53,6 +53,8 @@ std::vector GetExpandOps() { prim::kPrimBiasAddGrad, prim::kPrimGeLU, prim::kPrimSoftmax, + prim::kPrimLogSoftmax, + prim::kPrimLogSoftmaxGrad, prim::kPrimTile, #if ENABLE_D prim::kPrimSqrtGrad, diff --git a/tests/st/ops/graph_kernel/test_logsoftmax.py b/tests/st/ops/graph_kernel/test_logsoftmax.py index 31d9dfb0e6..26acd857a8 100644 --- a/tests/st/ops/graph_kernel/test_logsoftmax.py +++ b/tests/st/ops/graph_kernel/test_logsoftmax.py @@ -106,12 +106,18 @@ def test_logsoftmaxgrad_gpu(): context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") test_logsoftmaxgrad() - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_logsoftmax_asend(): context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") test_logsoftmax() - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_logsoftmaxgrad_asend(): context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") test_logsoftmaxgrad()