Browse Source

refactor primitive ComputeFunction function

tags/v0.6.0-beta
WilliamLian 5 years ago
parent
commit
50e2fda52d
11 changed files with 86 additions and 45 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/session/kernel_graph.cc
  2. +7
    -5
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +24
    -0
      mindspore/ccsrc/utils/primitive_utils.cc
  4. +5
    -0
      mindspore/ccsrc/utils/primitive_utils.h
  5. +6
    -18
      mindspore/ccsrc/vm/vmimpl.cc
  6. +1
    -0
      mindspore/core/ir/primitive.h
  7. +28
    -8
      mindspore/core/ir/primitive_py.cc
  8. +4
    -1
      mindspore/core/ir/primitive_py.h
  9. +1
    -2
      tests/ut/cpp/operator/ops_test.cc
  10. +1
    -2
      tests/ut/cpp/parallel/step_parallel_test.cc
  11. +8
    -8
      tests/ut/cpp/vm/segment_runner_test.cc

+ 1
- 1
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -307,7 +307,7 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
if (AnfAlgo::IsRealKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}


+ 7
- 5
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -363,19 +363,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
}
auto func = op_exec_info->py_primitive->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
MS_LOG(ERROR) << "VM failed to get func";
auto primitive = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(primitive);
auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs);
if (py::isinstance<py::none>(result)) {
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
py::tuple err_ret(0);
return std::move(err_ret);
}

// execute op
py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
py::tuple tuple_result = py::make_tuple(result);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
return std::move(tuple_result);
}

bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,


+ 24
- 0
mindspore/ccsrc/utils/primitive_utils.cc View File

@@ -15,6 +15,9 @@
*/

#include "utils/primitive_utils.h"

#include <memory>

#include "pipeline/jit/parse/python_adapter.h"
#include "utils/log_adapter.h"
#include "common/utils.h"
@@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) {
py::object fn = mod.attr(common::SafeCStr(name));
return fn;
}

py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
auto py_args = py::tuple(args.size());
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString();
i++;
}
return py_args;
}

BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
auto func = GetComputeFunction(prim->name());
if (py::isinstance<py::none>(func)) {
MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented";
}
auto py_args = ConvertDatatoPyTuple(args);
py::object obj = func(*py_args);
return std::make_shared<PyObjectRef>(obj);
}
} // namespace mindspore

+ 5
- 0
mindspore/ccsrc/utils/primitive_utils.h View File

@@ -19,6 +19,7 @@

#include <string>
#include "pybind11/pybind11.h"
#include "utils/base_ref.h"

namespace py = pybind11;

@@ -28,6 +29,10 @@ py::function GetBpropFunctionByObj(py::object obj);
py::function GetBpropFunction(std::string name);

py::function GetComputeFunction(std::string name);

BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);

py::tuple ConvertDatatoPyTuple(const VectorRef &args);
} // namespace mindspore

#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_

+ 6
- 18
mindspore/ccsrc/vm/vmimpl.cc View File

@@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
}

BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim);

MS_LOG(DEBUG) << "operation start " << prim->name();
auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name());
if (py::isinstance<py::none>(func)) {
MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented";
}

py::tuple py_args = py::tuple(args.size());
MS_LOG(DEBUG) << "input for operation:";
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg: " << i << ":";
i++;
}
py::object obj = func(*py_args);
MS_LOG(DEBUG) << "result:" << py::str(obj);
return obj;
MS_EXCEPTION_IF_NULL(prim);
auto result = prim->RunComputeFunction(args);
if (result.is_null()) {
return RunComputeFunction(prim, args);
}
return result;
}

} // namespace compile


+ 1
- 0
mindspore/core/ir/primitive.h View File

@@ -83,6 +83,7 @@ class Primitive : public Named {

void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }

ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);


+ 28
- 8
mindspore/core/ir/primitive_py.cc View File

@@ -79,13 +79,7 @@ py::function PrimitivePy::GetBpropFunction() {
}

BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
auto py_args = py::tuple(args.size());
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg:" << i << ":";
i++;
}
auto py_args = ConvertDatatoPyTuple(args);
py::object obj;
bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) {
@@ -123,7 +117,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
return std::make_shared<PyObjectRef>(obj);
}

py::function PrimitivePy::GetComputeFunction() {
py::function PrimitivePy::GetComputeFunction() const {
static const char *const compute_func_name = "vm_impl";

if (py::hasattr(python_obj_, compute_func_name)) {
@@ -176,6 +170,32 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
this->set_hook(primitive_py->hook());
}

BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
auto py_args = ConvertDatatoPyTuple(args);
auto result = this->RunPyComputeFunction(py_args);
if (py::isinstance<py::none>(result)) {
return std::make_shared<BaseRef>(nullptr);
}
return std::make_shared<PyObjectRef>(result);
}

py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
auto func = this->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
return py::none();
}
auto result = func(*py_args);
return result;
}

bool PrimitivePy::HasComputeFunction() const {
auto func = GetComputeFunction();
if (py::isinstance<py::none>(func)) {
return false;
}
return true;
}

REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)


+ 4
- 1
mindspore/core/ir/primitive_py.h View File

@@ -41,7 +41,6 @@ class PrimitivePy : public Primitive {
~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
py::function GetComputeFunction();

void set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
@@ -57,11 +56,15 @@ class PrimitivePy : public Primitive {
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
BaseRef RunHookFunction(const VectorRef &args) const override;
BaseRef RunComputeFunction(const VectorRef &args) const override;
py::object RunPyComputeFunction(const py::tuple &py_args) const;
bool HasComputeFunction() const;
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
bool is_tuple_input_ = false;

private:
py::function GetComputeFunction() const;
py::object python_obj_;
py::function hook_;
std::vector<Signature> signatures_;


+ 1
- 2
tests/ut/cpp/operator/ops_test.cc View File

@@ -454,8 +454,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) {
ASSERT_TRUE(conv2d_ptr);
if (nullptr != conv2d_ptr) {
MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name();
auto func = conv2d_ptr->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
if(!conv2d_ptr->HasComputeFunction()){
MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented";
}



+ 1
- 2
tests/ut/cpp/parallel/step_parallel_test.cc View File

@@ -294,8 +294,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
ASSERT_TRUE(allreduce_ptr);
if (nullptr != allreduce_ptr) {
MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
auto func = allreduce_ptr->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
if (!allreduce_ptr->HasComputeFunction()) {
MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented";
}



+ 8
- 8
tests/ut/cpp/vm/segment_runner_test.cc View File

@@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {

std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);

AnfNodePtrList anf_list;
AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
@@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {

std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);

AnfNodePtrList anf_list;
AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
@@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) {

std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);

AnfNodePtrList anf_list;
AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
@@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) {

TEST_F(TestCompileSegmentRunner, test_RunOperation1) {
VectorRef args({1});
auto res = RunOperation(prim::kPrimIdentity, args);
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args);
ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1);
}

TEST_F(TestCompileSegmentRunner, test_RunOperation2) {
VectorRef args({1, 2});
auto res = RunOperation(prim::kPrimScalarGt, args);
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args);
ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false);
}
} // namespace compile


Loading…
Cancel
Save