Browse Source

!31511 Change general vmap rule to metafg

Merge pull request !31511 from LiangZhibo/rule
r1.7
i-robot Gitee 4 years ago
parent
commit
e5484dd4dd
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 254 additions and 152 deletions
  1. +182
    -41
      mindspore/ccsrc/frontend/operator/composite/vmap.cc
  2. +23
    -0
      mindspore/ccsrc/frontend/operator/composite/vmap.h
  3. +28
    -17
      mindspore/ccsrc/frontend/optimizer/irpass/vmap_eliminate.cc
  4. +0
    -6
      mindspore/ccsrc/include/common/utils/primitive_utils.h
  5. +0
    -4
      mindspore/ccsrc/pybind_api/ir/primitive_py.cc
  6. +0
    -12
      mindspore/ccsrc/utils/primitive_utils.cc
  7. +4
    -2
      mindspore/python/mindspore/ops/_vmap/__init__.py
  8. +17
    -70
      mindspore/python/mindspore/ops/_vmap/vmap_base.py

+ 182
- 41
mindspore/ccsrc/frontend/operator/composite/vmap.cc View File

@@ -34,6 +34,41 @@
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, int64_t args_size, bool wrapped_tuple,
bool bind) {
std::vector<AnfNodePtr> prim_output_cnode_inputs;
(void)prim_output_cnode_inputs.emplace_back(prim);
if (wrapped_tuple) {
auto val_in_param = fg->add_parameter();
std::vector<AnfNodePtr> prim_inputs_cnode_inputs;
(void)prim_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (int64_t i = 0; i < args_size; ++i) {
auto val_in_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(i)});
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_cnode, NewValueNode(kValIndex)});
(void)prim_inputs_cnode_inputs.emplace_back(val_cnode);
}
auto prim_inputs_cnode = fg->NewCNode(prim_inputs_cnode_inputs);
(void)prim_output_cnode_inputs.emplace_back(prim_inputs_cnode);
} else {
for (int64_t i = 0; i < args_size; ++i) {
auto val_in_param = fg->add_parameter();
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(kValIndex)});
(void)prim_output_cnode_inputs.emplace_back(val_cnode);
}
}
auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs);
const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none");
auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn);
auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode});
if (bind) {
auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode});
fg->set_output(output_cnode);
return;
}
fg->set_output(bind_all_none_cnode);
return;
}

CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerBroadcastAxis(
const AnfNodePtr &inputs, const AnfNodePtr &out_axis, const AnfNodePtr &axis_size,
const AbstractBasePtr &inputs_abstract_elements_begin) const {
@@ -86,7 +121,6 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
auto value_cnode = fg_->NewCNode(value_cnode_inputs);
std::vector<AnfNodePtr> out_cnode_inputs;
if (inputs_abstract_elements_end->isa<abstract::AbstractNone>()) {
constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap";
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
(void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg));
@@ -99,8 +133,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
(void)dim_cnode_inputs.emplace_back(inputs);
(void)dim_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(1)));
auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
constexpr char kVmapFunctionModelName[] = "mindspore.numpy";
const py::function move_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "moveaxis");
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
auto move_axis_fg = parse::ParsePythonCode(move_axis);
(void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg));
(void)out_cnode_inputs.emplace_back(value_cnode);
@@ -190,7 +223,6 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu
auto src_abstract = each_inputs_abstract_elements[1];
CNodePtr out_cnode = nullptr;
if (src_abstract->isa<abstract::AbstractNone>() && !dst_abstract->isa<abstract::AbstractNone>()) {
constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap";
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size});
@@ -199,8 +231,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu
} else if (src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
out_cnode = val_cnode;
} else {
constexpr char kVmapFunctionModelName[] = "mindspore.numpy";
const py::function move_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "moveaxis");
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
auto move_axis_fg = parse::ParsePythonCode(move_axis);
out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode});
}
@@ -303,9 +334,6 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList
return args_spec_list;
};
auto tuple_elements = get_tuple_elements(args_spec_list);

