|
|
|
@@ -15,18 +15,19 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "pybind_api/ir/primitive_py.h" |
|
|
|
|
|
|
|
#include <mutex> |
|
|
|
#include "ir/signature.h" |
|
|
|
#include "pipeline/jit/parse/python_adapter.h" |
|
|
|
#include "pipeline/jit/parse/data_converter.h" |
|
|
|
#include "pipeline/jit/parse/python_adapter.h" |
|
|
|
#include "pybind11/pytypes.h" |
|
|
|
#include "utils/convert_utils_base.h" |
|
|
|
#include "utils/convert_utils_py.h" |
|
|
|
#include "utils/primitive_utils.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "pybind_api/api_register.h" |
|
|
|
#include "pybind_api/export_flags.h" |
|
|
|
#include "pybind_api/ir/base_ref_py.h" |
|
|
|
#include "utils/convert_utils_base.h" |
|
|
|
#include "utils/convert_utils_py.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "utils/primitive_utils.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace { |
|
|
|
@@ -107,6 +108,42 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) |
|
|
|
return grads; |
|
|
|
} |
|
|
|
|
|
|
|
void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const { |
|
|
|
if (py::isinstance<py::tuple>(expected_grad_out)) { |
|
|
|
if (!py::isinstance<py::tuple>(grad_out)) { |
|
|
|
hook_grad_.clear(); |
|
|
|
MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!"; |
|
|
|
} |
|
|
|
auto actual_out_tuple = py::cast<py::tuple>(grad_out); |
|
|
|
auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out); |
|
|
|
if (actual_out_tuple.size() != expected_out_tuple.size()) { |
|
|
|
hook_grad_.clear(); |
|
|
|
MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size() |
|
|
|
<< ", but it is " << actual_out_tuple.size(); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < expected_out_tuple.size(); ++i) { |
|
|
|
CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(expected_grad_out)) { |
|
|
|
if (!py::isinstance<tensor::Tensor>(grad_out)) { |
|
|
|
hook_grad_.clear(); |
|
|
|
MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!"; |
|
|
|
} |
|
|
|
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out); |
|
|
|
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out); |
|
|
|
MS_EXCEPTION_IF_NULL(actual_out_tensor); |
|
|
|
MS_EXCEPTION_IF_NULL(expected_out_tensor); |
|
|
|
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) { |
|
|
|
hook_grad_.clear(); |
|
|
|
MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be " |
|
|
|
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is " |
|
|
|
<< actual_out_tensor->GetShapeAndDataTypeInfo(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { |
|
|
|
py::tuple py_args = ConvertDatatoPyTuple(args); |
|
|
|
bool is_bprop = this->HasAttr(kBpropAttrName); |
|
|
|
@@ -138,6 +175,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
obj = py_args[2]; |
|
|
|
} |
|
|
|
CheckHookConsistency(obj, py_args[2]); |
|
|
|
hook_grad_.erase(cell_id); |
|
|
|
} else { |
|
|
|
hook_grad_[cell_id] = py_args[2]; |
|
|
|
@@ -149,6 +187,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
obj = py_args[2]; |
|
|
|
} |
|
|
|
CheckHookConsistency(obj, py_args[2]); |
|
|
|
} |
|
|
|
obj = py::make_tuple(obj); |
|
|
|
return std::make_shared<PyObjectRef>(obj); |
|
|
|
|