Browse Source

Decouple ir.Signature from python

tags/v0.6.0-beta
He Wei 5 years ago
parent
commit
535f399251
3 changed files with 19 additions and 32 deletions
  1. +11
    -7
      mindspore/ccsrc/ir/primitive.cc
  2. +5
    -6
      mindspore/ccsrc/ir/signature.h
  3. +3
    -19
      mindspore/ccsrc/ir/signature_py.cc

+ 11
- 7
mindspore/ccsrc/ir/primitive.cc View File

@@ -30,17 +30,21 @@
#include "pybind_api/export_flags.h"

namespace mindspore {
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}

void PrimitivePy::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear();
for (auto &signature : signatures) {
std::string name;
SignatureEnumRW rw;
SignatureEnumKind kind;
py::object default_value;
SignatureEnumDType dtype;
std::tie(name, rw, kind, default_value, dtype) = signature;
signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype));
auto [name, rw, kind, arg_default, dtype] = signature;
auto default_value = PyArgToValue(arg_default);
signatures_.emplace_back(name, rw, kind, default_value, dtype);
}
set_has_signature(true);
}


+ 5
- 6
mindspore/ccsrc/ir/signature.h View File

@@ -16,14 +16,11 @@

#ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_
#define MINDSPORE_CCSRC_IR_SIGNATURE_H_

#include <string>
#include <vector>

#include "pybind11/operators.h"
#include "ir/value.h"

namespace py = pybind11;

namespace mindspore {
// Input signature, support type
enum SignatureEnumRW {
@@ -62,8 +59,10 @@ struct Signature {
ValuePtr default_value; // nullptr for no default value
SignatureEnumDType dtype;
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype);
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind);
const ValuePtr &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), default_value(arg_default), dtype(arg_dtype) {}
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: Signature(arg_name, rw_tag, arg_kind, nullptr, SignatureEnumDType::kDTypeEmptyDefaultValue) {}
};
} // namespace mindspore



mindspore/ccsrc/ir/signature.cc → mindspore/ccsrc/ir/signature_py.cc View File

@@ -15,30 +15,14 @@
*/

#include "ir/signature.h"

#include "pybind11/operators.h"
#include "pybind_api/api_register.h"
#include "pipeline/parse/data_converter.h"

namespace mindspore {
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
default_value = nullptr;
} else {
default_value = parse::data_converter::PyDataToValue(arg_default);
}
}

Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name),
rw(rw_tag),
kind(arg_kind),
default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
namespace py = pybind11;

namespace mindspore {
// Bind SignatureEnumRW as a python class.
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead)

Loading…
Cancel
Save