From: @simson_wu Reviewed-by: @chujinjin,@zhoufeng54 Signed-off-by: @chujinjintags/v1.2.0-rc1
| @@ -66,6 +66,7 @@ | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/utils.h" | |||||
| #if ENABLE_CPU && ENABLE_GPU | #if ENABLE_CPU && ENABLE_GPU | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| @@ -448,7 +449,8 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask) { | const std::vector<int64_t> &tensors_mask) { | ||||
| // Check if the graph cache exists. | // Check if the graph cache exists. | ||||
| if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { | |||||
| if (run_op_graphs_.find(graph_info) != run_op_graphs_.end() && | |||||
| kOpCacheAllowList.find(op_run_info.op_name) == kOpCacheAllowList.end()) { | |||||
| return; | return; | ||||
| } | } | ||||
| // Prepare the graph | // Prepare the graph | ||||
| @@ -387,10 +387,8 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr | |||||
| } else if (py::isinstance<py::float_>(input_object)) { | } else if (py::isinstance<py::float_>(input_object)) { | ||||
| double input_value = py::cast<py::float_>(input_object); | double input_value = py::cast<py::float_>(input_object); | ||||
| tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32); | tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32); | ||||
| *tensor_mask = kValueNodeTensorMask; | |||||
| } else if (py::isinstance<py::int_>(input_object)) { | } else if (py::isinstance<py::int_>(input_object)) { | ||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64); | tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64); | ||||
| *tensor_mask = kValueNodeTensorMask; | |||||
| } else if (py::isinstance<py::array>(input_object)) { | } else if (py::isinstance<py::array>(input_object)) { | ||||
| tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr); | tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr); | ||||
| } else if (py::isinstance<py::list>(input_object)) { | } else if (py::isinstance<py::list>(input_object)) { | ||||
| @@ -271,6 +271,7 @@ constexpr auto kPadAndShiftOpName = "PadAndShift"; | |||||
| constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; | constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; | ||||
| constexpr auto kOneHotOpName = "OneHot"; | constexpr auto kOneHotOpName = "OneHot"; | ||||
| constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; | constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; | ||||
| constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler"; | |||||
| // Hcom Op Type | // Hcom Op Type | ||||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | ||||
| @@ -492,6 +493,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | |||||
| const std::set<std::string> kPosteriorOperatorSet = {kPullOpName}; | const std::set<std::string> kPosteriorOperatorSet = {kPullOpName}; | ||||
| const std::set<std::string> kOpCacheAllowList = {kUniformCandidateSamplerOpName}; | |||||
| const std::set<std::string> kHWSpecialFormatSet = { | const std::set<std::string> kHWSpecialFormatSet = { | ||||
| kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, | kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, | ||||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; | kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; | ||||