constexpr int64_t val_index = 0;
constexpr int64_t dim_index = 1;
bool is_all_none = true;
constexpr size_t kCurTupleSize = 2;
for (int64_t i = 0; i < inputs_size; ++i) {
@@ -321,7 +349,7 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList
MS_LOG(EXCEPTION) << "The " << i + offset << "th input to VmapGeneralPreprocess should be a tuple with two "
<< "elements but got " << cur_arg_tuple_elements.size() << " elements.";
}
if (!cur_arg_tuple_elements[dim_index]->isa<abstract::AbstractNone>()) {
if (!cur_arg_tuple_elements[kDimIndex]->isa<abstract::AbstractNone>()) {
MS_LOG(INFO) << "The " << i + offset << "th input to VmapGeneralPreprocess has not None dim value.";
is_all_none = false;
break;
@@ -334,39 +362,11 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList
for (size_t i = 1; i < args_size; ++i) {
(void)fg->add_parameter();
}
(void)output_cnode_inputs.emplace_back(NewValueNode(false));
(void)output_cnode_inputs.emplace_back(NewValueNode(kNone));
auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(false), NewValueNode(kNone)});
fg->set_output(output_cnode);
} else {
std::vector<AnfNodePtr> prim_output_cnode_inputs;
(void)prim_output_cnode_inputs.emplace_back(prim);
if (wrapped_tuple) {
auto val_in_param = fg->add_parameter();
std::vector<AnfNodePtr> prim_inputs_cnode_inputs;
(void)prim_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (int64_t i = 0; i < inputs_size; ++i) {
auto val_in_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(i)});
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_cnode, NewValueNode(val_index)});
(void)prim_inputs_cnode_inputs.emplace_back(val_cnode);
}
auto prim_inputs_cnode = fg->NewCNode(prim_inputs_cnode_inputs);
(void)prim_output_cnode_inputs.emplace_back(prim_inputs_cnode);
} else {
for (int64_t i = 0; i < inputs_size; ++i) {
auto val_in_param = fg->add_parameter();
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(val_index)});
(void)prim_output_cnode_inputs.emplace_back(val_cnode);
}
}
auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs);
const char kVmapFunctionModelName[] = "mindspore.ops._vmap";
const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none");
auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn);
auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode});
(void)output_cnode_inputs.emplace_back(NewValueNode(true));
(void)output_cnode_inputs.emplace_back(bind_all_none_cnode);
GenerateFuncGraphAllNone(fg, prim, inputs_size, wrapped_tuple, true);
}
auto output_cnode = fg->NewCNode(output_cnode_inputs);
fg->set_output(output_cnode);
return fg;
}

@@ -375,5 +375,146 @@ REGISTER_PYBIND_DEFINE(VmapGeneralPreprocess_, ([](const py::module *m) {
*m, "VmapGeneralPreprocess_")
.def(py::init<std::string &>(), py::arg("fn"));
}));

CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &tuple_elements_abstract, bool wrapped_tuple,
int64_t args_size) {
AnfNodePtr single_input = nullptr;
if (wrapped_tuple) {
single_input = fg_->add_parameter();
}

CNodeInpusList map_inputs(axis_size_);
for (int64_t i = 0; i < args_size; ++i) {
AnfNodePtr cur_arg_node = nullptr;
if (wrapped_tuple) {
cur_arg_node = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), single_input, NewValueNode(i)});
} else {
cur_arg_node = fg_->add_parameter();
}
auto tuple_element_abstract = tuple_elements_abstract[i];
auto val_abstract = tuple_element_abstract[kValIndex];
auto dim_abstract = tuple_element_abstract[kDimIndex];
AnfNodePtr val_cnode =
fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kValIndex)});

if (dim_abstract->isa<abstract::AbstractNone>()) {
for (int64_t m = 0; m < axis_size_; ++m) {
map_inputs[m].push_back(val_cnode);
}
} else {
if (!val_abstract->isa<abstract::AbstractTensor>()) {
MS_LOG(EXCEPTION) << "A variable of type other than `Tensor` is accepted, but the source axis is not `None`";
}
AnfNodePtr dim_cnode =
fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)});
const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack");
auto unstack_fg_ = parse::ParsePythonCode(unstack_fn);
auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode});
for (int64_t m = 0; m < axis_size_; ++m) {
auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)});
map_inputs[m].push_back(out_element_cnode);
}
}
}
return map_inputs;
}

