Browse Source

add valuenode info to graph_info

tags/v1.2.0-rc1
simson wudengsong 4 years ago
parent
commit
a9752ea5e1
3 changed files with 26 additions and 14 deletions
  1. +1
    -1
      mindspore/ccsrc/pipeline/pynative/base.h
  2. +23
    -11
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +2
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 1
- 1
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -57,7 +57,7 @@ struct OpExecInfo {


py::list op_inputs; py::list op_inputs;
py::dict op_attrs; py::dict op_attrs;
std::vector<bool> inputs_mask;
std::vector<int64_t> inputs_mask;
bool is_dynamic_shape = false; bool is_dynamic_shape = false;
std::string next_op_name = ""; std::string next_op_name = "";
bool is_mixed_precision_cast = false; bool is_mixed_precision_cast = false;


+ 23
- 11
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -266,17 +266,26 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
std::string graph_info; std::string graph_info;
// get input tensor info // get input tensor info
for (const auto &tensor : input_tensors) {
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_shape = tensor->shape();
for (size_t index = 0; index < input_tensors.size(); ++index) {
MS_EXCEPTION_IF_NULL(input_tensors[index]);
auto tensor_shape = input_tensors[index]->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(), (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(tensor->data_type()) + "_");
if (tensor->device_address() != nullptr) {
(void)graph_info.append(
std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id()) + "_");
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format() +
(void)graph_info.append(std::to_string(input_tensors[index]->data_type()) + "_");
auto tensor_addr = input_tensors[index]->device_address();
if (tensor_addr != nullptr) {
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id()) +
"_"); "_");
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format() + "_");
}
if (static_cast<int64_t>(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) {
if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) {
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())) + "_");
} else if (input_tensors[index]->Dtype()->type_id() == kNumberTypeFloat32) {
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c())) + "_");
} else {
MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!";
}
} }
} }
// get prim and abstract info // get prim and abstract info
@@ -387,8 +396,10 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
} else if (py::isinstance<py::float_>(input_object)) { } else if (py::isinstance<py::float_>(input_object)) {
double input_value = py::cast<py::float_>(input_object); double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32); tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::int_>(input_object)) { } else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64); tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::array>(input_object)) { } else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr); tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::list>(input_object)) { } else if (py::isinstance<py::list>(input_object)) {
@@ -452,6 +463,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
int64_t tensor_mask = static_cast<int64_t>(op_run_info->inputs_mask[index]); int64_t tensor_mask = static_cast<int64_t>(op_run_info->inputs_mask[index]);
ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
// mark tensors, data : 0, weight : 1, valuenode: 2 // mark tensors, data : 0, weight : 1, valuenode: 2
op_run_info->inputs_mask[index] = tensor_mask;
std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
} }
@@ -602,7 +614,7 @@ py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
} }
// make cnode for building grad graph if grad flag is set. // make cnode for building grad graph if grad flag is set.
abstract::AbstractBasePtrList args_spec_list; abstract::AbstractBasePtrList args_spec_list;
std::vector<bool> op_masks;
std::vector<int64_t> op_masks;
auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list); auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list);
op_exec_info->inputs_mask = op_masks; op_exec_info->inputs_mask = op_masks;
// get output abstract info // get output abstract info
@@ -677,7 +689,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
return op_exec_info; return op_exec_info;
} }


void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) { std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
auto prim = op_exec_info->py_primitive; auto prim = op_exec_info->py_primitive;
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) { for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
@@ -715,7 +727,7 @@ void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vecto
} }
} }


AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) { abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks); MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list); MS_EXCEPTION_IF_NULL(args_spec_list);


+ 2
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -208,9 +208,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
PynativeStatusCode *const status); PynativeStatusCode *const status);
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, std::vector<AnfNodePtr> *inputs,
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index); const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);


Loading…
Cancel
Save