Browse Source

fix trace

tags/v1.3.0
huangmengxi 4 years ago
parent
commit
733158ac31
2 changed files with 8 additions and 1 deletions
  1. +6
    -0
      mindspore/_checkparam.py
  2. +2
    -1
      mindspore/_extends/parse/standard_method.py

+ 6
- 0
mindspore/_checkparam.py View File

@@ -25,6 +25,7 @@ from collections.abc import Iterable
import numpy as np
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore._c_expression import Tensor as Tensor_

class Rel(Enum):
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
@@ -835,6 +836,11 @@ class Validator:
new_axes += (ax,)
return new_axes

@staticmethod
def empty_compile(dtype, shape):
"""Returns an empty Tensor."""
return Tensor_(dtype, shape)


def check_input_format(input_param):
"""Judge input format."""


+ 2
- 1
mindspore/_extends/parse/standard_method.py View File

@@ -777,7 +777,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
last_dim_end = min_(
shape[-2], max_(0, shape[-1] - offset)) - last_dim_begin
if last_dim_end <= 0:
return Tensor([])
return empty_compile(dtype, (0,))
size += (last_dim_end,)
res = F.tensor_slice(res, begin, size)
return res.astype(dtype)
@@ -1628,6 +1628,7 @@ infer_out_shape = constexpr(validator.infer_out_shape)
get_log2_size = constexpr(validator.get_log2_size)
check_axis_type = constexpr(validator.check_axis_type)
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
empty_compile = constexpr(validator.empty_compile)

def tensor_bool(x):
"""tensor as condition, if is constant, return immediate bool value"""


Loading…
Cancel
Save