// When the primitive does not registered the relevant specific VmapRule, it attempts to get
// this the general rule. The general rule is combining loop and stack operators to simulate
// the behavior of Vmap. Noted that, general rules does not guarantee the correctness of
// execution results.
// Currently, only the following types of primitives are supported:
// 1、 Most calculation operations, whose inputs are tensors, scalars or both of them.
// (If all elements in a tuple are scalars, it is also considered scalar.)
// 2、 Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped into a tuple.
// In other words, we do not support any tuple wrapped variables except for the special cases
// listed above.
FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
fg_ = std::make_shared<FuncGraph>();
int64_t args_size = static_cast<int64_t>(args_spec_list.size());

bool wrapped_tuple = false;
auto get_tuple_elements = [&wrapped_tuple,
&args_size](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & {
if (args_size == 1) {
auto arg = args_spec_list[0];
if (!arg->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractTuple but got: "
<< arg->ToString() << ".";
}
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
const auto &arg_tuple_elements = arg_tuple->elements();
if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
// Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
// into a tuple. We need to process the internal elements separately and then re-wrap
// them into tuple. Handle case such as args:(((A, 0), (B, 1), (C, None)),). Which
// different from the case with single input parameter ((A, 0),).
wrapped_tuple = true;
args_size = arg_tuple_elements.size();
return arg_tuple_elements;
}
}
return args_spec_list;
};
auto tuple_elements = get_tuple_elements(args_spec_list);

bool is_all_none = true;
constexpr size_t kCurTupleSize = 2;
InputsAbstractList tuple_elements_abstract(args_size);
for (int64_t i = 0; i < args_size; ++i) {
auto cur_arg = tuple_elements[i];
if (!cur_arg->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "The " << i
<< "th input to VmapGeneralPreprocess should be AbstractTuple but got: " << cur_arg->ToString()
<< ".";
}
auto cur_arg_tuple = cur_arg->cast<abstract::AbstractTuplePtr>();
auto cur_arg_tuple_elements = cur_arg_tuple->elements();
if (cur_arg_tuple_elements.size() != kCurTupleSize) {
MS_LOG(EXCEPTION) << "The " << i << "th input to VmapGeneralPreprocess should be a tuple with two "
<< "elements but got " << cur_arg_tuple_elements.size() << " elements.";
}
auto dim_abstract = cur_arg_tuple_elements[kDimIndex];
if (is_all_none && !dim_abstract->isa<abstract::AbstractNone>()) {
MS_LOG(INFO) << "The " << i << "th input to VmapGeneralPreprocess has not None dim value.";
is_all_none = false;
}
auto val_abstract = cur_arg_tuple_elements[kValIndex];
std::vector<abstract::AbstractBasePtr> element_abstract = {val_abstract, dim_abstract};
tuple_elements_abstract[i] = element_abstract;
}

if (is_all_none) {
GenerateFuncGraphAllNone(fg_, NewValueNode(prim_), args_size, wrapped_tuple, false);
return fg_;
}

CNodeInpusList map_inputs = ConstructMapInput(tuple_elements_abstract, wrapped_tuple, args_size);

std::vector<AnfNodePtr> output_cnode_inputs;
(void)output_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto map_input : map_inputs) {
std::vector<AnfNodePtr> output_element_cnode_inputs;
if (wrapped_tuple) {
std::vector<AnfNodePtr> tuple_cnode_inputs{NewValueNode(prim::kPrimMakeTuple)};
(void)tuple_cnode_inputs.insert(tuple_cnode_inputs.end(), map_input.begin(), map_input.end());
auto tuple_cnode = fg_->NewCNode(tuple_cnode_inputs);
output_element_cnode_inputs.push_back(NewValueNode(prim_));
output_element_cnode_inputs.push_back(tuple_cnode);
} else {
output_element_cnode_inputs.push_back(NewValueNode(prim_));
(void)output_element_cnode_inputs.insert(output_element_cnode_inputs.end(), map_input.begin(), map_input.end());
}
auto output_element_cnode = fg_->NewCNode(output_element_cnode_inputs);
(void)output_cnode_inputs.emplace_back(output_element_cnode);
}
auto output_cnode = fg_->NewCNode(output_cnode_inputs);
const py::function vmap_general_output_process_fn =
python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process");
auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn);
auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode});
fg_->set_output(vmap_general_output);
return fg_;
}
} // namespace prim
} // namespace mindspore

