Browse Source

!26595 PyNative Ascend MindRT

Merge pull request !26595 from caifubi/master-pynative-mindrt-lazy-build-with-ascend
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
76f1ca3d9f
13 changed files with 224 additions and 55 deletions
  1. +11
    -0
      mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h
  2. +0
    -1
      mindspore/ccsrc/backend/session/session_basic.cc
  3. +2
    -1
      mindspore/ccsrc/pipeline/jit/action.cc
  4. +1
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  5. +1
    -3
      mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
  6. +26
    -0
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc
  7. +5
    -1
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc
  8. +13
    -7
      mindspore/ccsrc/runtime/framework/graph_compiler.cc
  9. +79
    -20
      mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.cc
  10. +4
    -1
      mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.h
  11. +15
    -0
      mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.cc
  12. +1
    -0
      mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.h
  13. +66
    -20
      mindspore/ccsrc/vm/backend.cc

+ 11
- 0
mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include "runtime/device/ascend/ge_runtime/task_info.h"
#include "backend/kernel_compiler/kernel.h"
#include "runtime/device/executor/dynamic_kernel.h"
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#endif
@@ -44,9 +45,19 @@ class AscendKernelMod : public KernelMod {
#endif
}

void InitDynamicKernel(const CNodePtr &cnode_ptr, void *stream) {
if (dynamic_kernel_ == nullptr) {
stream_ = stream;
dynamic_kernel_ = GenDynamicKernel(cnode_ptr, stream);
dynamic_kernel_->Initialize();
}
}
device::DynamicKernelPtr DynamicKernel() const { return dynamic_kernel_; }

protected:
uint32_t block_dim_{1};
uint32_t stream_id_{0};
device::DynamicKernelPtr dynamic_kernel_{nullptr};
};
} // namespace kernel
} // namespace mindspore


+ 0
- 1
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -2190,7 +2190,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
// set execution order
std::vector<CNodePtr> exe_order = {cnode};
graph->set_execution_order(exe_order);
// set output
if (is_ascend) {
graph->set_output(cnode);
} else {


+ 2
- 1
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -96,7 +96,8 @@ void ResetMindRTEnable(const ResourcePtr &res) {
auto manager = func_graph->manager();
size_t graph_nums = manager->func_graphs().size();
// Heterogeneous scenario
if (graph_nums == 1 && context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice) {
if (graph_nums == 1 && (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice ||
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode)) {
return;
}
if (common::GetEnv("ENABLE_ASCEND_MINDRT") == "1" || common::kEnableAscendMindRT) {


+ 1
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -2116,7 +2116,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
#endif

VectorRef outputs;
if (!enable_mind_rt || cur_target == "Ascend") {
if (!enable_mind_rt) {
auto cur_session = GetCurrentSession(cur_target, device_id);
MS_EXCEPTION_IF_NULL(cur_session);
cur_session->RunOp(&op_run_info, &outputs);


+ 1
- 3
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc View File

@@ -178,9 +178,7 @@ void AscendDeviceAddress::BindDevice() const {
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
auto ascend_device_context = dynamic_cast<AscendDeviceContext *>(device_context);
MS_EXCEPTION_IF_NULL(ascend_device_context);
if (!ascend_device_context->BindDeviceToCurrentThread()) {
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
}
ascend_device_context->BindDeviceToCurrentThread();
} else {
MS_LOG(WARNING) << "device name is null.";
}


+ 26
- 0
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc View File

@@ -160,6 +160,29 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}

void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ref_node_map = graph->GetRefMap();
for (auto iter : ref_node_map) {
auto &output_pair = iter.first;
auto &input_pair = iter.second;
auto &ref_node = output_pair.first;
auto output_index = output_pair.second;
auto &input_node = input_pair.first;
auto input_node_output_index = input_pair.second;

auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index);
auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(ref_node, output_index);
// Just compare shared_ptr of two DeviceAddress.
// The ptr of DeviceAddress may still be nullptr.
if (input_addr != ref_node_output_addr) {
// The output of RefNode will not be used by subsequent Node.
// So update the DeviceAddress of the kernel directly instead of updating the ptr of the DeviceAddress.
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
}
}
}
} // namespace
void DataPrepareActor::Init() {
MS_EXCEPTION_IF_NULL(graph_compiler_info_);
@@ -295,6 +318,9 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve
const auto front_node = FetchFrontNodeByBackendNode(input_node, graph);
PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context, context);
}
// The DeviceAddress of the graph parameter has been updated.
// The output address of RefNode needs to be consistent with the address of parameter.
UpdateRefNodeOutputDeviceAddress(graph);
}

PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);


+ 5
- 1
mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc View File

@@ -335,7 +335,11 @@ void KernelActor::FetchOutputDeviceTensor() {
MS_EXCEPTION_IF_NULL(output_address);
if (output_size_list[i] != output_address->GetSize()) {
// The size of output address may be changed in dynamic shape scenario.
output_address->SetSize(output_size_list[i]);
// If the format of the DeviceAddress is different, then the size is originally different.
// Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size.
if (AnfAlgo::GetOutputFormat(kernel_, i) == output_address->format()) {
output_address->SetSize(output_size_list[i]);
}
}

// When the tensor is the output of graph or in dynamic shape scenario, the output tensor may be changed.


+ 13
- 7
mindspore/ccsrc/runtime/framework/graph_compiler.cc View File

@@ -402,6 +402,10 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
MS_EXCEPTION_IF_NULL(device_context);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return graph->graph_id();
}

#ifdef ENABLE_DUMP_IR
bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Dump .pb graph before graph optimization.
@@ -426,10 +430,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
// Adjust kernel graph before run graph.
device_context->PreprocessBeforeRunGraph(graph);

if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context);
}
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context);

graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));

@@ -482,14 +484,16 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
// Generate kernel graph.
MS_EXCEPTION_IF_NULL(session_);
KernelGraphPtr graph =
session_->ConstructSingleOpGraph(op_run_info, op_run_info.input_tensors, op_run_info.tensor_mask);
session_->ConstructSingleOpGraph(op_run_info, op_run_info.input_tensors, op_run_info.tensor_mask,
device_context->GetDeviceAddressType() == device::DeviceAddressType::kAscend);
MS_EXCEPTION_IF_NULL(graph);

// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
device_context->UnifyMindIR(graph);

MS_EXCEPTION_IF_NULL(device_context);
device_context->OptimizeSingleOpGraph(graph);

device_context->PreprocessBeforeRunSingleOpGraph(graph);

// Create device address for all anf nodes of graph.
CreateDeviceAddressWithoutWorkspace(graph, device_context);

@@ -520,6 +524,7 @@ void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graph
device_context->CreateKernel(node_to_build);

for (const auto &graph : graphs) {
device_context->PreprocessBeforeRunSingleOpGraph(graph);
CreateKernelWorkspaceDeviceAddress(device_context, graph);
}
}
@@ -553,6 +558,7 @@ void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &gr
CreateValueNodeDeviceAddress(device_context, graph);
CreateKernelOutputDeviceAddress(device_context, graph);
UpdateDeviceAddressForInplaceNode(graph);
UpdateDeviceAddressForRefNode(graph);
}

