Browse Source

PyNative Ascend dynamic shape bugfix

r1.7
caifubi 4 years ago
parent
commit
b174ae05de
3 changed files with 11 additions and 8 deletions
  1. +2
    -2
      mindspore/ccsrc/backend/graph_compiler/backend.cc
  2. +8
    -5
      mindspore/ccsrc/runtime/pynative/run_op_helper.cc
  3. +1
    -1
      mindspore/ccsrc/runtime/pynative/run_op_helper.h

+ 2
- 2
mindspore/ccsrc/backend/graph_compiler/backend.cc View File

@@ -1353,7 +1353,7 @@ void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext>
auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, context->is_pynative_infer());
runtime::RunSingleOpGraph(context->graph(), GetTensorWithoutValueMask(context->op_run_info()),
context->device_context(), context->op_run_info().is_dynamic_shape);
context->device_context());
ClearGraphDeviceAddress(context->graph(), context->device_context(), context->op_run_info().is_gradient_out);
ClearInputDeviceAddress(context->graph(), context->device_context());
// Reset PyNative infer flag.
@@ -1473,7 +1473,7 @@ void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph
}
auto tensors_without_value_mask = GetTensorWithoutValueMask(*op_run_info);
runtime::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context);
runtime::RunSingleOpGraph(graph, tensors_without_value_mask, device_context, op_run_info->is_dynamic_shape);
runtime::RunSingleOpGraph(graph, tensors_without_value_mask, device_context);
ReleaseForwardOutput(op_run_info->input_tensors);
UpdateOutput(output_nodes, outputs);
ClearGraphDeviceAddress(graph, device_context, op_run_info->is_gradient_out);


+ 8
- 5
mindspore/ccsrc/runtime/pynative/run_op_helper.cc View File

@@ -19,6 +19,7 @@
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "utils/log_adapter.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/convert_utils.h"
@@ -363,7 +364,7 @@ void CopyDataToDevice(const KernelGraphPtr &graph, const std::vector<tensor::Ten
}

// kernel_mode launch
void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *device_context, bool is_dynamic_shape) {
void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
MS_LOG(DEBUG) << "Start";
@@ -372,6 +373,7 @@ void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *dev
const auto &execution_order = graph->execution_order();
for (auto const &node : execution_order) {
MS_EXCEPTION_IF_NULL(node);
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(node);
auto runtime_info = node->user_data<runtime::OpRuntimeInfo>();
MS_EXCEPTION_IF_NULL(runtime_info);

@@ -394,8 +396,9 @@ void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *dev
MS_LOG(EXCEPTION) << "Malloc for kernel output failed, Memory isn't enough, node:" << node->fullname_with_scope();
}
auto outputs = CreateKernelOutputAddress(runtime_info);

device_context->LaunchKernel(node, inputs, workspaces, outputs, is_dynamic_shape);
if (!device_context->LaunchKernel(node, inputs, workspaces, outputs, is_dynamic_shape)) {
MS_LOG(EXCEPTION) << "Launch kernel failed, name:" << node->fullname_with_scope();
}

if (is_dynamic_shape) {
UpdateOutputAddrSize(node, runtime_info);
@@ -441,10 +444,10 @@ void UpdateDeviceAddress(const KernelGraphPtr &graph, const std::vector<tensor::
}

void RunSingleOpGraph(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context, bool is_dynamic_shape) {
const device::DeviceContext *device_context) {
WaitCommunicationFinish(input_tensors);
CopyDataToDevice(graph, input_tensors, device_context);
LaunchKernels(graph, device_context, is_dynamic_shape);
LaunchKernels(graph, device_context);
ReleaseKernelResource(graph);
}
} // namespace mindspore::runtime

+ 1
- 1
mindspore/ccsrc/runtime/pynative/run_op_helper.h View File

@@ -27,6 +27,6 @@ void UpdateDeviceAddress(const KernelGraphPtr &graph, const std::vector<tensor::
const device::DeviceContext *device_context);

void RunSingleOpGraph(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context, bool is_dynamic_shape);
const device::DeviceContext *device_context);
} // namespace mindspore::runtime
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_

Loading…
Cancel
Save