+ 23
- 0
mindspore/ccsrc/frontend/operator/composite/vmap.h View File

@@ -25,6 +25,12 @@
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using CNodeInpusList = std::vector<std::vector<AnfNodePtr>>;
using InputsAbstractList = std::vector<std::vector<abstract::AbstractBasePtr>>;
constexpr int64_t kValIndex = 0;
constexpr int64_t kDimIndex = 1;
constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap";
constexpr char kNumpyModelName[] = "mindspore.numpy";
class VmapMatchOutAxis : public MetaFuncGraph {
public:
explicit VmapMatchOutAxis(const std::string &name) : MetaFuncGraph(name), fg_(std::make_shared<FuncGraph>()) {}
@@ -55,6 +61,23 @@ class VmapGeneralPreprocess : public MetaFuncGraph {

FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
};

class VmapGeneralRule : public MetaFuncGraph {
public:
explicit VmapGeneralRule(const std::string &name, const PrimitivePtr &prim, int64_t axis_size)
: MetaFuncGraph(name), prim_(prim), axis_size_(axis_size) {}
~VmapGeneralRule() override = default;
MS_DECLARE_PARENT(VmapGeneralRule, MetaFuncGraph);

FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;

private:
CNodeInpusList ConstructMapInput(const InputsAbstractList &tuple_elements_abstract, bool wrapped_tuple,
int64_t args_size);
PrimitivePtr prim_{nullptr};
int64_t axis_size_ = 0;
FuncGraphPtr fg_{nullptr};
};
} // namespace prim
} // namespace mindspore



+ 28
- 17
mindspore/ccsrc/frontend/optimizer/irpass/vmap_eliminate.cc View File

@@ -187,7 +187,7 @@ AnfNodePtr MatchOutAxis(const AnfNodePtr &expanded_vmap_node, int parameters_siz
return NewValueNode(vmap_post_fg);
}

FuncGraphPtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resource, int axis_size) {
AnfNodePtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resource, int axis_size) {
// Set a child scope named "vmap_'PrimitiveName'" for the vmap rule function,
// and add "VmapRule" to the front.
constexpr char vmap_rule_scope[] = "VmapRule/";
@@ -200,6 +200,7 @@ FuncGraphPtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBaseP
// Firstly we parse the python VmapRules function registered for specific primitive. If failed, get
// the vmap general rule.
FuncGraphPtr vmap_rule_fg = nullptr;
AnfNodePtr vmap_rule_node = nullptr;
py::function vmap_rule_fn;
bool is_side_effect = false;
if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM)) {
@@ -207,6 +208,8 @@ FuncGraphPtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBaseP
} else if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_IO) && prim->name() != prim::kPrimPrint->name()) {
MS_LOG(EXCEPTION) << prim->name() << " is a GRAPH_FLAG_SIDE_EFFECT_IO prim, vmap dont support currently.";
}

// Get vmap rule for specific primitive.
if (prim->is_base()) {
vmap_rule_fn = GetVmapRuleFunction(prim->name(), axis_size);
} else {
@@ -215,25 +218,33 @@ FuncGraphPtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBaseP
vmap_rule_fn = GetVmapRuleFunction(prim->name(), axis_size);
}
}

// If vmap rule for specific primitive not found, get vmap general rule.
if (!vmap_rule_fn || py::isinstance<py::none>(vmap_rule_fn)) {
MS_LOG(DEBUG) << "Fail to find vmap rule function for " << prim->name() << ", try to get the general vmap rule.";
vmap_rule_fn = GetVmapGeneralRuleFunction(prim->name(), is_side_effect, axis_size);
}
if (!vmap_rule_fn || py::isinstance<py::none>(vmap_rule_fn)) {
MS_LOG(EXCEPTION) << "Fail to find vmap rule function for " << prim->name() << ".";
}
vmap_rule_fg = parse::ParsePythonCode(vmap_rule_fn);
if (vmap_rule_fg == nullptr) {
MS_LOG(EXCEPTION) << "Fail to parse vmap rule function for " << prim->name() << ".";
if (is_side_effect) {
vmap_rule_fn = python_adapter::GetPyFn("mindspore.ops._vmap", "vmap_monad_rule")(prim->name(), axis_size);
} else {
vmap_rule_node =
NewValueNode(std::make_shared<prim::VmapGeneralRule>("VmapGeneralRule", prim, static_cast<int64_t>(axis_size)));
}
}
auto vmap_rule_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
if (vmap_rule_flag) {
vmap_rule_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);

