Browse Source

for impl of 2nd-order and format

tags/v0.5.0-beta
z00478463 5 years ago
parent
commit
b57c0839ab
4 changed files with 8 additions and 5 deletions
  1. +1
    -1
      mindspore/ops/_op_impl/_custom_op/__init__.py
  2. +1
    -1
      mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py
  3. +5
    -2
      mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py
  4. +1
    -1
      mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py

+ 1
- 1
mindspore/ops/_op_impl/_custom_op/__init__.py View File

@@ -24,4 +24,4 @@ from .matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast
from .matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul
from .matmul_cube_impl import CusMatMulCube
from .matrix_combine_impl import CusMatrixCombine
from .transpose_02314_impl import CusTranspose02314
from .transpose02314_impl import CusTranspose02314

+ 1
- 1
mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py View File

@@ -101,4 +101,4 @@ def CusCholeskyTrsm(input_x,output, kernel_name):
tik_instance.data_move(res[block_index,0,0], temp_ub, 0, 1, 8 * vector_repeat_times * split_dim,0,0)

tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res])
return tik_instance
return tik_instance

+ 5
- 2
mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py View File

@@ -42,7 +42,8 @@ def CusFusedAbsMax1(input_x, output, origin_shape = None, kernel_name="fused_abs


if len(input_x_shape) > 2:
if (input_x_shape[0] == 1 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or (input_x_shape[0] == 4 and input_x_shape[1] == 16) or (input_x_shape[0] == 16 and input_x_shape[1] == 4): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
if (input_x_shape[0] == 1 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or (input_x_shape[0] == 4 and input_x_shape[1] == 16) or (input_x_shape[0] == 16 and input_x_shape[1] == 4):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
for val in input_x_shape:
@@ -131,7 +132,8 @@ def CusFusedAbsMax1(input_x, output, origin_shape = None, kernel_name="fused_abs
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif (input_x_shape[0] == 4 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or (input_x_shape[0] == 8 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 8): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
elif (input_x_shape[0] == 4 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or (input_x_shape[0] == 8 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 8):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
for val in input_x_shape:
@@ -608,6 +610,7 @@ def CusFusedAbsMax1(input_x, output, origin_shape = None, kernel_name="fused_abs
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
for val in input_x_shape:
total_elements *= val
blocks = 32
each_block_element = total_elements // blocks
with tik_instance.for_range(0,blocks,block_num=blocks) as block_index:


+ 1
- 1
mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py View File

@@ -54,7 +54,7 @@ def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}
shape_b_input = input_x2.get("shape")
matrix_max_input = input_x3.get("shape")
input_shape = (tuple(shape_a_input), tuple(shape_b_input), tuple(matrix_max_input))
if input_shape not in support_shape:
if input_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
if shape_a_temp[0] == 128 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 128:


Loading…
Cancel
Save