|
|
|
@@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF |
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD |
|
|
|
|
|
|
|
|
|
|
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT) |
|
|
|
@VLD.add_format(DF.DEFAULT) |
|
|
|
@VLD.check_attrs('axis') |
|
|
|
class LogSoftmax(Expander): |
|
|
|
"""LogSoftmax expander""" |
|
|
|
@@ -25,10 +25,18 @@ class LogSoftmax(Expander): |
|
|
|
def _expand(self, graph_builder): |
|
|
|
input_x = self.inputs[0] |
|
|
|
axis = self.attrs['axis'] |
|
|
|
processor = self.processor |
|
|
|
|
|
|
|
if isinstance(axis, int): |
|
|
|
axis = (axis,) |
|
|
|
|
|
|
|
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) |
|
|
|
ori_dtype = input_x.dtype |
|
|
|
if ori_dtype != "float16" and processor == "aicore": |
|
|
|
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) |
|
|
|
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True}) |
|
|
|
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) |
|
|
|
else: |
|
|
|
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) |
|
|
|
data_sub = graph_builder.emit('Sub', [input_x, max_x]) |
|
|
|
data_exp = graph_builder.emit('Exp', [data_sub]) |
|
|
|
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) |
|
|
|
|