Browse Source

add updated parameter for graph

tags/v1.3.0
kswang 4 years ago
parent
commit
42155e7e5c
7 changed files with 77 additions and 26 deletions
  1. +4
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +20
    -14
      mindspore/ccsrc/backend/session/cpu_session.cc
  3. +12
    -8
      mindspore/ccsrc/backend/session/kernel_graph.cc
  4. +9
    -0
      mindspore/ccsrc/backend/session/kernel_graph.h
  5. +1
    -1
      mindspore/ccsrc/backend/session/session_basic.cc
  6. +28
    -2
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  7. +3
    -0
      mindspore/core/ir/tensor.h

+ 4
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -362,9 +362,12 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
AnfAlgo::IsParameterWeight(input_param)) {
tensor->set_device_address(device_address);
}
if (kernel_graph->IsUpdatedParameter(input_param)) {
tensor->SetIsUpdateByDevice();
}
}
tensor->set_sync_status(kNoNeedSync);
}


+ 20
- 14
mindspore/ccsrc/backend/session/cpu_session.cc View File

@@ -164,20 +164,26 @@ void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
}
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
auto &item = input_nodes[input_idx];
MS_EXCEPTION_IF_NULL(item);
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs_const[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address &&
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() !=
device::DeviceAddressType::kCPU ||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
tensor->data_sync(false);
}
auto &input_node = input_nodes[input_idx];
MS_EXCEPTION_IF_NULL(input_node);
if (!input_node->isa<Parameter>() || HasAbstractMonad(input_node)) {
continue;
}
auto address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = inputs_const[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address == nullptr || tensor_address == address) {
continue;
}
auto input_param = input_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(input_param) && !tensor->IsUpdatedByDevice()) {
continue;
}
if (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() !=
device::DeviceAddressType::kCPU) {
tensor->data_sync(false);
}
}
}


+ 12
- 8
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -1323,15 +1323,19 @@ void KernelGraph::SetOptimizerFlag() {
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
has_optimizer_ = true;
return;
} else if (node_name.find("Assign") == string::npos) {
continue;
}
if (node_name.find("Assign") != string::npos) {
for (auto &input : cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
if (input->isa<Parameter>() && AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>())) {
has_optimizer_ = true;
return;
}
for (auto &input : cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
auto real_node = AnfAlgo::VisitKernel(input, 0).first;
if (!real_node->isa<Parameter>()) {
continue;
}
auto param = real_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(param)) {
has_optimizer_ = true;
(void)updated_parameters_.insert(param);
}
}
}


+ 9
- 0
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -63,6 +63,7 @@ class KernelGraph : public FuncGraph {
ref_out_in_map_ = graph.ref_out_in_map_;
node_output_edges_ = graph.node_output_edges_;
summary_nodes_ = graph.summary_nodes_;
updated_parameters_ = graph.updated_parameters_;
executable_ = graph.executable_;
summary_node_exist_ = graph.summary_node_exist_;
valid_inputs_ = graph.valid_inputs_;
@@ -259,6 +260,12 @@ class KernelGraph : public FuncGraph {
void SetInputNodes();
const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
bool has_optimizer() const { return has_optimizer_; }
bool IsUpdatedParameter(const ParameterPtr &param) {
if (updated_parameters_.find(param) != updated_parameters_.end()) {
return true;
}
return false;
}
// handle graph dependency
void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
if (graph != nullptr) {
@@ -373,6 +380,8 @@ class KernelGraph : public FuncGraph {
std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
// parameters that will be updated when graph is executed
std::unordered_set<ParameterPtr> updated_parameters_;
// graph needn't execute
bool executable_{false};
// exist summary node in graph


+ 1
- 1
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -74,7 +74,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodeP
(AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
return false;
}
if (AnfAlgo::IsRealKernel(node)) {
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
return true;
}
(*idx) += 1;


+ 28
- 2
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -40,6 +40,28 @@ using mindspore::kernel::AddressPtr;

namespace mindspore {
namespace device {
namespace {
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto graph_inputs = graph->inputs();
std::vector<AnfNodePtr> result(graph_inputs.begin(), graph_inputs.end());
std::set<AnfNodePtr> inputs_set(graph_inputs.begin(), graph_inputs.end());
auto kernels = graph->execution_order();
for (auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
auto input_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t i = 0; i < input_num; ++i) {
auto input_node = kernel->input(i + 1);
auto input_real_node = AnfAlgo::VisitKernelWithReturnType(input_node, 0).first;
if (input_real_node->isa<Parameter>() && inputs_set.find(input_real_node) == inputs_set.end()) {
(void)inputs_set.insert(input_real_node);
(void)result.emplace_back(input_real_node);
}
}
}
return result;
}
} // namespace
constexpr size_t kMinInputSize = 2;

KernelRuntime::~KernelRuntime() {}
@@ -277,17 +299,21 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_manager_);
MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph->graph_id();
auto graph_inputs = graph->inputs();
auto graph_inputs = GetGraphInputs(graph);
auto graph_valid_input = graph->valid_inputs();
graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
std::vector<AnfNodePtr> need_alloc_nodes;
auto add_need_alloc_nodes = [&need_alloc_nodes, this](const AnfNodePtr &node) {
auto add_need_alloc_nodes = [&need_alloc_nodes, graph, this](const AnfNodePtr &node) {
if (!node->isa<Parameter>()) {
return;
}
if (NodeOutputDeviceAddressExist(node, 0)) {
return;
}
auto input_param = node->cast<ParameterPtr>();
if (!input_param->IsUsedByRealKernelInGraph(graph->graph_id())) {
return;
}
need_alloc_nodes.push_back(node);
};



+ 3
- 0
mindspore/core/ir/tensor.h View File

@@ -356,6 +356,8 @@ class Tensor : public MetaTensor {

bool IsGraphOutput() { return graph_output_; }
void SetIsGraphOutput() { graph_output_ = true; }
bool IsUpdatedByDevice() { return updated_by_device_; }
void SetIsUpdateByDevice() { updated_by_device_ = true; }

private:
bool init_flag_{false};
@@ -364,6 +366,7 @@ class Tensor : public MetaTensor {
mutable std::shared_ptr<WaitEvent> event_{nullptr};
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
bool graph_output_{false};
bool updated_by_device_{false};
DeviceSyncPtr device_sync_{nullptr};
bool cache_enable_{false};
std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};


Loading…
Cancel
Save