|
|
|
@@ -34,6 +34,41 @@ |
|
|
|
namespace mindspore { |
|
|
|
// namespace to support composite operators definition |
|
|
|
namespace prim { |
|
|
|
void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, int64_t args_size, bool wrapped_tuple, |
|
|
|
bool bind) { |
|
|
|
std::vector<AnfNodePtr> prim_output_cnode_inputs; |
|
|
|
(void)prim_output_cnode_inputs.emplace_back(prim); |
|
|
|
if (wrapped_tuple) { |
|
|
|
auto val_in_param = fg->add_parameter(); |
|
|
|
std::vector<AnfNodePtr> prim_inputs_cnode_inputs; |
|
|
|
(void)prim_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
for (int64_t i = 0; i < args_size; ++i) { |
|
|
|
auto val_in_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(i)}); |
|
|
|
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_cnode, NewValueNode(kValIndex)}); |
|
|
|
(void)prim_inputs_cnode_inputs.emplace_back(val_cnode); |
|
|
|
} |
|
|
|
auto prim_inputs_cnode = fg->NewCNode(prim_inputs_cnode_inputs); |
|
|
|
(void)prim_output_cnode_inputs.emplace_back(prim_inputs_cnode); |
|
|
|
} else { |
|
|
|
for (int64_t i = 0; i < args_size; ++i) { |
|
|
|
auto val_in_param = fg->add_parameter(); |
|
|
|
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(kValIndex)}); |
|
|
|
(void)prim_output_cnode_inputs.emplace_back(val_cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs); |
|
|
|
const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none"); |
|
|
|
auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn); |
|
|
|
auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode}); |
|
|
|
if (bind) { |
|
|
|
auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode}); |
|
|
|
fg->set_output(output_cnode); |
|
|
|
return; |
|
|
|
} |
|
|
|
fg->set_output(bind_all_none_cnode); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerBroadcastAxis( |
|
|
|
const AnfNodePtr &inputs, const AnfNodePtr &out_axis, const AnfNodePtr &axis_size, |
|
|
|
const AbstractBasePtr &inputs_abstract_elements_begin) const { |
|
|
|
@@ -86,7 +121,6 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement( |
|
|
|
auto value_cnode = fg_->NewCNode(value_cnode_inputs); |
|
|
|
std::vector<AnfNodePtr> out_cnode_inputs; |
|
|
|
if (inputs_abstract_elements_end->isa<abstract::AbstractNone>()) { |
|
|
|
constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap"; |
|
|
|
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis"); |
|
|
|
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis); |
|
|
|
(void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg)); |
|
|
|
@@ -99,8 +133,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement( |
|
|
|
(void)dim_cnode_inputs.emplace_back(inputs); |
|
|
|
(void)dim_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(1))); |
|
|
|
auto dim_cnode = fg_->NewCNode(dim_cnode_inputs); |
|
|
|
constexpr char kVmapFunctionModelName[] = "mindspore.numpy"; |
|
|
|
const py::function move_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "moveaxis"); |
|
|
|
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis"); |
|
|
|
auto move_axis_fg = parse::ParsePythonCode(move_axis); |
|
|
|
(void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg)); |
|
|
|
(void)out_cnode_inputs.emplace_back(value_cnode); |
|
|
|
@@ -190,7 +223,6 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu |
|
|
|
auto src_abstract = each_inputs_abstract_elements[1]; |
|
|
|
CNodePtr out_cnode = nullptr; |
|
|
|
if (src_abstract->isa<abstract::AbstractNone>() && !dst_abstract->isa<abstract::AbstractNone>()) { |
|
|
|
constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap"; |
|
|
|
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis"); |
|
|
|
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis); |
|
|
|
out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size}); |
|
|
|
@@ -199,8 +231,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu |
|
|
|
} else if (src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) { |
|
|
|
out_cnode = val_cnode; |
|
|
|
} else { |
|
|
|
constexpr char kVmapFunctionModelName[] = "mindspore.numpy"; |
|
|
|
const py::function move_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "moveaxis"); |
|
|
|
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis"); |
|
|
|
auto move_axis_fg = parse::ParsePythonCode(move_axis); |
|
|
|
out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode}); |
|
|
|
} |
|
|
|
@@ -303,9 +334,6 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList |
|
|
|
return args_spec_list; |
|
|
|
}; |
|
|
|
auto tuple_elements = get_tuple_elements(args_spec_list); |
|
|
|
|
|
|
|
constexpr int64_t val_index = 0; |
|
|
|
constexpr int64_t dim_index = 1; |
|
|
|
bool is_all_none = true; |
|
|
|
constexpr size_t kCurTupleSize = 2; |
|
|
|
for (int64_t i = 0; i < inputs_size; ++i) { |
|
|
|
@@ -321,7 +349,7 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList |
|
|
|
MS_LOG(EXCEPTION) << "The " << i + offset << "th input to VmapGeneralPreprocess should be a tuple with two " |
|
|
|
<< "elements but got " << cur_arg_tuple_elements.size() << " elements."; |
|
|
|
} |
|
|
|
if (!cur_arg_tuple_elements[dim_index]->isa<abstract::AbstractNone>()) { |
|
|
|
if (!cur_arg_tuple_elements[kDimIndex]->isa<abstract::AbstractNone>()) { |
|
|
|
MS_LOG(INFO) << "The " << i + offset << "th input to VmapGeneralPreprocess has not None dim value."; |
|
|
|
is_all_none = false; |
|
|
|
break; |
|
|
|
@@ -334,39 +362,11 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList |
|
|
|
for (size_t i = 1; i < args_size; ++i) { |
|
|
|
(void)fg->add_parameter(); |
|
|
|
} |
|
|
|
(void)output_cnode_inputs.emplace_back(NewValueNode(false)); |
|
|
|
(void)output_cnode_inputs.emplace_back(NewValueNode(kNone)); |
|
|
|
auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(false), NewValueNode(kNone)}); |
|
|
|
fg->set_output(output_cnode); |
|
|
|
} else { |
|
|
|
std::vector<AnfNodePtr> prim_output_cnode_inputs; |
|
|
|
(void)prim_output_cnode_inputs.emplace_back(prim); |
|
|
|
if (wrapped_tuple) { |
|
|
|
auto val_in_param = fg->add_parameter(); |
|
|
|
std::vector<AnfNodePtr> prim_inputs_cnode_inputs; |
|
|
|
(void)prim_inputs_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
for (int64_t i = 0; i < inputs_size; ++i) { |
|
|
|
auto val_in_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(i)}); |
|
|
|
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_cnode, NewValueNode(val_index)}); |
|
|
|
(void)prim_inputs_cnode_inputs.emplace_back(val_cnode); |
|
|
|
} |
|
|
|
auto prim_inputs_cnode = fg->NewCNode(prim_inputs_cnode_inputs); |
|
|
|
(void)prim_output_cnode_inputs.emplace_back(prim_inputs_cnode); |
|
|
|
} else { |
|
|
|
for (int64_t i = 0; i < inputs_size; ++i) { |
|
|
|
auto val_in_param = fg->add_parameter(); |
|
|
|
auto val_cnode = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), val_in_param, NewValueNode(val_index)}); |
|
|
|
(void)prim_output_cnode_inputs.emplace_back(val_cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs); |
|
|
|
const char kVmapFunctionModelName[] = "mindspore.ops._vmap"; |
|
|
|
const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none"); |
|
|
|
auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn); |
|
|
|
auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode}); |
|
|
|
(void)output_cnode_inputs.emplace_back(NewValueNode(true)); |
|
|
|
(void)output_cnode_inputs.emplace_back(bind_all_none_cnode); |
|
|
|
GenerateFuncGraphAllNone(fg, prim, inputs_size, wrapped_tuple, true); |
|
|
|
} |
|
|
|
auto output_cnode = fg->NewCNode(output_cnode_inputs); |
|
|
|
fg->set_output(output_cnode); |
|
|
|
return fg; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -375,5 +375,146 @@ REGISTER_PYBIND_DEFINE(VmapGeneralPreprocess_, ([](const py::module *m) { |
|
|
|
*m, "VmapGeneralPreprocess_") |
|
|
|
.def(py::init<std::string &>(), py::arg("fn")); |
|
|
|
})); |
|
|
|
|
|
|
|
CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &tuple_elements_abstract, bool wrapped_tuple, |
|
|
|
int64_t args_size) { |
|
|
|
AnfNodePtr single_input = nullptr; |
|
|
|
if (wrapped_tuple) { |
|
|
|
single_input = fg_->add_parameter(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodeInpusList map_inputs(axis_size_); |
|
|
|
for (int64_t i = 0; i < args_size; ++i) { |
|
|
|
AnfNodePtr cur_arg_node = nullptr; |
|
|
|
if (wrapped_tuple) { |
|
|
|
cur_arg_node = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), single_input, NewValueNode(i)}); |
|
|
|
} else { |
|
|
|
cur_arg_node = fg_->add_parameter(); |
|
|
|
} |
|
|
|
auto tuple_element_abstract = tuple_elements_abstract[i]; |
|
|
|
auto val_abstract = tuple_element_abstract[kValIndex]; |
|
|
|
auto dim_abstract = tuple_element_abstract[kDimIndex]; |
|
|
|
AnfNodePtr val_cnode = |
|
|
|
fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kValIndex)}); |
|
|
|
|
|
|
|
if (dim_abstract->isa<abstract::AbstractNone>()) { |
|
|
|
for (int64_t m = 0; m < axis_size_; ++m) { |
|
|
|
map_inputs[m].push_back(val_cnode); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (!val_abstract->isa<abstract::AbstractTensor>()) { |
|
|
|
MS_LOG(EXCEPTION) << "A variable of type other than `Tensor` is accepted, but the source axis is not `None`"; |
|
|
|
} |
|
|
|
AnfNodePtr dim_cnode = |
|
|
|
fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)}); |
|
|
|
const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack"); |
|
|
|
auto unstack_fg_ = parse::ParsePythonCode(unstack_fn); |
|
|
|
auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode}); |
|
|
|
for (int64_t m = 0; m < axis_size_; ++m) { |
|
|
|
auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)}); |
|
|
|
map_inputs[m].push_back(out_element_cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return map_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
// When the primitive does not registered the relevant specific VmapRule, it attempts to get |
|
|
|
// this the general rule. The general rule is combining loop and stack operators to simulate |
|
|
|
// the behavior of Vmap. Noted that, general rules does not guarantee the correctness of |
|
|
|
// execution results. |
|
|
|
// Currently, only the following types of primitives are supported: |
|
|
|
// 1、 Most calculation operations, whose inputs are tensors, scalars or both of them. |
|
|
|
// (If all elements in a tuple are scalars, it is also considered scalar.) |
|
|
|
// 2、 Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped into a tuple. |
|
|
|
// In other words, we do not support any tuple wrapped variables except for the special cases |
|
|
|
// listed above. |
|
|
|
FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { |
|
|
|
fg_ = std::make_shared<FuncGraph>(); |
|
|
|
int64_t args_size = static_cast<int64_t>(args_spec_list.size()); |
|
|
|
|
|
|
|
bool wrapped_tuple = false; |
|
|
|
auto get_tuple_elements = [&wrapped_tuple, |
|
|
|
&args_size](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & { |
|
|
|
if (args_size == 1) { |
|
|
|
auto arg = args_spec_list[0]; |
|
|
|
if (!arg->isa<abstract::AbstractTuple>()) { |
|
|
|
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractTuple but got: " |
|
|
|
<< arg->ToString() << "."; |
|
|
|
} |
|
|
|
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>(); |
|
|
|
const auto &arg_tuple_elements = arg_tuple->elements(); |
|
|
|
if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) { |
|
|
|
// Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped |
|
|
|
// into a tuple. We need to process the internal elements separately and then re-wrap |
|
|
|
// them into tuple. Handle case such as args:(((A, 0), (B, 1), (C, None)),). Which |
|
|
|
// different from the case with single input parameter ((A, 0),). |
|
|
|
wrapped_tuple = true; |
|
|
|
args_size = arg_tuple_elements.size(); |
|
|
|
return arg_tuple_elements; |
|
|
|
} |
|
|
|
} |
|
|
|
return args_spec_list; |
|
|
|
}; |
|
|
|
auto tuple_elements = get_tuple_elements(args_spec_list); |
|
|
|
|
|
|
|
bool is_all_none = true; |
|
|
|
constexpr size_t kCurTupleSize = 2; |
|
|
|
InputsAbstractList tuple_elements_abstract(args_size); |
|
|
|
for (int64_t i = 0; i < args_size; ++i) { |
|
|
|
auto cur_arg = tuple_elements[i]; |
|
|
|
if (!cur_arg->isa<abstract::AbstractTuple>()) { |
|
|
|
MS_LOG(EXCEPTION) << "The " << i |
|
|
|
<< "th input to VmapGeneralPreprocess should be AbstractTuple but got: " << cur_arg->ToString() |
|
|
|
<< "."; |
|
|
|
} |
|
|
|
auto cur_arg_tuple = cur_arg->cast<abstract::AbstractTuplePtr>(); |
|
|
|
auto cur_arg_tuple_elements = cur_arg_tuple->elements(); |
|
|
|
if (cur_arg_tuple_elements.size() != kCurTupleSize) { |
|
|
|
MS_LOG(EXCEPTION) << "The " << i << "th input to VmapGeneralPreprocess should be a tuple with two " |
|
|
|
<< "elements but got " << cur_arg_tuple_elements.size() << " elements."; |
|
|
|
} |
|
|
|
auto dim_abstract = cur_arg_tuple_elements[kDimIndex]; |
|
|
|
if (is_all_none && !dim_abstract->isa<abstract::AbstractNone>()) { |
|
|
|
MS_LOG(INFO) << "The " << i << "th input to VmapGeneralPreprocess has not None dim value."; |
|
|
|
is_all_none = false; |
|
|
|
} |
|
|
|
auto val_abstract = cur_arg_tuple_elements[kValIndex]; |
|
|
|
std::vector<abstract::AbstractBasePtr> element_abstract = {val_abstract, dim_abstract}; |
|
|
|
tuple_elements_abstract[i] = element_abstract; |
|
|
|
} |
|
|
|
|
|
|
|
if (is_all_none) { |
|
|
|
GenerateFuncGraphAllNone(fg_, NewValueNode(prim_), args_size, wrapped_tuple, false); |
|
|
|
return fg_; |
|
|
|
} |
|
|
|
|
|
|
|
CNodeInpusList map_inputs = ConstructMapInput(tuple_elements_abstract, wrapped_tuple, args_size); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> output_cnode_inputs; |
|
|
|
(void)output_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
for (auto map_input : map_inputs) { |
|
|
|
std::vector<AnfNodePtr> output_element_cnode_inputs; |
|
|
|
if (wrapped_tuple) { |
|
|
|
std::vector<AnfNodePtr> tuple_cnode_inputs{NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
(void)tuple_cnode_inputs.insert(tuple_cnode_inputs.end(), map_input.begin(), map_input.end()); |
|
|
|
auto tuple_cnode = fg_->NewCNode(tuple_cnode_inputs); |
|
|
|
output_element_cnode_inputs.push_back(NewValueNode(prim_)); |
|
|
|
output_element_cnode_inputs.push_back(tuple_cnode); |
|
|
|
} else { |
|
|
|
output_element_cnode_inputs.push_back(NewValueNode(prim_)); |
|
|
|
(void)output_element_cnode_inputs.insert(output_element_cnode_inputs.end(), map_input.begin(), map_input.end()); |
|
|
|
} |
|
|
|
auto output_element_cnode = fg_->NewCNode(output_element_cnode_inputs); |
|
|
|
(void)output_cnode_inputs.emplace_back(output_element_cnode); |
|
|
|
} |
|
|
|
auto output_cnode = fg_->NewCNode(output_cnode_inputs); |
|
|
|
const py::function vmap_general_output_process_fn = |
|
|
|
python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process"); |
|
|
|
auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn); |
|
|
|
auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode}); |
|
|
|
fg_->set_output(vmap_general_output); |
|
|
|
return fg_; |
|
|
|
} |
|
|
|
} // namespace prim |
|
|
|
} // namespace mindspore |