Browse Source

support int32 add int64 scalar op

tags/v1.3.0
yanglf1121 4 years ago
parent
commit
b0762483ce
2 changed files with 10 additions and 0 deletions
  1. +8
    -0
      mindspore/ccsrc/frontend/operator/cc_implementations.cc
  2. +2
    -0
      mindspore/core/utils/convert_utils_base.h

+ 8
- 0
mindspore/ccsrc/frontend/operator/cc_implementations.cc View File

@@ -237,6 +237,10 @@ bool InnerScalarGe(T x, U y) {
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
@@ -245,6 +249,10 @@ bool InnerScalarGe(T x, U y) {
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
return MakeValue(sum); \
} \
MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
<< ", y: " << y->ToString(); \
} while (0); \


+ 2
- 0
mindspore/core/utils/convert_utils_base.h View File

@@ -114,6 +114,8 @@ inline int32_t LongToInt(int64_t u) {
return static_cast<int32_t>(u);
}

inline int64_t IntToLong(int32_t v) { return static_cast<int64_t>(v); }

inline int64_t UlongToLong(uint64_t u) {
if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";


Loading…
Cancel
Save