Browse Source

!12728 fix precision error after cache modification

From: @simson_wu
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @chujinjin
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
00f25c8409
3 changed files with 6 additions and 3 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/session/gpu_session.cc
  2. +0
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +3
    -0
      mindspore/ccsrc/utils/utils.h

+ 3
- 1
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -66,6 +66,7 @@
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
#include "utils/ms_context.h"
#include "utils/utils.h"
#if ENABLE_CPU && ENABLE_GPU
#include "ps/util.h"
#include "ps/ps_cache/ps_cache_manager.h"
@@ -448,7 +449,8 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end() &&
kOpCacheAllowList.find(op_run_info.op_name) == kOpCacheAllowList.end()) {
return;
}
// Prepare the graph


+ 0
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -387,10 +387,8 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
} else if (py::isinstance<py::float_>(input_object)) {
double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::list>(input_object)) {


+ 3
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -271,6 +271,7 @@ constexpr auto kPadAndShiftOpName = "PadAndShift";
constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits";
constexpr auto kOneHotOpName = "OneHot";
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";
constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";

// Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
@@ -492,6 +493,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,

const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};

const std::set<std::string> kOpCacheAllowList = {kUniformCandidateSamplerOpName};

const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};


Loading…
Cancel
Save