Browse Source

Pre Merge pull request !703 from 王芝伟/cherry-pick-1665110660

pull/703/MERGE
王芝伟 Gitee 3 years ago
parent
commit
8ebe38f136
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 4 additions and 4 deletions
  1. +1
    -1
      parser/common/op_registration_tbe.cc
  2. +1
    -1
      parser/common/op_registration_tbe.h
  3. +2
    -2
      parser/common/parser_fp16_t.cc

+ 1
- 1
parser/common/op_registration_tbe.cc View File

@@ -42,7 +42,7 @@ FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() {
return &instance;
}

bool OpRegistrationTbe::Finalize(const OpRegistrationData &reg_data, bool is_train) {
bool OpRegistrationTbe::Finalize(const OpRegistrationData &reg_data, bool is_train) const {
static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{domi::CAFFE, &caffe_op_map}};
if (is_train) {
op_map[domi::TENSORFLOW] = &tensorflow_train_op_map;


+ 1
- 1
parser/common/op_registration_tbe.h View File

@@ -24,7 +24,7 @@ class OpRegistrationTbe {
public:
static OpRegistrationTbe *Instance();

bool Finalize(const OpRegistrationData &reg_data, bool is_train = false);
bool Finalize(const OpRegistrationData &reg_data, bool is_train = false) const;

private:
bool RegisterParser(const OpRegistrationData &reg_data) const;


+ 2
- 2
parser/common/parser_fp16_t.cc View File

@@ -380,7 +380,7 @@ static uint16_t Fp16ToUInt16(const uint16_t &fp_val) {
}
}
bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen);
m_ret = static_cast<uint16_t>((long_int_m >> static_cast<int16_t>(kFp16ManLen + shift_out)) & kBitLen16Max);
m_ret = static_cast<uint16_t>((long_int_m >> static_cast<uint16_t>(kFp16ManLen + shift_out)) & kBitLen16Max);
if (need_round && m_ret != kBitLen16Max) {
m_ret++;
}
@@ -1020,7 +1020,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) {
for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1;
}
m_trunc = (m_ret & trunc_mask) << static_cast<int16_t>(static_cast<uint16_t>(kBitShift32) - e_tmp);
m_trunc = (m_ret & trunc_mask) << static_cast<uint16_t>(static_cast<uint16_t>(kBitShift32) - e_tmp);
for (int i = 0; i < e_tmp; i++) {
m_ret = (m_ret >> 1);
e_ret = e_ret + 1;


Loading…
Cancel
Save