if (vmap_rule_node == nullptr) {
vmap_rule_fg = parse::ParsePythonCode(vmap_rule_fn);
if (vmap_rule_fg == nullptr) {
MS_LOG(EXCEPTION) << "Fail to parse vmap rule function for " << prim->name() << ".";
}
auto vmap_rule_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
if (vmap_rule_flag) {
vmap_rule_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
}
pipeline::ResourceBasePtr res = (resource != nullptr) ? resource : std::make_shared<pipeline::Resource>();
(void)parse::ResolveFuncGraph(vmap_rule_fg, res);
vmap_rule_node = NewValueNode(vmap_rule_fg);
}
pipeline::ResourceBasePtr res = (resource != nullptr) ? resource : std::make_shared<pipeline::Resource>();
(void)parse::ResolveFuncGraph(vmap_rule_fg, res);

return vmap_rule_fg;
return vmap_rule_node;
}

AnfNodePtr ExpandVmapPrimitive(const AnfNodePtr &vnode, const pipeline::ResourceBasePtr &resource, int axis_size) {
@@ -246,12 +257,12 @@ AnfNodePtr ExpandVmapPrimitive(const AnfNodePtr &vnode, const pipeline::Resource
if (throughtout_op.count(prim->name())) {
return vnode;
} else {
FuncGraphPtr prim_vmap_rule = GetVmapRule(prim, resource, axis_size);
AnfNodePtr prim_vmap_rule = GetVmapRule(prim, resource, axis_size);
if (prim_vmap_rule == nullptr) {
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed. NodeInfo: "
<< trace::GetDebugInfo(prim_vmap_rule->debug_info()) << ".";
}
return NewValueNode(prim_vmap_rule);
return prim_vmap_rule;
}
return nullptr;
}


+ 0
- 6
mindspore/ccsrc/include/common/utils/primitive_utils.h View File

@@ -47,12 +47,6 @@ COMMON_EXPORT py::tuple ConvertDatatoPyTuple(const VectorRef &args);
COMMON_EXPORT py::function GetVmapRuleFunctionByObj(const py::object &obj, int axis_size);

COMMON_EXPORT py::function GetVmapRuleFunction(const std::string &name, int axis_size);

COMMON_EXPORT py::function GetVmapGeneralRuleFunction(const std::string &name, const bool is_side_effect = false,
int axis_size = 0);

COMMON_EXPORT py::function GetVmapGeneralRuleByObj(const py::object &obj, const bool is_side_effect = false,
int axis_size = 0);
} // namespace mindspore

#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PRIMITIVE_UTILS_H_

+ 0
- 4
mindspore/ccsrc/pybind_api/ir/primitive_py.cc View File

@@ -145,10 +145,6 @@ py::function PrimitivePy::GetVmapRuleFunction(const bool is_side_effect, int axi
return fn;
} else {
auto fn = GetVmapRuleFunctionByObj(python_obj_, axis_size);
if (!fn || py::isinstance<py::none>(fn)) {
MS_LOG(DEBUG) << "Fail to find vmap rule function for " << this->name() << ", try to get the vmap general rule";
fn = GetVmapGeneralRuleByObj(python_obj_, is_side_effect, axis_size);
}
return fn;
}
}


+ 0
- 12
mindspore/ccsrc/utils/primitive_utils.cc View File

@@ -122,16 +122,4 @@ py::function GetVmapRuleFunction(const std::string &name, int axis_size) {
auto fn = GetVmapRuleFunctionByObj(py::str(name), axis_size);
return fn;
}

py::function GetVmapGeneralRuleFunction(const std::string &name, const bool is_side_effect, int axis_size) {
auto fn = GetVmapGeneralRuleByObj(py::str(name), is_side_effect, axis_size);
return fn;
}

py::function GetVmapGeneralRuleByObj(const py::object &obj, const bool is_side_effect, int axis_size) {
std::string get_vmap_rule_fn = is_side_effect ? "vmap_monad_rule" : "vmap_general_rule";
constexpr char vmap_module[] = "mindspore.ops._vmap";
py::function fn = python_adapter::GetPyFn(vmap_module, get_vmap_rule_fn)(obj, axis_size);
return fn;
}
} // namespace mindspore

