Browse Source

!15415 [GraphKernel]adapt for logsoftmax in ascend

From: @wenfangpei
Reviewed-by: @gaoxiong1,@ckey_dou,@gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou,@ckey_dou
pull/15415/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
ed539597c2
3 changed files with 20 additions and 4 deletions
  1. +10
    -2
      mindspore/_extends/graph_kernel/expanders/logsoftmax.py
  2. +2
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc
  3. +8
    -2
      tests/st/ops/graph_kernel/test_logsoftmax.py

+ 10
- 2
mindspore/_extends/graph_kernel/expanders/logsoftmax.py View File

@@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD


@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis')
class LogSoftmax(Expander):
"""LogSoftmax expander"""
@@ -25,10 +25,18 @@ class LogSoftmax(Expander):
def _expand(self, graph_builder):
input_x = self.inputs[0]
axis = self.attrs['axis']
processor = self.processor

if isinstance(axis, int):
axis = (axis,)

max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
ori_dtype = input_x.dtype
if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else:
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub])
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})


+ 2
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc View File

@@ -53,6 +53,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimBiasAddGrad,
prim::kPrimGeLU,
prim::kPrimSoftmax,
prim::kPrimLogSoftmax,
prim::kPrimLogSoftmaxGrad,
prim::kPrimTile,
#if ENABLE_D
prim::kPrimSqrtGrad,


+ 8
- 2
tests/st/ops/graph_kernel/test_logsoftmax.py View File

@@ -106,12 +106,18 @@ def test_logsoftmaxgrad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_logsoftmaxgrad()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_logsoftmax_asend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_logsoftmax()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad_asend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_logsoftmaxgrad()

Loading…
Cancel
Save