void GraphCompiler::GetParamAndOutputIndex(


+ 79
- 20
mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.cc View File

@@ -26,6 +26,8 @@
#include "runtime/device/ascend/ascend_stream_assign.h"
#include "runtime/device/ascend/kernel_build_ascend.h"
#include "runtime/hardware/ascend/ascend_graph_optimization.h"
#include "backend/kernel_compiler/ascend_kernel_mod.h"
#include "runtime/device/ascend/ascend_bucket.h"

#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
@@ -284,10 +286,18 @@ void AscendDeviceContext::Initialize() {
DumpInit(rank_id_);
#endif
compute_stream_ = runtime_instance_->compute_stream();
communication_stream_ = runtime_instance_->communication_stream();

initialized_ = true;
MS_LOG(INFO) << "Status record: Initialize success.";
}

bool AscendDeviceContext::IsGraphMode() {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
return context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode;
}

void AscendDeviceContext::Destroy() {
MS_LOG(INFO) << "Status record: Enter Destroy...";
if (!initialized_) {
@@ -306,7 +316,7 @@ void AscendDeviceContext::Destroy() {

std::vector<GraphSegmentPtr> AscendDeviceContext::PartitionGraph(
const FuncGraphPtr &func_graph, const std::vector<GraphSegmentPtr> &default_partition_segments) {
return std::vector<GraphSegmentPtr>();
return IsGraphMode() ? std::vector<GraphSegmentPtr>() : default_partition_segments;
}

void AscendDeviceContext::UnifyMindIR(const KernelGraphPtr &graph) const {
@@ -544,27 +554,71 @@ bool AscendDeviceContext::SyncStream(size_t stream_id) const {
bool AscendDeviceContext::IsExecutingSink(const KernelGraphPtr &graph) const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
return ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && IsGraphMode();
}

bool AscendDeviceContext::IsLoopCountSink(const KernelGraphPtr &graph) const { return true; }
bool AscendDeviceContext::IsLoopCountSink(const KernelGraphPtr &graph) const { return IsGraphMode(); }

// kernel by kernel mode interface
void AscendDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {
MS_LOG(ERROR) << "!!! Ascend with MindRT not support kernel by kernel mode. !!! ";
AscendGraphOptimization::GetInstance().OptimizeSingleOpGraph(graph);
}

void AscendDeviceContext::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const {
MS_LOG(ERROR) << "!!! Ascend with MindRT not support kernel by kernel mode. !!! ";
}
MS_EXCEPTION_IF_NULL(graph);
const auto &nodes = graph->execution_order();
// Remove placeholder
for (const auto &node : nodes) {
auto op_name = AnfAlgo::GetCNodeName(node);
static const std::set<std::string> place_holder_nodes = {kDynamicRNNOpName, kDynamicGRUV2OpName};
auto iter = place_holder_nodes.find(op_name);
if (iter != place_holder_nodes.end()) {
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "placeholder_index");
// Remove seq_length
auto input_num = AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(node)};
for (size_t i = 0; i < input_num; ++i) {
auto item = std::find(none_index.begin(), none_index.end(), i);
if (item == none_index.end()) {
auto input_node = AnfAlgo::GetInputNode(node, i);
new_inputs.emplace_back(input_node);
}
}
node->set_inputs(new_inputs);
}
}

device::ascend::InsertAtomicCleanOps(nodes, &node_atomics_);
std::vector<CNodePtr> atomic_nodes;
for (const auto &node : nodes) {
auto iter = node_atomics_.find(node);
if (iter != node_atomics_.end()) {
const auto &atomics = iter->second;
std::copy(atomics.begin(), atomics.end(), std::back_inserter(atomic_nodes));
}
}

void AscendDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {
MS_LOG(ERROR) << "!!! Ascend with MindRT not support function UpdateDynamicShape. !!! ";
CreateKernel(atomic_nodes);
}

void AscendDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {}

std::shared_ptr<Bucket> AscendDeviceContext::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const {
MS_LOG(ERROR) << "!!! Ascend with MindRT not support function CreateBucket. !!! ";
return DeviceContext::CreateBucket(bucket_id, bucket_size);
auto bucket = std::make_shared<AscendBucket>(bucket_id, bucket_size);
MS_EXCEPTION_IF_NULL(bucket);

bucket->Init({compute_stream_}, {communication_stream_});
return bucket;
}

bool AscendDeviceContext::SyncRuning() const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) &&
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !SyncStream()) {
return false;
}
return true;
}

bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
@@ -582,6 +636,19 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);

if (is_dynamic_shape) {
kernel::AscendKernelMod *ascend_kernel = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel);
ascend_kernel->InitDynamicKernel(kernel, compute_stream_);
auto dynamic_kernel = ascend_kernel->DynamicKernel();
MS_EXCEPTION_IF_NULL(dynamic_kernel);
dynamic_kernel->InferShape();
dynamic_kernel->UpdateArgs();
dynamic_kernel->Execute();
dynamic_kernel->PostExecute();
return SyncRuning();
}