+ 4
- 2
mindspore/python/mindspore/ops/_vmap/__init__.py View File

@@ -14,6 +14,8 @@
# ============================================================================

"""vmap impl."""
from .vmap_base import get_vmap_rule, vmap_general_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none
from .vmap_base import get_vmap_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none,\
vmap_unstack, vmap_general_output_process

__all__ = ['get_vmap_rule', 'vmap_general_rule', 'vmap_monad_rule', '_broadcast_by_axis', 'vmap_bind_all_none']
__all__ = ['get_vmap_rule', 'vmap_monad_rule', '_broadcast_by_axis', 'vmap_bind_all_none',
'vmap_unstack', 'vmap_general_output_process']

+ 17
- 70
mindspore/python/mindspore/ops/_vmap/vmap_base.py View File

@@ -20,7 +20,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import constexpr
from .._register_for_op import Registry
from ..composite import HyperMap, _VmapGeneralPreprocess
from ..composite import _VmapGeneralPreprocess
from ..primitive import Primitive
from ...common import Tensor

@@ -87,77 +87,24 @@ def vmap_bind_all_none(inputs):
vmap_general_preprocess = _VmapGeneralPreprocess()


def vmap_general_rule(prim, axis_size):
"""
When the primitive does not registered the relevant specific VmapRule, it attempts to get
this the general rule. The general rule is combining loop and stack operators to simulate
the behavior of Vmap. Noted that, general rules does not guarantee the correctness of
execution results.
Currently, only the following types of primitives are supported:
1、 Most calculation operations, whose inputs are tensors, scalars or both of them.
(If all elements in a tuple are scalars, it is also considered scalar.)
2、 Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped into a tuple.
In other words, we do not support any tuple wrapped variables except for the special cases
listed above.
"""

if isinstance(prim, str):
prim_name = prim
prim = Primitive(prim)
else:
prim_name = prim.name
common_map = HyperMap()
def vmap_unstack(dim, val):
return P.Unstack(dim)(val)

def loop_stack(*args):
is_all_none, result = vmap_general_preprocess(prim, *args)
if is_all_none:
return result

wrapped_tuple = False
# Handle case such as args:(((A, 0), (B, 1)),)
if len(args) == 1 and isinstance(args[0][-1], tuple):
wrapped_tuple = True
args = args[0]

vals_in_tuple = ()
for val_in in args:
val, dim = val_in
out = ()
if dim is None:
# Handle case such as args:(..., (A, None), (1, None), ...)
for _ in range(axis_size):
out = out + (val,)
else:
if isinstance(val, Tensor):
# Handle case such as args:(..., (A, 0), (B, 1), ...)
out = P.Unstack(dim)(val)
else:
_raise_value_error("A variable of type other than `Tensor` is accepted, "
"but the source axis is not `None`")

vals_in_tuple = vals_in_tuple + (out,)

if wrapped_tuple:
output = ()
for sub_tuple in zip(*vals_in_tuple):
out = prim(sub_tuple)
output = output + (out,)
else:
output = common_map(prim, *vals_in_tuple)

vals_out_tuple = ()
if isinstance(output[0], tuple):
for res in zip(**output):
if not isinstance(res[0], Tensor):
_raise_value_error("The output of the operator is not of the Tensor type, "
"a specific vmap rule is required for op: ", prim_name)
out = F.stack(res)
vals_out_tuple = vals_out_tuple + ((out, 0),)
else:
out = F.stack(output)
vals_out_tuple = vals_out_tuple + (out, 0)
return vals_out_tuple
return loop_stack
def vmap_general_output_process(output):
""" Match output to axis 0"""
vals_out_tuple = ()
if isinstance(output[0], tuple):
for res in zip(**output):
if not isinstance(res[0], Tensor):
_raise_value_error("The output of the operator is not of the Tensor type, "
"a specific vmap rule is required.")
out = F.stack(res)
vals_out_tuple = vals_out_tuple + ((out, 0),)
else:
out = F.stack(output)
vals_out_tuple = vals_out_tuple + (out, 0)
return vals_out_tuple


def vmap_monad_rule(prim, axis_size):


Loading…
Cancel
Save