Browse Source

fixe cache error for aicpu op

tags/v1.1.0
chujinjin 5 years ago
parent
commit
4bcc88e333
2 changed files with 2 additions and 1 deletions
  1. +1
    -0
      mindspore/ccsrc/pipeline/pynative/base.h
  2. +1
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 1
- 0
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -64,6 +64,7 @@ using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);

const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
} // namespace pynative
} // namespace mindspore



+ 1
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -747,7 +747,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
}
}

if (op_exec_info->abstract == nullptr) {
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) {
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);


Loading…
Cancel
Save