std::vector<AddressPtr> real_inputs;
auto input_num = AnfAlgo::GetInputTensorNum(kernel);
if (input_num != inputs.size()) {
@@ -605,21 +672,13 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
return false;
}

// Sync running.
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) &&
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !SyncStream()) {
return false;
}
return true;
return SyncRuning();
}

bool AscendDeviceContext::BindDeviceToCurrentThread() const {
void AscendDeviceContext::BindDeviceToCurrentThread() const {
if (initialized_) {
runtime_instance_->SetContext();
}
return true;
}

bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace,


+ 4
- 1
mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.h View File

@@ -124,7 +124,7 @@ class AscendDeviceContext : public DeviceContext {
bool IsLoopCountSink(const KernelGraphPtr &graph) const override;

// set rt_context_ to this thread to control device
bool BindDeviceToCurrentThread() const;
void BindDeviceToCurrentThread() const;

// dump all graphs.
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const override;
@@ -135,6 +135,8 @@ class AscendDeviceContext : public DeviceContext {
void AssignInputMemory(const NotNull<KernelGraphPtr> &graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void LoadModel(const NotNull<KernelGraphPtr> &root_graph) const;
void UpdateExecOrder(const KernelGraphPtr &graph) const;
static bool IsGraphMode();
bool SyncRuning() const;

// Kernel Runtime --- only for task sink
AscendKernelRuntime *runtime_instance_{nullptr};
@@ -157,6 +159,7 @@ class AscendDeviceContext : public DeviceContext {
bool LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const;
void *compute_stream_;
void *communication_stream_;
};
} // namespace ascend
} // namespace device


+ 15
- 0
mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.cc View File

@@ -56,6 +56,20 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id();
}

void AscendGraphOptimization::OptimizeSingleOpGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
opt::RunOpAscendBackendIRFusionOptimization(graph);
SelectKernel(graph);
opt::RunOpAscendBackendOptimization(graph);

auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
// Cannot Hide nop node in PyNative mode.
// If there is more than one node in the graph,
// and one of the nodes is a nop node, the node will be hidden.
// The DAG of Actors will be invalid(lack an input edge).
}

void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto context_ptr = MsContext::GetInstance();
@@ -97,6 +111,7 @@ void AscendGraphOptimization::OptimizeExecutionOrder(const KernelGraphPtr &graph
DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id()));
}
#endif

// TODO(sida): do not hide nop op in kernel_by_kernel mode
if (graph->is_executing_sink()) {
opt::HideNopNode(graph.get());


+ 1
- 0
mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.h View File

@@ -43,6 +43,7 @@ class AscendGraphOptimization {
AscendGraphOptimization &operator=(const AscendGraphOptimization &) = delete;

void OptimizeGraph(const KernelGraphPtr &graph);
void OptimizeSingleOpGraph(const KernelGraphPtr &graph);
void SetOperatorInfo(const std::vector<CNodePtr> &nodes);
void UnifyMindIR(const KernelGraphPtr &graph);



+ 66
- 20
mindspore/ccsrc/vm/backend.cc View File

@@ -234,28 +234,25 @@ void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, Vec
}
}

void UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> &output_nodes,
const DeviceContext *device_context) {
for (auto &item_with_index : output_nodes) {
auto &output_node = item_with_index.first;
auto output_index = item_with_index.second;
if (output_node != nullptr) {
if (!AnfAlgo::OutputAddrExist(output_node, output_index, false)) {
void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &node : graph->execution_order()) {
auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
for (size_t i = 0; i < output_address_num; ++i) {
if (!AnfAlgo::OutputAddrExist(node, i, false)) {
continue;
}
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);

if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) {
const auto &device_address = AnfAlgo::GetMutableOutputAddr(node, i, false);
if (device_address == nullptr) {
continue;
}

MS_EXCEPTION_IF_NULL(device_context);
auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
device_tensor->format(), device_tensor->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);
new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
new_device_tensor->ResetRefCount();
AnfAlgo::SetOutputAddr(new_device_tensor, output_index, output_node.get());
auto new_device_address = device_context->CreateDeviceAddress(
nullptr, device_address->GetSize(), device_address->format(), device_address->type_id());
MS_EXCEPTION_IF_NULL(new_device_address);
new_device_address->set_original_ref_count(device_address->original_ref_count());
new_device_address->ResetRefCount();
AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
}
}
}
@@ -269,6 +266,51 @@ void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
}
}
}

std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph(
const KernelGraphPtr &graph, size_t input_tensors_size,
const std::vector<tensor::TensorPtr> &tensors_without_value_node) {
std::vector<tensor::TensorPtr> new_input_tensors;
if (graph->execution_order().size() != 1) {
return new_input_tensors;
}

const auto &node = graph->execution_order().back();
auto input_num = AnfAlgo::GetInputTensorNum(node);
// In most scenarios, input_num and input_tensors_size are equal.
// Except for special procedures, new ValueNode will be added to Graph in GraphOptimize.
if (input_num == input_tensors_size) {
return new_input_tensors;
}
MS_LOG(INFO) << "CNode input num:" << input_num << " input_tensors size:" << input_tensors_size;

std::map<size_t, tensor::TensorPtr> value_node_pos;
for (size_t i = 0; i < input_num; ++i) {
auto input = AnfAlgo::GetInputNode(node, i);
MS_EXCEPTION_IF_NULL(input);
if (input->isa<ValueNode>()) {
auto value_node = input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<tensor::TensorPtr>();
value_node_pos.emplace(i, tensor);
}
}

size_t cur_input_tensor_index = 0;
for (size_t i = 0; i < input_num; ++i) {
auto iter = value_node_pos.find(i);
if (iter == value_node_pos.end()) {
new_input_tensors.emplace_back(tensors_without_value_node[cur_input_tensor_index]);
cur_input_tensor_index++;
} else {
new_input_tensors.emplace_back(iter->second);
}
}
MS_LOG(INFO) << "new input tensor size:" << new_input_tensors.size();
return new_input_tensors;
}
} // namespace

VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
@@ -1125,6 +1167,9 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph,
}
}

std::vector<tensor::TensorPtr> new_input_tensors =
GetRealValueNodeTensorFromGraph(graph, input_tensors.size(), tensors_without_value_node);

for (auto &tensor : tensors_without_value_node) {
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->NeedWaitDevice()) {
@@ -1135,7 +1180,8 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph,
// Run actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(graph_compiler_info->name_);
MS_EXCEPTION_IF_NULL(actor_set);
runtime::GraphScheduler::GetInstance().Run(actor_set, {}, {tensors_without_value_node}, input_tensors,
runtime::GraphScheduler::GetInstance().Run(actor_set, {}, {tensors_without_value_node},
new_input_tensors.empty() ? input_tensors : new_input_tensors,
runtime::GraphExecutionStrategy::kStep);

// Release the kernel resource.
@@ -1200,7 +1246,7 @@ void MindRTBackend::LazyExecuteTaskCallback() {
const auto &context = op_run_task->context();
RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(),
context->graph_compiler_info(), context->device_context());
UpdateOutputDeviceAddress(context->output_nodes(), context->device_context());
ClearGraphDeviceAddress(context->graph(), context->device_context());
UpdateInputDeviceAddress(context->graph());

op_lazy_builder.PopOpRunTask();
@@ -1258,7 +1304,7 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
}
RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context);
UpdateOutput(output_nodes, outputs);
UpdateOutputDeviceAddress(output_nodes, device_context);
ClearGraphDeviceAddress(graph, device_context);
UpdateInputDeviceAddress(graph);
if (op_run_info->is_dynamic_shape) {
UpdateOutputAbstract(graph, op_run_info);


Loading…
Cancel
Save